In [1]:
import os
import pandas as pd
import numpy as np
import os.path as osp
import spatialdata as sd
from skimage.measure import regionprops
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

from config import DATA_PATH
from data import patch_john, patch_harry
from models import utils as model_utils
from eval import utils as eval_utils

# hugginface
from dotenv import load_dotenv
from huggingface_hub import login
load_dotenv(dotenv_path=os.path.expanduser('~/hl/.gutinstinct.env'))
api_token = os.getenv("API_TOKEN")
login(token=api_token)



In [2]:
zarr_path = osp.join(os.path.expanduser(DATA_PATH), "UC6_I.zarr/UC6_I.zarr")
sdata = sd.read_zarr(zarr_path)

  compressor, fill_value = _kwargs_compat(compressor, fill_value, kwargs)


In [3]:
dataset_patch_train, dataset_patch_val, dataset_patch_test = patch_harry.get_patches(sdata)

In [4]:
dataset_expression_train = model_utils.get_expression(sdata, dataset_patch_train)
dataset_expression_val = model_utils.get_expression(sdata, dataset_patch_val)
dataset_expression_test = model_utils.get_expression(sdata, dataset_patch_test)

In [5]:
train_loader = DataLoader(dataset_patch_train, batch_size=32, shuffle=True)
val_loader = DataLoader(dataset_patch_val, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset_patch_test, batch_size=32, shuffle=True)

In [6]:
## instantiate a uni model
timm_kwargs = {
   'img_size': 224,
   'patch_size': 32,
   'depth': 24,
   'num_heads': 24,
   'init_values': 1e-5,
   'embed_dim': 1536,
   'mlp_ratio': 2.66667*2,
   'num_classes': 0,
   'no_embed_class': True,
   'mlp_layer': timm.layers.SwiGLUPacked,
   'act_layer': torch.nn.SiLU,
   'reg_tokens': 8,
   'dynamic_img_size': True
  }
model_uni = timm.create_model("hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs)
model_uni = model_uni.to('cuda')
transform = create_transform(**resolve_data_config(model_uni.pretrained_cfg, model=model_uni))
model_uni.eval()

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1536, kernel_size=(32, 32), stride=(32, 32))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1536, out_features=4608, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1536, out_features=1536, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1536,), eps=1e-06, elementwise_affine=True)
      (mlp): GluMlp(
        (fc1): Linear(in_features=1536, out_features=8192, bias=True)
        (act): SiLU()
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()
    

In [7]:
train_embeddings, train_cell_ids = model_utils.extract_features(train_loader, model_uni)
val_embeddings, val_cell_ids = model_utils.extract_features(val_loader, model_uni)
test_embeddings, test_cell_ids = model_utils.extract_features(test_loader, model_uni)

TypeError: can't convert np.ndarray of type numpy.uint32. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

In [None]:
train_dataset = TensorDataset(
    torch.tensor(train_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_train, dtype=torch.float32)
)

val_dataset = TensorDataset(
    torch.tensor(val_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_val, dtype=torch.float32)
)

test_dataset = TensorDataset(
    torch.tensor(test_embeddings, dtype=torch.float32),
    torch.tensor(dataset_expression_test, dtype=torch.float32)
)

# create dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)