In [36]:
import os
import pandas as pd
import numpy as np
import os.path as osp
import spatialdata as sd
from skimage.measure import regionprops
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch.optim as optim
import torch.nn as nn
from geomloss import SamplesLoss
import torch.nn.functional as F

from config import DATA_PATH
from data import patch_john, patch_harry
from models import utils as model_utils
from eval import utils as eval_utils

# hugginface
from dotenv import load_dotenv
from huggingface_hub import login
load_dotenv(dotenv_path=os.path.expanduser('~/hl/.gutinstinct.env'))
api_token = os.getenv("API_TOKEN")
login(token=api_token)


In [39]:
# eval_utils.l1_normalize(1)
eval_utils.compute_r2_spearman(y_true_all, y_pred_all)

AttributeError: module 'eval.utils' has no attribute 'compute_r2_spearman'

In [2]:
zarr_path = osp.join(os.path.expanduser(DATA_PATH), "UC6_I.zarr/UC6_I.zarr")
sdata = sd.read_zarr(zarr_path)

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


In [14]:
dataset_patch_train, dataset_patch_val, dataset_patch_test = patch_harry.get_patches(sdata)

In [9]:
dataset_expression_train = model_utils.get_expression(sdata, dataset_patch_train)
dataset_expression_val = model_utils.get_expression(sdata, dataset_patch_val)
dataset_expression_test = model_utils.get_expression(sdata, dataset_patch_test)

In [15]:
train_loader = DataLoader(dataset_patch_train, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset_patch_val, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset_patch_test, batch_size=32, shuffle=True)

In [16]:
## instantiate a uni model
timm_kwargs = {
   'img_size': 224,
   'patch_size': 32,
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5,
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0,
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked,
   'act_layer': torch.nn.SiLU,
   'reg_tokens': 8,
   'dynamic_img_size': True
  }
model_uni = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
model_uni = model_uni.to('cuda')
transform = create_transform(**resolve_data_config(model_uni.pretrained_cfg, model=model_uni))
model_uni.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(32, 32), stride=(32, 32))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
    

In [17]:
train_embeddings, train_cell_ids = model_utils.extract_features(train_loader, model_uni)
val_embeddings, val_cell_ids = model_utils.extract_features(val_loader, model_uni)
test_embeddings, test_cell_ids = model_utils.extract_features(test_loader, model_uni)

In [18]:
train_dataset = TensorDataset(
    torch.tensor(train_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_train, dtype=torch.float32)
)

val_dataset = TensorDataset(
    torch.tensor(val_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_val, dtype=torch.float32)
)

test_dataset = TensorDataset(
    torch.tensor(test_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_test, dtype=torch.float32)
)

# create dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [35]:
# fit fcn model
feature_dim = train_embeddings.shape[1]
model = nn.Sequential(
    nn.Linear(feature_dim, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 460)
)
model.to('cuda')

# some setups
criterion = SamplesLoss("sinkhorn", p=2, blur=0.05)
lmd = 0.7        # controls how much we care about the point-wise loss
optimizer = optim.Adam(model.parameters(), lr=1e-5)
train_losses = []
val_losses = []

# es logic
es = None
best_r2 = float('inf')
patience = 5
counter = 0

for epoch in range(20):
    model.train()
    running_train_loss = 0.0
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to('cuda'), y_batch.to('cuda')
        optimizer.zero_grad()
        y_pred = model(x_batch)
        y_pred_dist = eval_utils.l1_normalize(y_pred)
        y_true_dist = eval_utils.l1_normalize(y_batch)

        # use both point-wise and distributional loss
        mse_loss = F.mse_loss(y_pred, y_batch)
        dist_loss = criterion(l1_normalize(y_pred), l1_normalize(y_batch))
        loss = lmd * mse_loss + (1 - lmd) * dist_loss

        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()

    train_losses.append(running_train_loss / len(train_loader))

    # validation
    model.eval()
    running_val_loss = 0.0
    y_true_all, y_pred_all = [], []
    with torch.no_grad():
        for x_val, y_val in val_loader:
            x_val, y_val = x_val.to('cuda'), y_val.to('cuda')
            y_pred_val = model(x_val)
            y_pred_dist_val = eval_utils.l1_normalize(y_pred_val, dim=1)
            y_true_dist_val = eval_utils.l1_normalize(y_val, dim=1)

            # use both point-wise and distributional loss
            mse_loss = F.mse_loss(y_pred_val, y_val)
            dist_loss = criterion(y_pred_dist_val, y_true_dist_val)
            val_loss = lmd * mse_loss + (1 - lmd) * dist_loss

            running_val_loss += val_loss.item()

            y_true_all.append(y_val.detach().cpu().numpy())
            y_pred_all.append(y_pred_val.detach().cpu().numpy())

    val_losses.append(running_val_loss / len(val_loader))
    val_r2, val_spearman = eval_utils.compute_r2_spearman(y_true_all, y_pred_all)

    print(f"Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, Val R2: {val_r2:.4f}, Val Spearman: {val_spearman:.4f}")

    if es is not None:
        if val_r2 < best_r2:
            best_r2 = val_r2
            counter = 0
            best_model_state = model.state_dict()
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

AttributeError: module 'eval.utils' has no attribute 'l1_normalize'