In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import medmnist
from medmnist import INFO, Evaluator, BloodMNIST

In [3]:
train_dataset = BloodMNIST(split="train", size=64, download=True)
val_dataset = BloodMNIST(split="val", size=64, download=True)
test_dataset = BloodMNIST(split="test", size=64, download=True)

## Dinobloom

Let's test the embeddings produced by the Dinobloom model.  
The [small version](https://huggingface.co/1aurent/vit_small_patch14_224.dinobloom)
will do the work for now.

In [30]:
import timm

# load model from the hub
model = timm.create_model(
  model_name="hf-hub:1aurent/vit_small_patch14_224.dinobloom",
  pretrained=True,
).eval()

config.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/86.5M [00:00<?, ?B/s]

In [31]:
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)

In [20]:
import torch
from tqdm import tqdm
from torch.utils import data

class CustomDataset(data.Dataset):

    def __init__(self, data, transform=None):
        super().__init__()
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, label = self.data[index]
        if self.transform:
            img = self.transform(img)
        return (img, label)


class EmbeddingDataset(data.Dataset):
    def __init__(self, dataset, model, transform, device):
        super().__init__()
        self.device = device
        self.embeddings, self.labels = self._create_vectors(model, dataset)

    def _create_vectors(self, model, dataset):
        embeddings = []
        label_list = []
        model.to(self.device)
        dataloader = data.DataLoader(dataset, batch_size=8)

        # loop over the data
        for batch in tqdm(dataloader):
            images, labels = batch
            images = images.to(self.device)

            # make the embeddings from the batch
            with torch.no_grad():
                embs = model(images).to("cpu")
            
            embeddings.append(embs)
            label_list.append(labels)


        embeddings = torch.cat(embeddings, dim=0)
        label_list = torch.cat(label_list, dim=0)
        return embeddings, label_list

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx, :], self.labels[idx]

In [32]:
device = "cuda:6"
training_dataset = CustomDataset(train_dataset, transforms)
emb_train = EmbeddingDataset(training_dataset, model, transform, device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1495/1495 [06:01<00:00,  4.14it/s]


In [37]:
validation_dataset = CustomDataset(val_dataset, transforms)
emb_validation = EmbeddingDataset(validation_dataset, model, transform, device)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 214/214 [00:27<00:00,  7.66it/s]


In [36]:
tensor, label = emb_train[0]
tensor.shape

torch.Size([384])

In [41]:
import torch
from torch.utils.data import DataLoader
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

dataloader = DataLoader(emb_train, batch_size=512, shuffle=False)
all_features, all_labels = [], []

for features, labels in dataloader:
    all_features.append(features.numpy())
    all_labels.append(labels.numpy())

X_train = np.concatenate(all_features, axis=0)
y_train = np.concatenate(all_labels, axis=0)

# Step 3: Fit a linear classifier (logistic regression)
clf = LogisticRegression(max_iter=1000, solver="lbfgs", multi_class="multinomial")
clf.fit(X_train, y_train)

# Optionally evaluate on a validation/test set
val_dataloader = DataLoader(emb_validation, batch_size=512, shuffle=False)

val_features, val_labels = [], []
for features, labels in val_dataloader:
    val_features.append(features.numpy())
    val_labels.append(labels.numpy())

X_val = np.concatenate(val_features, axis=0)
y_val = np.concatenate(val_labels, axis=0)

y_pred = clf.predict(X_val)
acc = accuracy_score(y_val, y_pred)
print(f"Linear probe accuracy: {acc:.4f}")


  y = column_or_1d(y, warn=True)


Linear probe accuracy: 0.9871


## UNI

In [5]:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()
print()




In [6]:
sum([p.numel() for p in model.parameters()]) / 1e6

303.350784

In [20]:
import torch
from tqdm import tqdm
from torch.utils import data

class CustomDataset(data.Dataset):

    def __init__(self, data, transform=None):
        super().__init__()
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img, label = self.data[index]
        if self.transform:
            img = self.transform(img)
        return (img, label)


class EmbeddingDataset(data.Dataset):
    def __init__(self, dataset, model, transform, device):
        super().__init__()
        self.device = device
        self.embeddings, self.labels = self._create_vectors(model, dataset)

    def _create_vectors(self, model, dataset):
        embeddings = []
        label_list = []
        model.to(self.device)
        dataloader = data.DataLoader(dataset, batch_size=8)

        # loop over the data
        for batch in tqdm(dataloader):
            images, labels = batch
            images = images.to(self.device)

            # make the embeddings from the batch
            with torch.no_grad():
                embs = model(images).to("cpu")
            
            embeddings.append(embs)
            label_list.append(labels)


        embeddings = torch.cat(embeddings, dim=0)
        label_list = torch.cat(label_list, dim=0)
        return embeddings, label_list

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        return self.embeddings[idx, :], self.labels[idx]

In [21]:
device = "cuda:6"
training_dataset = CustomDataset(train_dataset, transform)
emb_train = EmbeddingDataset(training_dataset, model, transform, device)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1495/1495 [02:23<00:00, 10.41it/s]


In [25]:
images, labels = emb_train[:10]

In [27]:
from torchvision.models import resnet18

cnn = resnet18(num_classes=9)