# Hierarchical DINO Pretraining

generic imports

In [None]:
import os
import sys
import math
import wandb
import datetime
import time
import json
import hydra
import shutil
import random
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import multiprocessing as mp

from tqdm import tqdm
from pathlib import Path
from omegaconf import OmegaConf

module imports

In [None]:
# sys.path.append("/path/to/dino")
sys.path.append("/home/clementgrisi/clement/code/dino")

In [None]:
import dino.models.vision_transformer as vits

from dino.components import DINOLoss
from dino.data import RegionDataAugmentationDINO, HierarchicalPretrainingDataset
from dino.models import MultiCropWrapper
from dino.distributed import get_world_size, is_main_process
from dino.utils import (
train_one_epoch,
cosine_scheduler,
fix_random_seeds,
has_batchnorms,
get_params_groups,
compute_time,
start_from_checkpoint,
resume_from_checkpoint,
)
from dino.utils.utils import clip_gradients, cancel_gradients_last_layer
from dino.log import initialize_wandb, update_log_dict, MetricLogger

load config

In [None]:
config_file = "/home/clementgrisi/clement/code/dino/dino/config/region_debug.yaml"

In [None]:
cfg = OmegaConf.load(config_file)

initialize distributed session (if necessary)

In [None]:
distributed = torch.cuda.device_count() > 1
if distributed:
    torch.distributed.init_process_group(backend="nccl")
    gpu_id = int(os.environ["LOCAL_RANK"])
    if gpu_id == 0:
        print("Distributed session successfully initialized")
else:
    gpu_id = -1

In [None]:
if is_main_process():
    print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
    run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
    # set up wandb
    if cfg.wandb.enable:
        key = os.environ.get("WANDB_API_KEY")
        wandb_run = initialize_wandb(cfg, key=key)
        wandb_run.define_metric("epoch", summary="max")
        run_id = wandb_run.id
else:
    run_id = ""

if distributed:
    obj = [run_id]
    torch.distributed.broadcast_object_list(
        obj, 0, device=torch.device(f"cuda:{gpu_id}")
    )
    run_id = obj[0]

fix_random_seeds(cfg.seed)
cudnn.benchmark = True

create output directory

In [None]:
cfg.output_dir

In [None]:
output_dir = Path(cfg.output_dir, cfg.experiment_name, run_id)
snapshot_dir = Path(output_dir, "snapshots")
if not cfg.resume and is_main_process():
    if output_dir.exists():
        print(f"WARNING: {output_dir} already exists! Deleting its content...")
        shutil.rmtree(output_dir)
        output_dir.mkdir(parents=True)
    else:
        output_dir.mkdir(exist_ok=True, parents=True)
    snapshot_dir.mkdir(exist_ok=True, parents=True)

prepare pretraining data

In [None]:
transform = RegionDataAugmentationDINO(
    cfg.aug.global_crops_scale,
    cfg.aug.local_crops_number,
    cfg.aug.local_crops_scale,
    cfg.model.region_size,
    cfg.model.patch_size,
)

In [None]:
# using custom dataset for our [256 x 384] tensors ("local" features)
dataset = HierarchicalPretrainingDataset(cfg.data_dir, transform)

In [None]:
if cfg.training.pct:
    print(f"Pre-training on {cfg.training.pct*100}% of the data")
    nsample = int(cfg.training.pct * len(dataset))
    idxs = random.sample(range(len(dataset)), k=nsample)
    dataset = torch.utils.data.Subset(dataset, idxs)

In [None]:
if distributed:
    sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
else:
    sampler = torch.utils.data.RandomSampler(dataset)

In [None]:
num_workers = min(mp.cpu_count(), cfg.speed.num_workers)
if "SLURM_JOB_CPUS_PER_NODE" in os.environ:
    num_workers = min(num_workers, int(os.environ["SLURM_JOB_CPUS_PER_NODE"]))
num_workers

In [None]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    sampler=sampler,
    batch_size=cfg.training.batch_size_per_gpu,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=True,
)
print(f"Pretraining data loaded ({len(dataset)} regions)")

build student and teacher networks

In [None]:
student = vits.__dict__[cfg.model.arch](
    img_size=cfg.model.region_size,
    patch_size=cfg.model.patch_size,
    drop_path_rate=cfg.model.drop_path_rate,
)
teacher = vits.__dict__[cfg.model.arch](
    img_size=cfg.model.region_size, patch_size=cfg.model.patch_size
)
embed_dim = student.embed_dim

In [None]:
# multi-crop wrapper handles forward with inputs of different resolutions
student = MultiCropWrapper(
    student,
    vits.DINOHead(
        embed_dim,
        cfg.model.out_dim,
        use_bn=cfg.model.use_bn_in_head,
        norm_last_layer=cfg.model.norm_last_layer,
    ),
)
teacher = MultiCropWrapper(
    teacher,
    vits.DINOHead(
        embed_dim,
        cfg.model.out_dim,
        use_bn=cfg.model.use_bn_in_head,
    ),
)

In [None]:
# move networks to gpu
if distributed:
    student, teacher = student.to(gpu_id), teacher.to(gpu_id)
else:
    student, teacher = student.cuda(), teacher.cuda()

# synchronize batch norms (if any)
if has_batchnorms(student) and distributed:
    # we need DDP wrapper to have synchro batch norms working...
    student = nn.SyncBatchNorm.convert_sync_batchnorm(student)
    teacher = nn.SyncBatchNorm.convert_sync_batchnorm(teacher)
    teacher = nn.parallel.DistributedDataParallel(
        teacher, device_ids=[gpu_id], output_device=gpu_id
    )
    teacher_without_ddp = teacher.module
else:
    # teacher_without_ddp and teacher are the same thing
    teacher_without_ddp = teacher

if distributed:
    student = nn.parallel.DistributedDataParallel(
        student,
        device_ids=[gpu_id],
        output_device=gpu_id,
    )

In [None]:
# optionally start student from existing checkpoint
if cfg.start_from_checkpoint:
    ckpt_path = Path(cfg.start_from_checkpoint)
    start_from_checkpoint(
        ckpt_path,
        student,
    )

In [None]:
# teacher and student start with the same weights
student_sd = student.state_dict()
nn.modules.utils.consume_prefix_in_state_dict_if_present(student_sd, "module.")
teacher_without_ddp.load_state_dict(student_sd)

In [None]:
# there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
    p.requires_grad = False

create loss

In [None]:
# total number of crops = 2 global crops + local_crops_number
crops_number = cfg.aug.local_crops_number + 2

In [None]:
dino_loss = DINOLoss(
    cfg.model.out_dim,
    crops_number,
    cfg.model.warmup_teacher_temp,
    cfg.model.teacher_temp,
    cfg.model.warmup_teacher_temp_epochs,
    cfg.training.nepochs,
)

In [None]:
if distributed:
    dino_loss = dino_loss.to(gpu_id)
else:
    dino_loss = dino_loss.cuda()

create optimizer

In [None]:
params_groups = get_params_groups(student)
optimizer = torch.optim.AdamW(params_groups)

In [None]:
# for mixed precision training
fp16_scaler = None
if cfg.speed.use_fp16:
    fp16_scaler = torch.cuda.amp.GradScaler()

create schedulers

In [None]:
assert (
    cfg.training.nepochs >= cfg.training.warmup_epochs
), f"nepochs ({cfg.training.nepochs}) must be greater than or equal to warmup_epochs ({cfg.training.warmup_epochs})"
base_lr = (
    cfg.optim.lr * (cfg.training.batch_size_per_gpu * get_world_size()) / 256.0
)

In [None]:
lr_schedule = cosine_scheduler(
    base_lr,
    cfg.optim.lr_scheduler.min_lr,
    cfg.training.nepochs,
    len(data_loader),
    warmup_epochs=cfg.training.warmup_epochs,
)

In [None]:
wd_schedule = cosine_scheduler(
    cfg.optim.lr_scheduler.weight_decay,
    cfg.optim.lr_scheduler.weight_decay_end,
    cfg.training.nepochs,
    len(data_loader),
)

In [None]:
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule = cosine_scheduler(
    cfg.model.momentum_teacher,
    1,
    cfg.training.nepochs,
    len(data_loader),
)

pretrain utils

In [None]:
def train_one_epoch(
    student,
    teacher,
    teacher_without_ddp,
    dino_loss,
    data_loader,
    optimizer,
    lr_schedule,
    wd_schedule,
    momentum_schedule,
    epoch,
    nepochs,
    fp16_scaler,
    clip_grad,
    freeze_last_layer,
    gpu_id,
):
    metric_logger = MetricLogger(delimiter="  ")
    with tqdm(
        data_loader,
        desc=(f"Epoch [{epoch+1}/{nepochs}]"),
        unit=" img",
        ncols=80,
        unit_scale=data_loader.batch_size,
        leave=False,
        disable=not (gpu_id in [-1, 0]),
    ) as t:
        for it, (images, _) in enumerate(t):
            # update weight decay and learning rate according to their schedule
            it = len(data_loader) * epoch + it  # global training iteration
            for i, param_group in enumerate(optimizer.param_groups):
                param_group["lr"] = lr_schedule[it]
                if i == 0:  # only the first group is regularized
                    param_group["weight_decay"] = wd_schedule[it]

            # move images to gpu
            if gpu_id == -1:
                images = [im.cuda(non_blocking=True) for im in images]
            else:
                device = torch.device(f"cuda:{gpu_id}")
                images = [im.to(device, non_blocking=True) for im in images]
            # teacher and student forward passes + compute dino loss
            with torch.cuda.amp.autocast(fp16_scaler is not None):
                teacher_output = teacher(
                    images[:2]
                )  # only the 2 global views pass through the teacher
                student_output = student(images)
                loss = dino_loss(student_output, teacher_output, epoch)

            if not math.isfinite(loss.item()):
                tqdm.write(
                    "Loss is {}, stopping training".format(loss.item()), force=True
                )
                sys.exit(1)

            # student update
            optimizer.zero_grad()
            param_norms = None
            if fp16_scaler is None:
                loss.backward()
                if clip_grad:
                    param_norms = clip_gradients(student, clip_grad)
                cancel_gradients_last_layer(epoch, student, freeze_last_layer)
                optimizer.step()
            else:
                fp16_scaler.scale(loss).backward()
                if clip_grad:
                    fp16_scaler.unscale_(
                        optimizer
                    )  # unscale the gradients of optimizer's assigned params in-place
                    param_norms = clip_gradients(student, clip_grad)
                cancel_gradients_last_layer(epoch, student, freeze_last_layer)
                fp16_scaler.step(optimizer)
                fp16_scaler.update()

            # EMA update for the teacher
            with torch.no_grad():
                m = momentum_schedule[it]  # momentum parameter
                if torch.cuda.device_count() > 1:
                    student_params = student.module.parameters()
                else:
                    student_params = student.parameters()
                for param_q, param_k in zip(
                    student_params, teacher_without_ddp.parameters()
                ):
                    param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

            # logging
            torch.cuda.synchronize()
            metric_logger.update(loss=loss.item())
            metric_logger.update(lr=optimizer.param_groups[0]["lr"])
            metric_logger.update(wd=optimizer.param_groups[0]["weight_decay"])

    # gather the stats from all processes
    metric_logger.synchronize_between_processes(gpu_id)
    # print("Averaged stats:", metric_logger)
    train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    return train_stats

pretrain model

In [None]:
epochs_run = 0
start_time = time.time()

for epoch in range(epochs_run, cfg.training.nepochs):
    epoch_start_time = time.time()
    if cfg.wandb.enable and is_main_process():
        log_dict = {"epoch": epoch}

    if distributed:
        data_loader.sampler.set_epoch(epoch)

    # training one epoch of DINO
    train_stats = train_one_epoch(
        student,
        teacher,
        teacher_without_ddp,
        dino_loss,
        data_loader,
        optimizer,
        lr_schedule,
        wd_schedule,
        momentum_schedule,
        epoch,
        cfg.training.nepochs,
        fp16_scaler,
        cfg.training.clip_grad,
        cfg.training.freeze_last_layer,
        gpu_id,
    )

    if cfg.wandb.enable and is_main_process():
        update_log_dict("train", train_stats, log_dict, step="epoch")

    # save snapshot and log to wandb
    if is_main_process():
        snapshot = {
            "epoch": epoch,
            "student": student.state_dict(),
            "teacher": teacher.state_dict(),
            "optimizer": optimizer.state_dict(),
            "dino_loss": dino_loss.state_dict(),
        }
        if fp16_scaler is not None:
            snapshot["fp16_scaler"] = fp16_scaler.state_dict()

        save_path = Path(snapshot_dir, f"epoch_{epoch:03}.pt")
        if (
            cfg.logging.save_snapshot_every
            and epoch % cfg.logging.save_snapshot_every == 0
        ):
            torch.save(snapshot, save_path)
        torch.save(snapshot, Path(snapshot_dir, "latest.pt"))

        if cfg.wandb.enable:
            wandb.log(log_dict, step=epoch)

    log_stats = {
        **{f"train_{k}": v for k, v in train_stats.items()},
        "epoch": epoch,
    }
    if is_main_process():
        with open(Path(output_dir, "log.txt"), "a") as f:
            f.write(json.dumps(log_stats) + "\n")

    epoch_end_time = time.time()
    epoch_mins, epoch_secs = compute_time(epoch_start_time, epoch_end_time)
    if is_main_process():
        tqdm.write(
            f"End of epoch {epoch+1}/{cfg.training.nepochs} \t Time Taken:  {epoch_mins}m {epoch_secs}s"
        )

total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Pretraining time {}".format(total_time_str))

if distributed:
    torch.distributed.destroy_process_group()