In [1]:
import os
import pandas as pd
import torch

from tqdm import tqdm
from torch.utils.data import DataLoader, Subset

from torchvision.models import vit_b_16
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, ToPILImage

from utils import fix_random_seeds, clip_gradients, compute_knn_accuracy
from dataset import ISICDataset
from dino import DataAugmentationDINO, MultiCropWrapper, DINOHead, Loss

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(device)

mps


In [3]:
fix_random_seeds(42)

In [4]:
metadata = pd.read_csv('data/metadata.csv')
labels = metadata['malignant'].values.astype(int)
files = [f"data/ISIC_2024_Training_Input/{f}" for f in os.listdir('data/ISIC_2024_Training_Input') if f.endswith('.jpg')]

In [5]:
transform = DataAugmentationDINO(global_crops_scale=(0.4, 1.0), local_crops_scale=(0.05, 0.4), local_crops_number=8)

In [6]:
total_train_size = 10000

norm_only = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

dataset_train = ISICDataset(files, labels, transform=transform)
dataset_knn = ISICDataset(files, labels, transform=norm_only)

targets = torch.tensor(labels)
# get all indices where targets is 1
positive_indices = torch.where(targets == 1)[0]

# get 10.000 random indices where targets is 0
negative_indices = torch.where(targets == 0)[0]
negative_indices_train = negative_indices[torch.randperm(negative_indices.size(0))[:total_train_size-len(positive_indices)]]

# combine positive and negative indices
indices = torch.cat([positive_indices, negative_indices_train])

train_dataset = Subset(dataset_train, indices)

# get 50% of the positive indices
positive_indices_knn_val = positive_indices[torch.randperm(positive_indices.size(0))[:len(positive_indices)//2]]

# fill up to 1000 indices with negative indices
negative_indices_knn_val = negative_indices[torch.randperm(negative_indices.size(0))[:1000-len(positive_indices_knn_val)]]

knn_val_dataset = Subset(dataset_knn, torch.cat([positive_indices_knn_val, negative_indices_knn_val]))

# get the rest of the positive indices
positive_indices_knn_train = positive_indices[torch.randperm(positive_indices.size(0))[len(positive_indices)//2:]]

# fill up to 5000 indices with negative indices
negative_indices_knn_train = negative_indices[torch.randperm(negative_indices.size(0))[1000-len(positive_indices_knn_val):(5000-len(positive_indices_knn_train))+(1000-len(positive_indices_knn_val))]]

knn_train_dataset = Subset(dataset_knn, torch.cat([positive_indices_knn_train, negative_indices_knn_train]))

In [11]:
dataset_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

In [12]:
student = vit_b_16(weights=None)
teacher = vit_b_16(weights=None)

# make teacher and student have the same weights
teacher.load_state_dict(student.state_dict())

student = MultiCropWrapper(student, DINOHead(768, 1024))
teacher = MultiCropWrapper(teacher, DINOHead(768, 1024))

student = student.to(device)
teacher = teacher.to(device)

for p in teacher.parameters():
    p.requires_grad = False

  WeightNorm.apply(module, name, dim)


In [13]:
#dino_loss = DINOLoss(1024,8+2,0.04,0.04,0,100)
dino_loss = Loss(1024)
dino_loss = dino_loss.to(device)

In [14]:
lr = 0.0005 * 16 / 256
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=1e-6)
momentum_teacher = 0.995

In [15]:
log_number = 2

In [16]:
epochs = 1

for e in range(epochs):
    num_batches = 0
    for images, _ in tqdm(dataset_loader):
        images = [img.to(device) for img in images]
        student_output = student(images)
        teacher_output = teacher(images[:2])

        loss = dino_loss(student_output, teacher_output)

        optimizer.zero_grad()
        loss.backward()
        clip_gradients(student)
        optimizer.step()

        with torch.no_grad():
            for student_ps, teacher_ps in zip(
                student.parameters(), teacher.parameters()
            ):
                teacher_ps.data.mul_(momentum_teacher)
                teacher_ps.data.add_(
                    (1 - momentum_teacher) * student_ps.detach().data
                )

        num_batches += 1

        if (num_batches % log_number) == 0:
            print(f"Calculating KNN accuracy for report")
            acc, preds, train_lbls, val_lbls = compute_knn_accuracy(student.backbone, knn_train_dataset, knn_val_dataset, device, 64)
            print(f"KNN Accuracy {acc}")

  0%|          | 1/5000 [00:02<2:47:21,  2.01s/it]

Calculating KNN accuracy for report


 50%|████▉     | 155/313 [01:01<01:02,  2.53it/s]
  0%|          | 1/5000 [01:04<89:57:04, 64.78s/it]


KeyboardInterrupt: 