In [12]:
import os
import copy
import torchvision.transforms as T
import pytorch_lightning as pl
import torch
import torchvision
import torch.optim as optim
from PIL import Image
from torch import nn
from torch.utils.data import Dataset
from typing import Tuple, Optional
from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule
from lightly.models.modules.masked_autoencoder import MAEBackbone

In [2]:
transform1 = DINOTransform(cj_prob=0,random_gray_scale=0,gaussian_blur=(0,0,0),sigmas=(0,0),solarization_prob=0)

class BaselinesDataset(Dataset):
    def __init__(self, 
                 data_dir: str, 

                 ) -> None:
      
        self.data_dir = data_dir
        self.all_images = os.listdir(self.data_dir)
        for image in self.all_images:
            if image.startswith('._'):
                self.all_images.remove(image)
        
    def __len__(self) -> int:
        return len(self.all_images)
    
    def __getitem__(self,index: int
                    ) -> Tuple[torch.Tensor]:
        name = self.all_images[index]
        path = os.path.join(self.data_dir, name)
        img = Image.open(fp=path).convert('RGB')
        img = transform(img)
        return img, index, name

In [29]:
class DINO(pl.LightningModule):
    def __init__(self,
                 learning_rate: float= 0.0001,
                 weight_decay:float= 0.000,
                 max_epochs: int = 100,
                 
                )-> None:
        super().__init__()
#         resnet = torchvision.models.resnet18()
#         backbone = nn.Sequential(*list(resnet.children())[:-1])
#         input_dim = 512
        # instead of a resnet you can also use a vision transformer backbone as in the
        # original paper (you might have to reduce the batch size in this case):
        self.learning_rate = learning_rate
        self.weight_decay = weight_decay
        self.max_epochs = max_epochs
        
        vit = torchvision.models.VisionTransformer(image_size=224,
                                                   patch_size=16,
                                                   num_layers=12,
                                                   num_heads=6,
                                                   hidden_dim=192,
                                                   mlp_dim=192 * 4,
                                                   )
        backbone =MAEBackbone.from_vit(vit)
        input_dim = backbone.hidden_dim

        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 768, 256, 2048, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 768, 256, 2048)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

        self.criterion = DINOLoss(output_dim=2048, warmup_teacher_temp_epochs=5)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

    def training_step(self, batch, batch_idx):
        momentum = cosine_schedule(self.current_epoch, 10, 0.996, 1)
        update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
        update_momentum(self.student_head, self.teacher_head, m=momentum)
        views = batch[0]
        views = [view.to(self.device) for view in views]
        global_views = views[:2]
        teacher_out = [self.forward_teacher(view) for view in global_views]
        student_out = [self.forward(view) for view in views]
        loss = self.criterion(teacher_out, student_out, epoch=self.current_epoch)
        self.log("train_loss", loss, on_epoch= True,on_step=True , logger=True,prog_bar=True)
        return loss

    def on_after_backward(self):
        self.student_head.cancel_last_layer_gradients(current_epoch=self.current_epoch)

    
    def configure_optimizers(self):

        optimizer = optim.AdamW(params=self.parameters(), 
                                   lr=self.learning_rate, 
                                   weight_decay=self.weight_decay
                                   )

        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                         eta_min=0,
                                                         T_max=self.max_epochs
                                                         )
        
        return {'optimizer': optimizer,
                'lr_scheduler': scheduler
               }




In [30]:
model = DINO()

In [31]:
transform = DINOTransform()
# we ignore object detection annotations by setting target_transform to return 0
data_dir = '/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0/resized/'
dataset = BaselinesDataset(data_dir=data_dir)

In [32]:

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    drop_last=True,
    num_workers=24,
    pin_memory=True
)

In [None]:
# or create a dataset from a folder containing images or videos:
# dataset = LightlyDataset("path/to/folder")

torch.set_float32_matmul_precision('medium')
trainer = pl.Trainer(max_epochs=100, 
                     devices='auto', 
                     accelerator='auto',
                     precision='16-mixed',
                     log_every_n_steps=1,)
trainer.fit(model=model, train_dataloaders=dataloader)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type               | Params
--------------------------------------------------------
0 | student_backbone | MAEBackbone        | 5.7 M 
1 | student_head     | DINOProjectionHead | 1.5 M 
2 | teacher_backbone | MAEBackbone        | 5.7 M 
3 | teacher_head     | DINOProjectionHead | 1.5 M 
4 | criterion        | DINOLoss           | 0     
--------------------------------------------------------
7.2 M     Trainable params
7.2 M     Non-trainable params
14.4 M    Total params
57.435    Total estimated model params size (MB)


Epoch 8:  75%|███████▌  | 4446/5892 [22:23<07:17,  3.31it/s, v_num=4, train_loss_step=1.940, train_loss_epoch=1.800]