In [None]:
import os
import pandas as pd
import torch

from tqdm import tqdm
from torch.utils.data import DataLoader

from torchvision.models import vit_b_16

from utils import fix_random_seeds, clip_gradients, compute_knn_accuracy
from dataset import ISICDataset, get_random_subset_without_given_indices
from dino import DataAugmentationDINO, MultiCropWrapper, DINOHead, DINOLoss

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
print(device)

In [None]:
fix_random_seeds(42)

In [None]:
metadata = pd.read_csv('data/metadata.csv')
labels = metadata['malignant'].values.astype(int)
files = [f"data/ISIC24/{f}" for f in os.listdir('data/ISIC24')]

In [None]:
transform = DataAugmentationDINO(global_crops_scale=(0.4, 1.0), local_crops_scale=(0.05, 0.4), local_crops_number=8)

In [None]:
dataset = ISICDataset(files, labels, transform=transform)
dataset_loader = DataLoader(dataset, batch_size=2, shuffle=True)

In [None]:
# Calculate class counts
from torch.utils.data import WeightedRandomSampler, Subset

targets = torch.tensor(dataset.labels)
class_counts = torch.bincount(targets)
num_zeros = class_counts[0].item()
num_ones = class_counts[1].item()

# Create weight tensor
weights = torch.ones(len(targets))
weights[targets == 1] = num_zeros / num_ones

sampler = WeightedRandomSampler(weights, 500, replacement=False)

# sample all the indices via the weighted sampler
indices = list(sampler)

sampler = WeightedRandomSampler(weights, 1000, replacement=False)
indices_train = list(sampler)

# remove from train indices the indices that are in the validation indices
indices_train = [idx for idx in indices_train if idx not in indices]

# get the subset of the dataset
val_knn_subset = Subset(dataset, indices)

# get the remaining indices
train_knn_subset = Subset(dataset, indices_train)

In [None]:
student = vit_b_16(weights=None)
teacher = vit_b_16(weights=None)

# make teacher and student have the same weights
teacher.load_state_dict(student.state_dict())

student = MultiCropWrapper(student, DINOHead(768, 1024))
teacher = MultiCropWrapper(teacher, DINOHead(768, 1024))

student = student.to(device)
teacher = teacher.to(device)

for p in teacher.parameters():
    p.requires_grad = False

In [None]:
dino_loss = DINOLoss(1024,8+2,0.04,0.04,0,100)
dino_loss = dino_loss.to(device)

In [None]:
lr = 0.0005 * 16 / 256
optimizer = torch.optim.AdamW(student.parameters(), lr=lr, weight_decay=1e-6)
momentum_teacher = 0.995

In [None]:
log_number = 2

In [None]:
epochs = 100

for e in range(epochs):
    num_batches = 0
    for images, _ in tqdm(dataset_loader):
        images = [img.to(device) for img in images]
        student_output = student(images)
        teacher_output = teacher(images[:2])

        loss = dino_loss(student_output, teacher_output, e)

        optimizer.zero_grad()
        loss.backward()
        clip_gradients(student)
        optimizer.step()

        with torch.no_grad():
            for student_ps, teacher_ps in zip(
                student.parameters(), teacher.parameters()
            ):
                teacher_ps.data.mul_(momentum_teacher)
                teacher_ps.data.add_(
                    (1 - momentum_teacher) * student_ps.detach().data
                )

        num_batches += 1

        if (num_batches % log_number) == 0:
            print(f"Calculating KNN accuracy for report")
            acc, preds, train_lbls, val_lbls = compute_knn_accuracy(student.backbone, train_knn_subset, val_knn_subset, device, 64)
            print(f"KNN Accuracy {acc}")