# SSL Quickstart (SimCLR/MoCo/MAE)

A short run to verify training works. Use a small number of epochs for a fast sanity check.


In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader

from src.augment import TwoCropsTransform, get_mae_transform, get_ssl_transform
from src.data import NIHChestXrayDataset
from src.ssl import MAEWrapper, MoCoV2, SimCLR, train_mae, train_moco, train_simclr

TRAIN_CSV = "splits/train.csv"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
METHOD = "simclr"  # simclr | moco | mae

train_df = pd.read_csv(TRAIN_CSV)
print("Train samples:", len(train_df))


In [None]:
if METHOD in {"simclr", "moco"}:
    transform = TwoCropsTransform(get_ssl_transform(224))
else:
    transform = get_mae_transform(224)

train_ds = NIHChestXrayDataset(train_df, transform=transform)
loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2, drop_last=True)

if METHOD == "simclr":
    model = SimCLR(backbone="resnet50")
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    train_simclr(model, loader, optimizer, torch.device(DEVICE), epochs=1, output_dir="checkpoints/simclr_debug")
elif METHOD == "moco":
    model = MoCoV2(backbone="resnet50")
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    train_moco(model, loader, optimizer, torch.device(DEVICE), epochs=1, output_dir="checkpoints/moco_debug")
else:
    model = MAEWrapper(mask_ratio=0.75)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    train_mae(model, loader, optimizer, torch.device(DEVICE), epochs=1, output_dir="checkpoints/mae_debug")
