In [None]:
!pip -q install torch torchvision medmnist scikit-learn matplotlib tqdm


In [17]:
import os, random, json, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms.v2 as vT
import torchvision.models as tvm
import matplotlib.pyplot as plt
import medmnist
from medmnist import INFO

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

class SSLDataset(Dataset):
  def __init__(self, base_dataset, ssl_transform):
    self.base = base_dataset
    self.transform = ssl_transform

  def __len__(self):
    return len(self.base)

  def __getitem__(self, idx):
    x,_ = self.base[idx]
    v1 = self.transform(x)
    v2 = self.transform(x)
    return v1,v2

def get_medmnist_train(key: str, download=True):
    info = INFO[key]
    DataClass = getattr(medmnist, info['python_class'])
    base_transform = vT.Compose([
        vT.Resize(224),
        vT.ToImage(),
        vT.ToDtype(torch.float32, scale=True),
        vT.Lambda(lambda x: x.repeat(3,1,1) if x.shape[0] == 1 else x),
    ])
    return DataClass(split="train", transform=base_transform, download=download)

ssl_transform = vT.Compose([
    vT.RandomResizedCrop(224, scale=(0.7, 1.0)),
    vT.RandomHorizontalFlip(p=0.5),
    vT.RandomRotation(degrees=10),
    vT.GaussianNoise(sigma=0.01),
    # vT.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),
    vT.ColorJitter(brightness=0.05, contrast=0.05),
    vT.Normalize(mean=[0.485,0.456,0.406],
                 std=[0.229,0.224,0.225]),
])


class ResNet18Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        m = tvm.resnet18(weights=tvm.ResNet18_Weights.IMAGENET1K_V1)
        self.features = nn.Sequential(*list(m.children())[:-1])  # remove classification layer
        self.out_dim = 512

    def forward(self, x):
        x = self.features(x)
        return x.flatten(1)

class ProjectionHead(nn.Module):
    def __init__(self, in_dim=512, proj_dim=128):
      super().__init__()
      self.net = nn.Sequential(
          nn.Linear(in_dim, in_dim),
          nn.BatchNorm1d(in_dim),
          nn.ReLU(inplace=True),
          nn.Linear(in_dim, proj_dim),
        )
    def forward(self, x):
        return self.net(x)

def nt_xent_loss (z1, z2, tau=0.2):
  B = z1.size(0)

  z1 = F.normalize(z1, dim=1)
  z2 = F.normalize(z2, dim=1)

  z = torch.cat([z1,z2], dim=0) #stack matrix rows, (z1 and z2 have shape [B, D], z has shape [2B,D])

  sim = z @ z.T/tau #cosine similarity since z1 and z2 are normalized

  mask = torch.eye(2*B, device=z.device).bool()
  sim = sim.masked_fill(mask, -1e9) #mask diagonal entries that correspond to paired augmentations

  labels = (torch.arange(2*B, device=z.device) + B) % (2*B)
  return F.cross_entropy(sim, labels)

In [None]:
# ---- Config (choose ONE binary dataset) ----
set_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', device)

DATASET_KEY = 'breastmnist'  # or 'breastmnist'
EPOCHS      = 50
BATCH_SIZE  = 128
LR          = 1e-3
WEIGHT_DECAY= 1e-4
TAU = 0.2

save_dir = "results/week4"
os.makedirs(save_dir, exist_ok=True)

In [None]:
base_ds = get_medmnist_train(DATASET_KEY)
ssl_ds = SSLDataset(base_ds, ssl_transform)
ssl_loader = DataLoader (ssl_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

encoder = ResNet18Encoder().to(device)
proj_head = ProjectionHead().to(device)

optimizer = torch.optim.AdamW(
    list(encoder.parameters()) + list(proj_head.parameters()),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
)

loss_history = []

for epoch in range(1, EPOCHS + 1):
    encoder.train()
    proj_head.train()
    running_loss = 0.0

    for v1, v2 in ssl_loader:
        v1 = v1.to(device)
        v2 = v2.to(device)

        h1 = encoder(v1)
        h2 = encoder(v2)
        z1 = F.normalize(proj_head(h1), dim=1)
        z2 = F.normalize(proj_head(h2), dim=1)

        loss = nt_xent_loss(z1, z2, tau=TAU)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * v1.size(0)

    epoch_loss = running_loss / len(ssl_loader.dataset)
    loss_history.append(epoch_loss)
    print(f"[SSL] Epoch {epoch:03d}/{EPOCHS} | loss={epoch_loss:.4f}")

torch.save(encoder.state_dict(), f"{save_dir}/ssl_encoder.pt")
torch.save(proj_head.state_dict(), f"{save_dir}/ssl_proj_head.pt")

plt.figure()
plt.plot(loss_history)
plt.xlabel("Epoch")
plt.ylabel("NT-Xent loss")
plt.title("SSL Loss Curve")
plt.tight_layout()
plt.savefig(f"{save_dir}/ssl_loss_curve.png")
plt.close()

ssl_config = {
    "method": "simclr-lite",
    "dataset": DATASET_KEY,
    "backbone": "resnet18",
    "proj_dim": 128,
    "batch_size": BATCH_SIZE,
    "epochs": EPOCHS,
    "lr": LR,
    "weight_decay": WEIGHT_DECAY,
    "augmentations": "RandomResizedCrop+Flip+Rot+Jitter, GussianNoise",
    "tau": TAU,
    "seed": 42,
}

with open(f"{save_dir}/ssl_config.json", "w") as f:
    json.dump(ssl_config, f, indent=4)




In [20]:
encoder.half()  # convert all weights to float16
torch.save(encoder.state_dict(), f"{save_dir}/ssl_encoder_fp16.pt")