<a href="https://colab.research.google.com/github/larissabooth/cv_project/blob/main/self_supe_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#@title Imports 

# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from pathlib import Path
import argparse
import json
import os
import random
import signal
import sys
import time
import urllib

from torch import nn, optim
from torchvision import datasets, transforms
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn, optim
import torch.distributed as dist

sys.path.append("/content/drive/My Drive/computer_vision_project/vicreg")

import augmentations as aug
# from distributed import init_distributed_mode

import resnet

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = True
gpu = torch.device(device)


In [None]:
#@title Network configurations

%cd "/content/drive/My Drive/computer_vision_project/vicreg"

#time tools
from datetime import datetime
currentDateAndTime = datetime.now()

currentTime = currentDateAndTime.strftime("%H_%M")

#Data
data_dir = Path("/content/drive/My Drive/computer_vision_project/Kitchener_torch")

# Checkpoint
ckpt_file = Path("checkpoint.pth")
pretrained = Path("./checkpoints/lincls/resnet50_fullckpt.pth") #path to pretrained model
exp_dir = Path("./checkpoints/self_sup") #path to export directory
log_freq_time = 10 #'Print logs to the stats.txt file every [log-freq-time] seconds'

# Model
arch = "resnet50"
mlp = "8192-8192-8192" #'Size and number of layers of the MLP expander head'

# Optim
epochs = 100
batch_size = 128
base_lr = 0.1 #"Base learning rate, effective learning after warmup is [base-lr] * [batch-size] / 256
wd = 0 #weight decay

#Loss
sim_coeff = 25.0 #'Invariance regularization loss coefficient'
std_coeff = 25.0 #'Variance regularization loss coefficient'
cov_coeff = 1.0  #'Covariance regularization loss coefficient'

# Running
num_workers= 4 #"number of data loader workers"

#Stats file
# stats_file =open("./stats/self_sup/stats_file_"+currentTime+".json", "a", buffering=1)

exp_dir.mkdir(parents=True, exist_ok=True)
stats_file = open("./stats/self_sup/stats_file_"+currentTime+".txt", "a", buffering=1)
print(" ".join(sys.argv))
print(" ".join(sys.argv), file=stats_file)


/content/drive/.shortcut-targets-by-id/1RST5HayuSWDl47eVsz8vHQpNYKmjc800/computer_vision_project/vicreg
/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py -f /root/.local/share/jupyter/runtime/kernel-c0a9332a-ff12-4f80-bcbf-e081d223deb4.json


In [None]:
class FullGatherLayer(torch.autograd.Function):
    """
    Gather tensors from all process and support backward propagation
    for the gradients across processes.
    """

    @staticmethod
    def forward(ctx, x):
        output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
        dist.all_gather(output, x)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        all_gradients = torch.stack(grads)
        dist.all_reduce(all_gradients)
        return all_gradients[dist.get_rank()]

In [None]:
#@title VICReg and Projector

class VICReg(nn.Module):
    def __init__(self, mlp, arch, batch_size, sim_coeff, std_coeff, cov_coeff):
        super().__init__()
        self.mlp = mlp
        self.arch = arch
        self.batch_size = batch_size
        self.sim_coeff = sim_coeff
        self.std_coeff = std_coeff
        self.cov_coeff = cov_coeff
        self.num_features = int(mlp.split("-")[-1])
        self.backbone, self.embedding = resnet.__dict__[arch](
            zero_init_residual=True
        )
        self.projector = Projector(mlp, self.embedding)

    def forward(self, x, y):
        x = self.projector(self.backbone(x))
        y = self.projector(self.backbone(y))

        repr_loss = F.mse_loss(x, y)

        # x = torch.cat(FullGatherLayer.apply(x), dim=0)
        # y = torch.cat(FullGatherLayer.apply(y), dim=0)
        x = torch.cat([x], dim=0)
        y = torch.cat([y], dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        cov_x = (x.T @ x) / (self.batch_size - 1)
        cov_y = (y.T @ y) / (self.batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(
            self.num_features
        ) + off_diagonal(cov_y).pow_(2).sum().div(self.num_features)

        loss = (
            self.sim_coeff * repr_loss
            + self.std_coeff * std_loss
            + self.cov_coeff * cov_loss
        )
        return loss


def Projector(mlp, embedding):
    mlp_spec = f"{embedding}-{mlp}"
    layers = []
    f = list(map(int, mlp_spec.split("-")))
    for i in range(len(f) - 2):
        layers.append(nn.Linear(f[i], f[i + 1]))
        layers.append(nn.BatchNorm1d(f[i + 1]))
        layers.append(nn.ReLU(True))
    layers.append(nn.Linear(f[-2], f[-1], bias=False))
    return nn.Sequential(*layers)

def exclude_bias_and_norm(p):
    return p.ndim == 1

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

In [None]:
#@title LARS Optimizer and Adjust Learning Rate

class LARS(optim.Optimizer):
    def __init__(
        self,
        params,
        lr,
        weight_decay=0,
        momentum=0.9,
        eta=0.001,
        weight_decay_filter=None,
        lars_adaptation_filter=None,
    ):
        defaults = dict(
            lr=lr,
            weight_decay=weight_decay,
            momentum=momentum,
            eta=eta,
            weight_decay_filter=weight_decay_filter,
            lars_adaptation_filter=lars_adaptation_filter,
        )
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self):
        for g in self.param_groups:
            for p in g["params"]:
                dp = p.grad

                if dp is None:
                    continue

                if g["weight_decay_filter"] is None or not g["weight_decay_filter"](p):
                    dp = dp.add(p, alpha=g["weight_decay"])

                if g["lars_adaptation_filter"] is None or not g[
                    "lars_adaptation_filter"
                ](p):
                    param_norm = torch.norm(p)
                    update_norm = torch.norm(dp)
                    one = torch.ones_like(param_norm)
                    q = torch.where(
                        param_norm > 0.0,
                        torch.where(
                            update_norm > 0, (g["eta"] * param_norm / update_norm), one
                        ),
                        one,
                    )
                    dp = dp.mul(q)

                param_state = self.state[p]
                if "mu" not in param_state:
                    param_state["mu"] = torch.zeros_like(p)
                mu = param_state["mu"]
                mu.mul_(g["momentum"]).add_(dp)

                p.add_(mu, alpha=-g["lr"])

def adjust_learning_rate(optimizer, loader, step):
    max_steps = epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = base_lr * batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr

In [None]:
#@title Dataloader
from torch.utils.data import random_split
transforms = aug.TrainTransform()

train_dataset = datasets.ImageFolder(data_dir / "train", transforms)

#very small batch for testing!
dataset, extra = random_split(train_dataset, [0.1, 0.9])
loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    pin_memory=True,
    shuffle=True,
)

In [None]:
#@title Model config

model = VICReg(mlp, arch, batch_size, sim_coeff, std_coeff, cov_coeff).cuda(gpu)
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
optimizer = LARS(
    model.parameters(),
    lr=0,
    weight_decay=wd,
    weight_decay_filter=exclude_bias_and_norm,
    lars_adaptation_filter=exclude_bias_and_norm,
)

if (pretrained).is_file():
    print("resuming from checkpoint")
    # ckpt = torch.load(pretrained, map_location="cpu")
    # start_epoch = ckpt["epoch"]
    # model.load_state_dict(ckpt["model"])
    # optimizer.load_state_dict(ckpt["optimizer"])
    checkpoint = torch.load(pretrained, map_location='cpu')
    model_state_dict = {}
    opt_state_dict = {}
    for key in checkpoint['model']:
        # map the keys in the state dictionary to the current model
        new_key = key.replace("module.", "")
        model_state_dict[new_key] = checkpoint['model'][key]
    for key in checkpoint['optimizer']:
        # map the keys in the state dictionary to the current model
        new_key = key.replace("module.", "")
        opt_state_dict[new_key] = checkpoint['optimizer'][key]

    start_epoch = checkpoint["epoch"]
    model.load_state_dict(model_state_dict)
    optimizer.load_state_dict(opt_state_dict)
else:
    start_epoch = 0

resuming from checkpoint


In [None]:
#@title train sequence 
def train():
    lr = optimizer.param_groups[0]['lr']
    start_time = last_logging = time.time()
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(start_epoch, start_epoch+epochs):
        for step, ((x, y), _) in enumerate(loader, start=epoch * len(loader)):
            x = x.cuda(gpu, non_blocking=True)
            y = y.cuda(gpu, non_blocking=True)

            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                loss = model.forward(x, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            current_time = time.time()
            if current_time - last_logging > log_freq_time:
                stats = dict(
                  epoch=epoch,
                    step=step,
                    loss=loss.item(),
                    time=int(current_time - start_time),
                    lr=lr,
                )
                print(json.dumps(stats))
                print(json.dumps(stats), file=stats_file)
                last_logging = current_time

        state = dict(
            epoch=epoch + 1,
            model=model.state_dict(),
            optimizer=optimizer.state_dict(),
        )
        #save the full checkpoint
        torch.save(state, exp_dir / "resnet50_full.pth")

    #save just the backbone
    torch.save(model.backbone.state_dict(), exp_dir / "resnet50_bb.pth")


In [None]:
train()

{"epoch": 1000, "step": 11000, "loss": 23.628154754638672, "time": 106, "lr": 0.0032}
{"epoch": 1000, "step": 11004, "loss": 23.59101104736328, "time": 186, "lr": 0.0032}
{"epoch": 1000, "step": 11008, "loss": 23.7258358001709, "time": 280, "lr": 0.0032}
{"epoch": 1001, "step": 11011, "loss": 23.64259147644043, "time": 298, "lr": 0.0032}
{"epoch": 1001, "step": 11021, "loss": 23.61747169494629, "time": 309, "lr": 0.0032}
{"epoch": 1002, "step": 11022, "loss": 23.50779151916504, "time": 319, "lr": 0.0032}
{"epoch": 1002, "step": 11031, "loss": 23.55649185180664, "time": 329, "lr": 0.0032}
{"epoch": 1003, "step": 11033, "loss": 23.586668014526367, "time": 340, "lr": 0.0032}
{"epoch": 1003, "step": 11042, "loss": 23.505640029907227, "time": 350, "lr": 0.0032}
{"epoch": 1004, "step": 11044, "loss": 23.455398559570312, "time": 362, "lr": 0.0032}
{"epoch": 1004, "step": 11053, "loss": 23.413198471069336, "time": 372, "lr": 0.0032}
{"epoch": 1005, "step": 11055, "loss": 23.2421817779541, "tim