In [1]:
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
import os
from lightly.data import LightlyDataset
from tqdm.auto import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys 
sys.path.insert(0, '/home/akansh-i2sc/Desktop/Study/HLCV/Why-Self-Supervision-in-Time/src/modules')
import seco_dataset_temporal as seco

In [3]:
class DINO(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18()
        backbone = nn.Sequential(*list(resnet.children())[:-1])
        input_dim = 512
        # instead of a resnet you can also use a vision transformer backbone as in the
        # original paper (you might have to reduce the batch size in this case):
        # backbone = torch.hub.load('facebookresearch/dino:main', 'dino_vits16', pretrained=False)
        # input_dim = backbone.embed_dim

        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

        self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

    def training_step(self, batch, batch_idx):
        momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
        update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
        update_momentum(self.student_head, self.teacher_head, m=momentum)
        # views = batch[0]
        views_1 = batch[0]
        views_2 = batch[1]
        # views = [view.to(self.device) for view in views]
        views_1 = [view.to(self.device) for view in views_1[0:4:2]]
        views_2 = [view.to(self.device) for view in views_2[0:4:2]]
        views = views_1
        views.extend(views_2)               # Teporal local+global views from 1 and 2
        global_views = views_1[:1]          # Teporal global view 1
        global_views.extend(views_2[:1])    # Teporal global view 2
        # global_views = views[:2]
        teacher_out = [self.forward_teacher(view) for view in global_views]
        student_out = [self.forward(view) for view in views]
        loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)
        return loss

    def on_after_backward(self):
        self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)

    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), lr=0.001)
        return optim


In [4]:
path2data = "/home/akansh-i2sc/Desktop/Study/HLCV/SeCo_dataset/seco_100k/seasonal_contrast_100k/"
transform = DINOTransform(cj_prob = 0, random_gray_scale = 0,gaussian_blur = (0,0,0), solarization_prob = 0)
seco_dataset = seco.SeasonalContrastTemporal(root=path2data, transform=transform)

In [5]:
dataloader = torch.utils.data.DataLoader(
    seco_dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=4,
)

In [6]:
model = DINO()

In [7]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
print(accelerator)
trainer = pl.Trainer(max_epochs=100, devices=1, accelerator=accelerator)

# trainer = pl.Trainer(
#     max_epochs=1,
#     devices="auto",
#     accelerator="gpu",
#     strategy="ddp",
#     sync_batchnorm=True,
#     replace_sampler_ddp=True,  # or replace_sampler_ddp=True for PyTorch Lightning <2.0
# )

trainer.fit(model=model, train_dataloaders=dataloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


gpu


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type               | Params
--------------------------------------------------------
0 | student_backbone | Sequential         | 11.2 M
1 | student_head     | DINOProjectionHead | 691 K 
2 | teacher_backbone | Sequential         | 11.2 M
3 | teacher_head     | DINOProjectionHead | 691 K 
4 | criterion        | DINOLoss           | 0     
--------------------------------------------------------
11.9 M    Trainable params
11.9 M    Non-trainable params
23.7 M    Total params
94.942    Total estimated model params size (MB)


Epoch 11:   0%|          | 1/312 [00:02<11:31,  2.22s/it, loss=3.59, v_num=2]  



Epoch 12:   0%|          | 1/312 [00:01<09:25,  1.82s/it, loss=3.69, v_num=2]  



Epoch 13:   0%|          | 1/312 [00:01<08:42,  1.68s/it, loss=3.53, v_num=2]  



Epoch 14:   0%|          | 1/312 [00:01<08:57,  1.73s/it, loss=3.44, v_num=2]  



Epoch 15:   0%|          | 1/312 [00:01<08:59,  1.73s/it, loss=3.39, v_num=2]  



Epoch 16:   0%|          | 1/312 [00:01<10:03,  1.94s/it, loss=3.25, v_num=2]  



Epoch 17:   0%|          | 1/312 [00:01<08:59,  1.74s/it, loss=3.06, v_num=2]  



Epoch 18:   0%|          | 1/312 [00:01<09:20,  1.80s/it, loss=2.99, v_num=2]  



Epoch 19:   0%|          | 1/312 [00:01<08:49,  1.70s/it, loss=2.79, v_num=2]  



Epoch 20:   0%|          | 1/312 [00:01<08:40,  1.67s/it, loss=2.73, v_num=2]  



Epoch 21:   0%|          | 1/312 [00:01<08:37,  1.66s/it, loss=2.7, v_num=2]   



Epoch 22:   0%|          | 1/312 [00:01<09:08,  1.76s/it, loss=2.61, v_num=2]  



Epoch 23:   0%|          | 1/312 [00:01<08:40,  1.67s/it, loss=2.58, v_num=2]  



Epoch 24:   0%|          | 1/312 [00:01<08:33,  1.65s/it, loss=2.47, v_num=2]  



Epoch 25:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=2.52, v_num=2]  



Epoch 26:   0%|          | 1/312 [00:01<08:38,  1.67s/it, loss=2.49, v_num=2]  



Epoch 27:   0%|          | 1/312 [00:01<09:04,  1.75s/it, loss=2.36, v_num=2]  



Epoch 28:   0%|          | 1/312 [00:01<08:50,  1.71s/it, loss=2.47, v_num=2]  



Epoch 29:   0%|          | 1/312 [00:01<08:52,  1.71s/it, loss=2.41, v_num=2]  



Epoch 30:   0%|          | 1/312 [00:01<08:42,  1.68s/it, loss=2.42, v_num=2]  



Epoch 31:   0%|          | 1/312 [00:01<08:34,  1.65s/it, loss=2.4, v_num=2]   



Epoch 32:   0%|          | 1/312 [00:01<08:52,  1.71s/it, loss=2.28, v_num=2]  



Epoch 33:   0%|          | 1/312 [00:01<10:05,  1.95s/it, loss=2.31, v_num=2]  



Epoch 34:   0%|          | 1/312 [00:01<08:53,  1.72s/it, loss=2.47, v_num=2]  



Epoch 35:   0%|          | 1/312 [00:01<09:04,  1.75s/it, loss=2.26, v_num=2]  



Epoch 36:   0%|          | 1/312 [00:01<08:59,  1.73s/it, loss=2.21, v_num=2]  



Epoch 37:   0%|          | 1/312 [00:01<09:02,  1.75s/it, loss=2.16, v_num=2]  



Epoch 38:   0%|          | 1/312 [00:01<08:55,  1.72s/it, loss=2.15, v_num=2]  



Epoch 39:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=2.18, v_num=2]  



Epoch 40:   0%|          | 1/312 [00:01<08:46,  1.69s/it, loss=2.09, v_num=2]  



Epoch 41:   0%|          | 1/312 [00:01<08:36,  1.66s/it, loss=2.07, v_num=2]  



Epoch 42:   0%|          | 1/312 [00:01<08:42,  1.68s/it, loss=2.18, v_num=2]  



Epoch 43:   0%|          | 1/312 [00:01<08:55,  1.72s/it, loss=1.95, v_num=2]  



Epoch 44:   0%|          | 1/312 [00:01<09:16,  1.79s/it, loss=2, v_num=2]     



Epoch 45:   0%|          | 1/312 [00:01<08:57,  1.73s/it, loss=1.94, v_num=2]  



Epoch 46:   0%|          | 1/312 [00:01<08:47,  1.70s/it, loss=2.03, v_num=2]  



Epoch 47:   0%|          | 1/312 [00:01<08:51,  1.71s/it, loss=1.96, v_num=2]  



Epoch 48:   0%|          | 1/312 [00:01<08:54,  1.72s/it, loss=2.09, v_num=2]  



Epoch 49:   0%|          | 1/312 [00:01<08:47,  1.70s/it, loss=1.96, v_num=2]  



Epoch 50:   0%|          | 1/312 [00:01<09:14,  1.78s/it, loss=1.99, v_num=2]  



Epoch 51:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=2.03, v_num=2]  



Epoch 52:   0%|          | 1/312 [00:01<09:01,  1.74s/it, loss=2.03, v_num=2]  



Epoch 53:   0%|          | 1/312 [00:01<09:01,  1.74s/it, loss=2, v_num=2]     



Epoch 54:   0%|          | 1/312 [00:01<09:21,  1.81s/it, loss=2.01, v_num=2]  



Epoch 55:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=1.89, v_num=2]  



Epoch 56:   0%|          | 1/312 [00:01<08:55,  1.72s/it, loss=1.79, v_num=2]  



Epoch 57:   0%|          | 1/312 [00:01<09:03,  1.75s/it, loss=1.89, v_num=2]  



Epoch 58:   0%|          | 1/312 [00:01<09:07,  1.76s/it, loss=1.81, v_num=2]  



Epoch 59:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=2, v_num=2]     



Epoch 60:   0%|          | 1/312 [00:01<09:04,  1.75s/it, loss=1.84, v_num=2]  



Epoch 61:   0%|          | 1/312 [00:01<08:57,  1.73s/it, loss=1.92, v_num=2]  



Epoch 62:   0%|          | 1/312 [00:01<09:09,  1.77s/it, loss=1.85, v_num=2]  



Epoch 63:   0%|          | 1/312 [00:01<08:55,  1.72s/it, loss=1.89, v_num=2]  



Epoch 64:   0%|          | 1/312 [00:01<09:22,  1.81s/it, loss=1.83, v_num=2]  



Epoch 65:   0%|          | 1/312 [00:01<09:13,  1.78s/it, loss=1.93, v_num=2]  



Epoch 66:   0%|          | 1/312 [00:01<08:59,  1.73s/it, loss=1.79, v_num=2]  



Epoch 67:   0%|          | 1/312 [00:01<08:57,  1.73s/it, loss=1.89, v_num=2]  



Epoch 68:   0%|          | 1/312 [00:01<09:08,  1.77s/it, loss=1.78, v_num=2]  



Epoch 69:   0%|          | 1/312 [00:01<09:04,  1.75s/it, loss=1.89, v_num=2]  



Epoch 70:   0%|          | 1/312 [00:01<09:11,  1.77s/it, loss=1.74, v_num=2]  



Epoch 71:   0%|          | 1/312 [00:01<09:25,  1.82s/it, loss=1.94, v_num=2]  



Epoch 72:   0%|          | 1/312 [00:01<08:53,  1.72s/it, loss=1.92, v_num=2]  



Epoch 73:   0%|          | 1/312 [00:01<09:20,  1.80s/it, loss=1.74, v_num=2]  



Epoch 74:   0%|          | 1/312 [00:01<09:15,  1.79s/it, loss=1.75, v_num=2]  



Epoch 75:   0%|          | 1/312 [00:01<08:56,  1.72s/it, loss=1.79, v_num=2]  



Epoch 76:   0%|          | 1/312 [00:01<08:50,  1.71s/it, loss=1.92, v_num=2]  



Epoch 77:   0%|          | 1/312 [00:01<09:24,  1.81s/it, loss=1.66, v_num=2]  



Epoch 78:   0%|          | 1/312 [00:02<10:52,  2.10s/it, loss=1.73, v_num=2]  



Epoch 79:   0%|          | 1/312 [00:01<09:13,  1.78s/it, loss=1.67, v_num=2]  



Epoch 80:   0%|          | 1/312 [00:01<09:03,  1.75s/it, loss=1.59, v_num=2]  



Epoch 81:   0%|          | 1/312 [00:01<09:22,  1.81s/it, loss=1.78, v_num=2]  



Epoch 82:   0%|          | 1/312 [00:01<09:50,  1.90s/it, loss=1.75, v_num=2]  



Epoch 83:   0%|          | 1/312 [00:01<08:45,  1.69s/it, loss=1.74, v_num=2]  



Epoch 84:   0%|          | 1/312 [00:01<09:16,  1.79s/it, loss=1.85, v_num=2]  



Epoch 85:   0%|          | 1/312 [00:01<09:18,  1.79s/it, loss=1.86, v_num=2]  



Epoch 86:   0%|          | 1/312 [00:01<08:49,  1.70s/it, loss=1.91, v_num=2]  



Epoch 87:   0%|          | 1/312 [00:01<09:45,  1.88s/it, loss=1.83, v_num=2]  



Epoch 88:   0%|          | 1/312 [00:01<09:22,  1.81s/it, loss=1.84, v_num=2]  



Epoch 89:   0%|          | 1/312 [00:01<09:12,  1.78s/it, loss=1.72, v_num=2]  



Epoch 90:   0%|          | 1/312 [00:01<10:13,  1.97s/it, loss=1.69, v_num=2]  



Epoch 91:   0%|          | 1/312 [00:01<10:18,  1.99s/it, loss=1.78, v_num=2]  



Epoch 92:   0%|          | 0/312 [00:00<?, ?it/s, loss=1.81, v_num=2]          



Epoch 93:   0%|          | 1/312 [00:01<09:55,  1.92s/it, loss=1.61, v_num=2]  



Epoch 94:   0%|          | 0/312 [00:00<?, ?it/s, loss=1.73, v_num=2]          



Epoch 95:   0%|          | 1/312 [00:01<08:51,  1.71s/it, loss=1.69, v_num=2]  



Epoch 96:   0%|          | 1/312 [00:01<09:40,  1.87s/it, loss=1.66, v_num=2]  



Epoch 97:   0%|          | 1/312 [00:01<09:59,  1.93s/it, loss=1.59, v_num=2]  



Epoch 98:   0%|          | 1/312 [00:01<09:10,  1.77s/it, loss=1.7, v_num=2]   



Epoch 99:   0%|          | 1/312 [00:01<09:24,  1.81s/it, loss=1.91, v_num=2]  



Epoch 99: 100%|██████████| 312/312 [01:30<00:00,  3.43it/s, loss=1.61, v_num=2]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 312/312 [01:31<00:00,  3.43it/s, loss=1.61, v_num=2]


In [8]:
pretrained_resnet_backbone = model.teacher_backbone
state_dict = {"resnet18_parameters": pretrained_resnet_backbone.state_dict()}
torch.save(state_dict, "/home/akansh-i2sc/Desktop/Study/HLCV/Why-Self-Supervision-in-Time/src/models/pre-trained_weights/resnet18_model_dino_temporal_20k_100e.pth")