DINO for SECO
- Resnet 18

In [4]:
import copy
import torch
import torchvision
from torch import nn
import sys 
sys.path.insert(0, '/home/akansh-i2sc/Desktop/Study/HLCV/Why-Self-Supervision-in-Time/src/modules')
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
import seco_dataset_temporal as seco
import numpy as np
import matplotlib.pyplot as plt
import glob
from tqdm.auto import tqdm
from lightly.data import LightlyDataset

In [5]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512

In [6]:
model = DINO(backbone, input_dim)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device);

In [7]:
transform = DINOTransform()
path2data = "/home/akansh-i2sc/Desktop/Study/HLCV/SeCo_dataset/seco_100k/jpeg_40k/"
seco_dataset = LightlyDataset(input_dir=path2data, transform=transform)
dataloader = torch.utils.data.DataLoader(
    seco_dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=4,
)

In [15]:
torch.save(model.teacher_backbone.state_dict(), "teacher_backbone.pth")
model_new = torchvision.models.resnet18(weights=None)

In [18]:
model_new.load_state_dict(torch.load("teacher_backbone.pth"))

In [8]:
criterion = DINOLoss(
    output_dim=2048,
    warmup_teacher_temp_epochs=5,
)

In [9]:
criterion = criterion.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10

In [20]:
print("Starting Training")
for epoch in range(epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, epochs, 0.996, 1)
    for batch in tqdm(dataloader):
        views = batch[0]
        update_momentum(model.student_backbone, model.teacher_backbone, m=momentum_val)
        update_momentum(model.student_head, model.teacher_head, m=momentum_val)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [model.forward_teacher(view) for view in global_views]
        student_out = [model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        # We only cancel gradients of student head.
        model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")


Starting Training


100%|██████████| 312/312 [01:37<00:00,  3.19it/s]


epoch: 00, loss: 6.62916


100%|██████████| 312/312 [01:39<00:00,  3.13it/s]


epoch: 01, loss: 3.66713


100%|██████████| 312/312 [01:39<00:00,  3.12it/s]


epoch: 02, loss: 5.08117


100%|██████████| 312/312 [01:40<00:00,  3.12it/s]


epoch: 03, loss: 5.19397


100%|██████████| 312/312 [01:40<00:00,  3.11it/s]


epoch: 04, loss: 4.97491


100%|██████████| 312/312 [01:40<00:00,  3.12it/s]


epoch: 05, loss: 4.81806


100%|██████████| 312/312 [01:39<00:00,  3.13it/s]


epoch: 06, loss: 4.71214


100%|██████████| 312/312 [01:40<00:00,  3.12it/s]


epoch: 07, loss: 4.66292


100%|██████████| 312/312 [01:40<00:00,  3.12it/s]


epoch: 08, loss: 4.63802


100%|██████████| 312/312 [01:40<00:00,  3.11it/s]

epoch: 09, loss: 4.62389





#### Testing

In [39]:
resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
input_dim = 512