# Vanilla DINO Pretraining

generic imports

In [None]:
import os
import sys
import time
import json
import math
import hydra
import wandb
import shutil
import random
import datetime
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import numpy as np
import pandas as pd
import multiprocessing as mp

from pathlib import Path
from omegaconf import OmegaConf
from torchvision import datasets
from tqdm import tqdm
from collections import defaultdict

module imports

In [None]:
sys.path.append("/path/to/your/dino/folder")

In [None]:
import dino.models.vision_transformer as vits

from dino.components import DINOLoss, EarlyStoppingDINO
from dino.data import PatchDataAugmentationDINO
from dino.eval import prepare_data
from dino.models import MultiCropWrapper
from dino.distributed import get_world_size, is_main_process
from dino.utils import cosine_scheduler, fix_random_seeds, has_batchnorms, get_params_groups, compute_time, resume_from_checkpoint
from dino.log import initialize_wandb, update_log_dict, MetricLogger
from dino.eval.knn import knn_classifier
from dino.utils.utils import load_weights, clip_gradients, cancel_gradients_last_layer

load config

In [None]:
config_file = "/path/to/your/patch/config.yaml"

In [None]:
cfg = OmegaConf.load(config_file)

initialize distributed session (if necessary)

In [None]:
distributed = torch.cuda.device_count() > 1
if distributed:
    torch.distributed.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    if gpu_id == 0:
        print(f"Distributed session successfully initialized")
else:
    gpu_id = -1

if is_main_process():
    print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
    run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
    # set up wandb
    if cfg.wandb.enable:
        key = os.environ.get("WANDB_API_KEY")
        wandb_run = initialize_wandb(cfg, key=key)
        wandb_run.define_metric("epoch", summary="max")
        run_id = wandb_run.id
else:
    run_id = ""

if distributed:
    obj = [run_id]
    torch.distributed.broadcast_object_list(
        obj, 0, device=torch.device(f"cuda:{gpu_id}")
    )
    run_id = obj[0]

fix_random_seeds(cfg.seed)
cudnn.benchmark = True

create output directory

In [None]:
cfg.output_dir

In [None]:
output_dir = Path(cfg.output_dir, cfg.experiment_name, run_id)
snapshot_dir = Path(output_dir, "snapshots")
features_dir = Path(output_dir, "features")
if not cfg.resume and is_main_process():
    if output_dir.exists():
        print(f"WARNING: {output_dir} already exists! Deleting its content...")
        shutil.rmtree(output_dir)
        output_dir.mkdir(parents=True)
    else:
        output_dir.mkdir(parents=True, exist_ok=True)
    snapshot_dir.mkdir(exist_ok=True, parents=True)
    if cfg.early_stopping.tune_every and cfg.early_stopping.knn.save_features:
        features_dir.mkdir(exist_ok=True, parents=True)

prepare downstream tuning data

In [None]:
if is_main_process() and cfg.early_stopping.tune_every:

    # only do it from master rank as tuning is not being run distributed for now
    train_df = pd.read_csv(cfg.early_stopping.downstream.train_csv)
    test_df = pd.read_csv(cfg.early_stopping.downstream.test_csv)

    num_workers = min(mp.cpu_count(), cfg.early_stopping.downstream.num_workers)
    if "SLURM_JOB_CPUS_PER_NODE" in os.environ:
        num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"]))

    downstream_train_loader, downstream_test_loader = prepare_data(
        train_df,
        test_df,
        cfg.early_stopping.downstream.batch_size_per_gpu,
        False,
        num_workers,
        cfg.early_stopping.downstream.label_name,
    )
    print(
        f"Tuning data loaded with {len(downstream_train_loader.dataset)} train patches and {len(downstream_test_loader.dataset)} test patches."
    )

prepare pretraining data

In [None]:
transform = PatchDataAugmentationDINO(
    cfg.aug.global_crops_scale,
    cfg.aug.local_crops_scale,
    cfg.aug.local_crops_number,
)

In [None]:
dataset_loading_start_time = time.time()
dataset = datasets.ImageFolder(cfg.data_dir, transform=transform)
dataset_loading_end_time = time.time() - dataset_loading_start_time
total_time_str = str(datetime.timedelta(seconds=int(dataset_loading_end_time)))
if is_main_process():
    print(f"Pretraining data loaded in {total_time_str} ({len(dataset)} patches)")

In [None]:
if cfg.training.pct:
    print(f"Pre-training on {cfg.training.pct*100}% of the data")
    nsample = int(cfg.training.pct * len(dataset))
    idxs = random.sample(range(len(dataset)), k=nsample)
    dataset = torch.utils.data.Subset(dataset, idxs)

In [None]:
if distributed:
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
else:
    sampler = torch.utils.data.RandomSampler(dataset)

In [None]:
num_workers = min(mp.cpu_count(), cfg.speed.num_workers)
if "SLURM_JOB_CPUS_PER_NODE" in os.environ:
    num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"]))
num_workers

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    sampler=sampler,
    batch_size=cfg.training.batch_size_per_gpu,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
)

build student and teacher networks

In [None]:
student = vits.__dict__[cfg.model.arch](
    patch_size=cfg.model.patch_size,
    drop_path_rate=cfg.model.drop_path_rate,
)
teacher = vits.__dict__[cfg.model.arch](patch_size=cfg.model.patch_size)
embed_dim = student.embed_dim
embed_dim

In [None]:
teacher.embed_dim

In [None]:
cfg.model.out_dim

In [None]:
# multi-crop wrapper handles forward with inputs of different resolutions
student = MultiCropWrapper(
    student,
    vits.DINOHead(
        embed_dim,
        cfg.model.out_dim,
        use_bn=cfg.model.use_bn_in_head,
        norm_last_layer=cfg.model.norm_last_layer,
    ),
)
teacher = MultiCropWrapper(
    teacher,
    vits.DINOHead(
        embed_dim,
        cfg.model.out_dim,
        use_bn=cfg.model.use_bn_in_head,
    ),
)

In [None]:
# move networks to gpu
if distributed:
    student, teacher = student.to(gpu_id), teacher.to(gpu_id)
else:
    student, teacher = student.cuda(), teacher.cuda()

# synchronize batch norms (if any)
if has_batchnorms(student) and distributed:
    # we need DDP wrapper to have synchro batch norms working...
    student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
    teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
    teacher = nn.parallel.DistributedDataParallel(
        teacher, device_ids=[gpu_id], output_device=gpu_id
    )
    teacher_without_ddp = teacher.module
else:
    # teacher_without_ddp and teacher are the same thing
    teacher_without_ddp = teacher

if distributed:
    student = nn.parallel.DistributedDataParallel(
        student, device_ids=[gpu_id], output_device=gpu_id
    )

In [None]:
# teacher and student start with the same weights
student_sd = student.state_dict()
nn.modules.utils.consume_prefix_in_state_dict_if_present(student_sd, "module.")
teacher_without_ddp.load_state_dict(student_sd)

In [None]:
# there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
    p.requires_grad = False

create loss

In [None]:
# total number of crops = 2 global crops + local_crops_number
crops_number = cfg.aug.local_crops_number + 2

In [None]:
dino_loss = DINOLoss(
    cfg.model.out_dim,
    crops_number,
    cfg.model.warmup_teacher_temp,
    cfg.model.teacher_temp,
    cfg.model.warmup_teacher_temp_epochs,
    cfg.training.nepochs,
)

In [None]:
if distributed:
    dino_loss = dino_loss.to(gpu_id)
else:
    dino_loss = dino_loss.cuda()

create optimizer

In [None]:
params_groups = get_params_groups(student)
optimizer = torch.optim.AdamW(params_groups)

In [None]:
# for mixed precision training
fp16_scaler = None
if cfg.speed.use_fp16:
    fp16_scaler = torch.cuda.amp.GradScaler()

create schedulers

In [None]:
assert (
    cfg.training.nepochs >= cfg.training.warmup_epochs
), f"nepochs ({cfg.training.nepochs}) must be greater than or equal to warmup_epochs ({cfg.training.warmup_epochs})"
base_lr = (
    cfg.optim.lr * (cfg.training.batch_size_per_gpu * get_world_size()) / 256.0
)

In [None]:
lr_schedule = cosine_scheduler(
    base_lr,
    cfg.optim.lr_scheduler.min_lr,
    cfg.training.nepochs,
    len(data_loader),
    warmup_epochs=cfg.training.warmup_epochs,
)

In [None]:
wd_schedule = cosine_scheduler(
    cfg.optim.lr_scheduler.weight_decay,
    cfg.optim.lr_scheduler.weight_decay_end,
    cfg.training.nepochs,
    len(data_loader),
)

In [None]:
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule = cosine_scheduler(
    cfg.model.momentum_teacher, 1, cfg.training.nepochs, len(data_loader)
)

create early stopper

In [None]:
if cfg.early_stopping.tune_every:
    early_stopping = EarlyStoppingDINO(
        cfg.early_stopping.tracking,
        cfg.early_stopping.min_max,
        cfg.early_stopping.patience,
        cfg.early_stopping.min_epoch,
        checkpoint_dir=snapshot_dir,
        save_every=cfg.early_stopping.save_every,
        verbose=True,
    )

tune utils

In [None]:
@torch.no_grad()
def extract_multiple_features(
    student,
    teacher,
    loader,
    distributed,
    use_cuda=True,
    multiscale=False,
):
    student_features = None
    teacher_features = None

    labels = []

    with tqdm(
        loader,
        desc=(f"Feature extraction"),
        unit=" slide",
        ncols=80,
        unit_scale=loader.batch_size,
        leave=True,
    ) as t:
        for i, batch in enumerate(t):
            index, img, label = batch
            index = index.cuda(non_blocking=True)
            img = img.cuda(non_blocking=True)
            labels.extend(label.clone().tolist())
            if multiscale:
                student_feats = multi_scale(img, student)
                teacher_feats = multi_scale(img, teacher)
            else:
                student_feats = student(img).clone()
                teacher_feats = teacher(img).clone()

            # init storage feature matrix
            if (
                is_main_process()
                and student_features is None
                and teacher_features is None
            ):
                student_features = torch.zeros(
                    len(loader.dataset), student_feats.shape[-1]
                )
                teacher_features = torch.zeros(
                    len(loader.dataset), teacher_feats.shape[-1]
                )
                if use_cuda:
                    student_features = student_features.cuda(non_blocking=True)
                    teacher_features = teacher_features.cuda(non_blocking=True)

            if distributed:
                ngpu = dist.get_world_size()
                y_all = torch.empty(
                    ngpu, index.size(0), dtype=index.dtype, device=index.device
                )
                y_l = list(y_all.unbind(0))
                y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
                y_all_reduce.wait()
                index_all = torch.cat(y_l)

                # share features between processes
                student_feats_all = torch.empty(
                    ngpu,
                    student_feats.size(0),
                    student_feats.size(1),
                    dtype=student_feats.dtype,
                    device=student_feats.device,
                )
                teacher_feats_all = torch.empty(
                    ngpu,
                    teacher_feats.size(0),
                    teacher_feats.size(1),
                    dtype=teacher_feats.dtype,
                    device=teacher_feats.device,
                )

                student_output_l = list(student_feats_all.unbind(0))
                student_output_all_reduce = torch.distributed.all_gather(
                    student_output_l, student_feats, async_op=True
                )
                teacher_output_l = list(teacher_feats_all.unbind(0))
                teacher_output_all_reduce = torch.distributed.all_gather(
                    teacher_output_l, teacher_feats, async_op=True
                )

                student_output_all_reduce.wait()
                teacher_output_all_reduce.wait()

                # update storage feature matrix
                if is_main_process():
                    if use_cuda:
                        student_features.index_copy_(
                            0, index_all, torch.cat(student_output_l)
                        )
                        teacher_features.index_copy_(
                            0, index_all, torch.cat(teacher_output_l)
                        )
                    else:
                        student_features.index_copy_(
                            0, index_all.cpu(), torch.cat(student_output_l).cpu()
                        )
                        teacher_features.index_copy_(
                            0, index_all.cpu(), torch.cat(teacher_output_l).cpu()
                        )
            else:
                student_features[list(index), :] = student_feats
                teacher_features[list(index), :] = teacher_feats

    if is_main_process():
        student_features = nn.functional.normalize(student_features, dim=1, p=2)
        teacher_features = nn.functional.normalize(teacher_features, dim=1, p=2)

    features = {"student": student_features, "teacher": teacher_features}
    labels = torch.tensor(labels).long()

    return features, labels

pretrain utils

In [None]:
def train_one_epoch(
    student,
    teacher,
    teacher_without_ddp,
    dino_loss,
    data_loader,
    optimizer,
    lr_schedule,
    wd_schedule,
    momentum_schedule,
    epoch,
    nepochs,
    fp16_scaler,
    clip_grad,
    freeze_last_layer,
    gpu_id,
):
    metric_logger = MetricLogger(delimiter="  ")
    with tqdm(
        data_loader,
        desc=(f"Epoch [{epoch+1}/{nepochs}]"),
        unit=" img",
        ncols=80,
        unit_scale=data_loader.batch_size,
        leave=False,
        disable=not (gpu_id in [-1, 0]),
    ) as t:
        for it, (images, _) in enumerate(t):
            # update weight decay and learning rate according to their schedule
            it = len(data_loader) * epoch + it  # global training iteration
            for i, param_group in enumerate(optimizer.param_groups):
                param_group["lr"] = lr_schedule[it]
                if i == 0:  # only the first group is regularized
                    param_group["weight_decay"] = wd_schedule[it]

            # move images to gpu
            if gpu_id == -1:
                images = [im.cuda(non_blocking=True) for im in images]
            else:
                device = torch.device(f"cuda:{gpu_id}")
                images = [im.to(device, non_blocking=True) for im in images]
            # teacher and student forward passes + compute dino loss
            with torch.cuda.amp.autocast(fp16_scaler is not None):
                teacher_output = teacher(
                    images[:2]
                )  # only the 2 global views pass through the teacher
                student_output = student(images)
                loss = dino_loss(student_output, teacher_output, epoch)

            if not math.isfinite(loss.item()):
                tqdm.write(
                    "Loss is {}, stopping training".format(loss.item()), force=True
                )
                sys.exit(1)

            # student update
            optimizer.zero_grad()
            param_norms = None
            if fp16_scaler is None:
                loss.backward()
                if clip_grad:
                    param_norms = clip_gradients(student, clip_grad)
                cancel_gradients_last_layer(epoch, student, freeze_last_layer)
                optimizer.step()
            else:
                fp16_scaler.scale(loss).backward()
                if clip_grad:
                    fp16_scaler.unscale_(
                        optimizer
                    )  # unscale the gradients of optimizer's assigned params in-place
                    param_norms = clip_gradients(student, clip_grad)
                cancel_gradients_last_layer(epoch, student, freeze_last_layer)
                fp16_scaler.step(optimizer)
                fp16_scaler.update()

            # EMA update for the teacher
            with torch.no_grad():
                m = momentum_schedule[it]  # momentum parameter
                if torch.cuda.device_count() > 1:
                    student_params = student.module.parameters()
                else:
                    student_params = student.parameters()
                for param_q, param_k in zip(
                    student_params, teacher_without_ddp.parameters()
                ):
                    param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

            # logging
            torch.cuda.synchronize()
            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes(gpu_id)
    # print("Averaged stats:", metric_logger)
    train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return train_stats

In [None]:
def tune_one_epoch(
    epoch,
    student: nn.Module,
    teacher: nn.Module,
    train_dataloader,
    test_dataloader,
    features_dir: Path,
    arch: str,
    patch_size: int,
    drop_path_rate: float,
    k: int,
    temperature: float,
    distributed: bool,
    save_features: bool = False,
    use_cuda: bool = False,
):
    student_model = vits.__dict__[arch](
        patch_size=patch_size, drop_path_rate=drop_path_rate, num_classes=0
    )
    teacher_model = vits.__dict__[arch](patch_size=patch_size, num_classes=0)
    tqdm.write(f"Teacher & student models {arch} {patch_size}x{patch_size} built.")
    student_model.cuda()
    teacher_model.cuda()
    tqdm.write(f"Loading epoch {epoch} weights...")
    student_weights = student.state_dict()
    teacher_weights = teacher.state_dict()
    load_weights(student_model, student_weights)
    load_weights(teacher_model, teacher_weights)
    student_model.eval()
    teacher_model.eval()

    # ============ extract student features ============
    tqdm.write("Extracting features for query set...")
    train_features, train_labels = extract_multiple_features(
        student_model, teacher_model, train_dataloader, distributed, use_cuda
    )
    tqdm.write("Extracting features for test set...")
    test_features, test_labels = extract_multiple_features(
        student_model, teacher_model, test_dataloader, distributed, use_cuda
    )

    teacher_train_features, teacher_test_features = (
        train_features["teacher"],
        test_features["teacher"],
    )
    student_train_features, student_test_features = (
        train_features["student"],
        test_features["student"],
    )

    # save features and labels
    if save_features and is_main_process():
        for name, feats in train_features.items():
            torch.save(feats.cpu(), Path(features_dir, f"{name}_train_feat.pth"))
        for name, feats in train_features.items():
            torch.save(feats.cpu(), Path(features_dir, f"{name}_test_feat.pth"))
        torch.save(train_labels.cpu(), Path(features_dir, "train_labels.pth"))
        torch.save(test_labels.cpu(), Path(features_dir, "test_labels.pth"))

    results = defaultdict(dict)
    if is_main_process():
        assert len(torch.unique(train_labels)) == len(
            torch.unique(test_labels)
        ), "train & test dataset have different number of classes!"
        num_classes = len(torch.unique(train_labels))
        if use_cuda:
            teacher_train_features, teacher_test_features = (
                teacher_train_features.cuda(),
                teacher_test_features.cuda(),
            )
            student_train_features, student_test_features = (
                student_train_features.cuda(),
                student_test_features.cuda(),
            )
            train_labels, test_labels = train_labels.cuda(), test_labels.cuda()

        tqdm.write("Features are ready!\nStarting kNN classification.")
        teacher_acc, teacher_auc = knn_classifier(
            teacher_train_features,
            train_labels,
            teacher_test_features,
            test_labels,
            k,
            temperature,
            num_classes,
        )
        student_acc, student_auc = knn_classifier(
            student_train_features,
            train_labels,
            student_test_features,
            test_labels,
            k,
            temperature,
            num_classes,
        )
        results["teacher"].update({"acc": teacher_acc, "auc": teacher_auc})
        results["student"].update({"acc": student_acc, "auc": student_auc})

    return results

pretrain model

In [None]:
epochs_run = 0
stop = False
start_time = time.time()


for epoch in range(epochs_run, cfg.training.nepochs):
    epoch_start_time = time.time()
    if cfg.wandb.enable and is_main_process():
        log_dict = {"epoch": epoch}

    if distributed:
        data_loader.sampler.set_epoch(epoch)

    # training one epoch of DINO
    train_stats = train_one_epoch(
        student,
        teacher,
        teacher_without_ddp,
        dino_loss,
        data_loader,
        optimizer,
        lr_schedule,
        wd_schedule,
        momentum_schedule,
        epoch,
        cfg.training.nepochs,
        fp16_scaler,
        cfg.training.clip_grad,
        cfg.training.freeze_last_layer,
        gpu_id,
    )

    if cfg.wandb.enable and is_main_process():
        update_log_dict("train", train_stats, log_dict, step="epoch")

    if is_main_process():
        snapshot = {
            "epoch": epoch,
            "student": student.state_dict(),
            "teacher": teacher.state_dict(),
            "optimizer": optimizer.state_dict(),
            "dino_loss": dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            snapshot["fp16_scaler"] = fp16_scaler.state_dict()

    # only run tuning on rank 0, otherwise one has to take care of gathering knn metrics from multiple gpus
    tune_results = None
    if (
        cfg.early_stopping.tune_every
        and epoch % cfg.early_stopping.tune_every == 0
        and is_main_process()
    ):
        tune_results = tune_one_epoch(
            epoch + 1,
            student,
            teacher_without_ddp,
            downstream_train_loader,
            downstream_test_loader,
            features_dir,
            cfg.model.arch,
            cfg.model.patch_size,
            cfg.model.drop_path_rate,
            cfg.early_stopping.knn.k,
            cfg.early_stopping.knn.temperature,
            False,
            cfg.early_stopping.knn.save_features,
            cfg.early_stopping.knn.use_cuda,
        )

        if cfg.wandb.enable and is_main_process():
            update_log_dict("tune", tune_results, log_dict, step="epoch")

    if is_main_process():
        early_stopping(epoch, tune_results, snapshot)
        if early_stopping.early_stop and cfg.early_stopping.enable:
            stop = True

    if stop:
        tqdm.write(
            f"Stopping early because best {cfg.early_stopping.tracking} was reached {cfg.early_stopping.patience} epochs ago"
        )
        break

    # save snapshot and log to wandb
    if is_main_process():
        save_path = Path(snapshot_dir, f"epoch_{epoch:03}.pt")
        if (
            cfg.early_stopping.save_every
            and epoch % cfg.early_stopping.save_every == 0
            and not save_path.is_file()
        ):
            torch.save(snapshot, save_path)

        if cfg.wandb.enable:
            wandb.log(log_dict, step=epoch)

    log_stats = {
        **{f"train_{k}": v for k, v in train_stats.items()},
        "epoch": epoch,
    }
    if is_main_process():
        with open(Path(output_dir, "log.txt"), "a") as f:
            f.write(json.dumps(log_stats) + "\n")

    epoch_end_time = time.time()
    epoch_mins, epoch_secs = compute_time(epoch_start_time, epoch_end_time)
    if is_main_process():
        tqdm.write(
            f"End of epoch {epoch+1}/{cfg.training.nepochs} \t Time Taken:  {epoch_mins}m {epoch_secs}s"
        )

    # ensure other gpus wait until gpu_0 is finished with tuning before starting next training iteration
    if distributed:
        torch.distributed.barrier()

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Pretraining time {}".format(total_time_str))

if distributed:
    torch.distributed.destroy_process_group()