In [1]:
import copy
import torch

import torch.nn as nn
import pytorch_lightning as pl
import lightly.models.utils as utils
from typing import List
from torch.utils.data import DataLoader
import torchvision
from lightly.utils.benchmarking import BenchmarkModule
import torch.distributed as dist
from lightly.loss import MSNLoss
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_autoencoder import MAEBackbone
import torch.nn.functional as F

In [47]:
class VICRegLoss(nn.Module):
    def __init__(self,
                 #beta:float =1.0,
                 lamda:float =25.0,
                 mu: float = 25.0,
                 nu: float = 1.0,
                 gamma:float = 1.0,
                 epsilon: float = 0.001
                ) -> None:
        
        super().__init__()
        #self.beta = beta
        self.lamda = lamda
        self.mu = mu
        self.nu = nu
        self.gamma = gamma
        self.epsilon = epsilon

        #self.msn_loss = MSNLoss(regularization_weight=0)
    
    def _invar_loss(self,
                  Z1: torch.Tensor,
                  Z2: torch.Tensor,
                  ) -> torch.Tensor:
        Z2 = Z2.repeat((11,1))
        invar_loss = F.mse_loss(Z1,Z2)

        return invar_loss


    def _var_loss(self,
                Z1: torch.Tensor,
                Z2: torch.Tensor,
               ) -> torch.Tensor:
        Z2 = Z2.repeat((11,1))
        std_Z_1 = torch.sqrt(Z1.var(dim=0) + self.epsilon)
        std_Z_2 = torch.sqrt(Z2.var(dim=0) + self.epsilon)
        var_loss = torch.mean(torch.relu(self.gamma - std_Z_1)) + \
                   torch.mean(torch.relu(self.gamma - std_Z_2))

        return var_loss


    def _covar_loss(self,
                  Z1: torch.Tensor,
                  Z2: torch.Tensor
                  ) -> torch.Tensor:
        Z2 = Z2.repeat((11,1))
        n, d = Z1.shape
        Z1 = Z1-Z1.mean(dim=0)
        Z2 = Z2-Z2.mean(dim=0)
        cov_Z1 = torch.mm(Z1.T,Z1)/(n-1)
        cov_Z2 = torch.mm(Z2.T,Z2)/(n-1)
        covar_Z1_loss = (cov_Z1.sum() - cov_Z1.diagonal().sum()).pow(2) / d
        covar_Z2_loss = (cov_Z2.sum() - cov_Z2.diagonal().sum()).pow(2) / d
        covar_loss = covar_Z1_loss + covar_Z2_loss
        return covar_loss



    def forward(self,
              Z1: torch.Tensor,
              Z2: torch.Tensor,
              Z3:torch.Tensor
              ) -> torch.Tensor:
        #msn_loss = self.beta * self.msn_loss(Z1,Z2,Z3)
        var_loss = self.mu * self._var_loss(Z1,Z2)
        invar_loss = self.lamda * self._invar_loss(Z1,Z2)
        covar_loss = self.nu * self._covar_loss(Z1,Z2)

        return invar_loss,var_loss,covar_loss

In [48]:
class MSN1(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # ViT small configuration (ViT-S/16)
        self.mask_ratio = 0.15
        # self.backbone = MAEBackbone(
        #     image_size=224,
        #     patch_size=16,
        #     num_layers=12,
        #     num_heads=6,
        #     hidden_dim=384,
        #     mlp_dim=384 * 4,
        # )
        # # or use a torchvision ViT backbone:
        vit = torchvision.models.VisionTransformer(image_size=224,patch_size=16,num_layers=12,num_heads=6,hidden_dim=192,mlp_dim=192*4)
        self.backbone = MAEBackbone.from_vit(vit)
        self.projection_head = MSNProjectionHead(192,768,192)

        self.anchor_backbone = copy.deepcopy(self.backbone)
        self.anchor_projection_head = copy.deepcopy(self.projection_head)

        utils.deactivate_requires_grad(self.backbone)
        utils.deactivate_requires_grad(self.projection_head)

        self.prototypes = nn.Linear(192, 1024, bias=False).weight
        self.criterion = VICRegLoss(lamda=2)#VICRegLoss()

    def training_step(self, batch, batch_idx):
        utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
        utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)

        views1 = batch[0]
        views2 = batch[1]
        views1 = [view.to(self.device, non_blocking=True) for view in views1]
        views2 = [view.to(self.device, non_blocking=True) for view in views2]
        targets = views1[0]
        anchors = views2[1]
        anchors_focal = torch.concat(views2[2:], dim=0)

        targets_out = self.backbone(targets)
        targets_out = self.projection_head(targets_out)
        anchors_out = self.encode_masked(anchors)
        anchors_focal_out = self.encode_masked(anchors_focal)
        anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)

        inv,var,cov = self.criterion(anchors_out, targets_out, self.prototypes.data)
        #self.log("msn_loss", msn, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        self.log("inv_loss", inv, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        self.log("var_loss", var, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        self.log("covar_loss", cov, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        loss = inv + var + cov 
        self.log("train_loss", loss, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        
        return loss

    def encode_masked(self, anchors):
        batch_size, _, _, width = anchors.shape
        seq_length = (width // self.anchor_backbone.patch_size) ** 2
        idx_keep, _ = utils.random_token_mask(
            size=(batch_size, seq_length),
            mask_ratio=self.mask_ratio,
            device=self.device,
        )
        out = self.anchor_backbone(anchors, idx_keep=idx_keep)
        return self.anchor_projection_head(out)

    def configure_optimizers(self):
        params = [
            *list(self.anchor_backbone.parameters()),
            *list(self.anchor_projection_head.parameters()),
            self.prototypes,
        ]
        optim = torch.optim.AdamW(params, lr=0.0001,weight_decay=0.001)
        return optim

In [49]:
import os
import glob
import torch
import random
import pandas as pd
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from typing import Tuple, Optional,Dict,Union
from torch.utils.data import Dataset
from lightly.transforms.msn_transform import MSNTransform

In [50]:
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor


from lightly.transforms.multi_view_transform import MultiViewTransform
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

MIMIC_NORMALIZE ={"mean":torch.tensor([0.4723, 0.4723, 0.4723]), 
                  "std":torch.tensor([0.3023, 0.3023, 0.3023])}


In [51]:
from lightly.transforms.multi_view_transform import MultiViewTransform
class MSNTransform(MultiViewTransform):
    """Implements the transformations for MSN [0].

    Input to this transform:
        PIL Image or Tensor.

    Output of this transform:
        List of Tensor of length 2 * random_views + focal_views. (12 by default)

    Applies the following augmentations by default:
        - Random resized crop
        - Random horizontal flip
        - ImageNet normalization

    Generates a set of random and focal views for each input image. The generated output
    is (views, target, filenames) where views is list with the following entries:
    [random_views_0, random_views_1, ..., focal_views_0, focal_views_1, ...].

    - [0]: Masked Siamese Networks, 2022: https://arxiv.org/abs/2204.07141

    Attributes:
        random_size:
            Size of the random image views in pixels.
        focal_size:
            Size of the focal image views in pixels.
        random_views:
            Number of random views to generate.
        focal_views:
            Number of focal views to generate.
        random_crop_scale:
            Minimum and maximum size of the randomized crops for the relative to random_size.
        focal_crop_scale:
            Minimum and maximum size of the randomized crops relative to focal_size.
        hf_prob:
            Probability that horizontal flip is applied.
        vf_prob:
            Probability that vertical flip is applied.
        normalize:
            Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize.
    """

    def __init__(
        self,
        random_size: int = 224,
        focal_size: int = 96,
        random_views: int = 2,
        focal_views: int = 10,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        random_crop_scale: Tuple[float, float] = (0.3, 1.0),
        focal_crop_scale: Tuple[float, float] = (0.05, 0.3),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_STAT,
    ):
        random_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=random_size,
            crop_scale=random_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        focal_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=focal_size,
            crop_scale=focal_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        transforms = [random_view_transform] * random_views
        transforms += [focal_view_transform] * focal_views
        super().__init__(transforms=transforms)

In [52]:
class MSNViewTransform:
    def __init__(
        self,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        crop_size: int = 224,
        crop_scale: Tuple[float, float] = (0.3, 1.0),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_STAT,
    ):

        transform = [
            T.RandomAffine(degrees=affine_dgrees, 
                          scale=affine_scale, 
                          shear=affine_shear, 
                          translate=affine_translate),
            T.RandomResizedCrop(size=crop_size, scale=crop_scale),
            T.RandomHorizontalFlip(p=hf_prob),
            T.RandomVerticalFlip(p=vf_prob),
            T.ToTensor(),
            T.Normalize(mean=normalize["mean"], std=normalize["std"]),
        ]

        self.transform = T.Compose(transform)

    def __call__(self, image: Union[torch.Tensor, Image]) -> torch.Tensor:
        """
        Applies the transforms to the input image.

        Args:
            image:
                The input image to apply the transforms to.

        Returns:
            The transformed image.

        """
        transformed = self.transform(image)
        return transformed

In [53]:
import os
import torch
import torch.nn as nn 
import pytorch_lightning as pl

import random
from collections import Counter
from typing import Tuple, Optional
from torch import Tensor
from PIL import Image
from torch.utils.data import Dataset

import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T
from tqdm import tqdm
from torch import Tensor
from copy import deepcopy
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader

In [54]:
class ChexMSNDataset(Dataset):
    def __init__(self, 
                 data_dir: str,
                 transforms: nn.Module,
                 same = True
                 ) -> None:
      
        self.meta = pd.read_csv(data_dir)
        self.all_images = list(self.meta.path)
        self.transforms = transforms
        self.same = same
        
    def __len__(self
                ) -> int:
        return len(self.all_images)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[torch.Tensor]:

        
        target_path = self.all_images[index]
        image_id = target_path.split('/')[-1][:-4]
        img_gender_path = self._retrieve_anchors(image_id=image_id,
                                              meta = self.meta,
                                              same=self.same)

        img_target = Image.open(fp=target_path).convert('RGB')
        img_target = self.transforms(img_target)
        
#         img_age = Image.open(fp=img_age_path).convert('RGB')
#         img_age = self.transforms(img_age)

        img_gender = Image.open(fp=img_gender_path).convert('RGB')
        img_gender = self.transforms(img_gender)

        return (img_target,img_gender)#,img_gender)
    
    
    def _retrieve_anchors(self,
                          image_id: str,
                          meta: pd.DataFrame,
                          same: bool = False) -> Tuple[str]:
        record = meta[meta.dicom_id == image_id]
    
        subject_id = list(record.subject_id)[0]
        age_groub =list(record.ageR10)[0] 
        gender = list(record.gender)[0]
    
        group = meta[meta.ageR10 == age_groub]
    
        if same:
            candidate_anchors = group[group.gender == gender]
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            sampled_images = random.sample(images,k=2)
            image_age, image_gender = sampled_images[0],sampled_images[1]
            image_age = sampled_images[0]
            return image_gender #,image_age
        else:
            candidate_anchors = group
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            image_age = random.sample(images,k=1)[0]
            #candidate_anchors = candidate_anchors[candidate_anchors.gender == gender]
            #images= list(candidate_anchors.path)
            #image_gender = random.sample(images,k=1)[0]
            return image_age#, image_gender

In [55]:
data_dir = '../data/meta.csv'

In [56]:
dataset = ChexMSNDataset(data_dir,transforms=MSNTransform(focal_views=10))

In [57]:
dataloader = DataLoader(dataset=dataset,
                              batch_size=64,
                              num_workers=16,
                              pin_memory=True,
                              shuffle=True
                              )

In [58]:
# plt.imshow(next(iter(dataloader))[0][0][0][0],cmap='gray')

In [59]:
# plt.imshow(next(iter(dataloader))[1][0][0][0],cmap='gray')

In [60]:
model = MSN1()

In [61]:
trainer = pl.Trainer(max_epochs=100,
                     log_every_n_steps=1,
                     precision='16-mixed'
                     )

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=model, train_dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type              | Params
-------------------------------------------------------------
0 | backbone               | MAEBackbone       | 5.7 M 
1 | projection_head        | MSNProjectionHead | 888 K 
2 | anchor_backbone        | MAEBackbone       | 5.7 M 
3 | anchor_projection_head | MSNProjectionHead | 888 K 
4 | criterion              | VICRegLoss        | 0     
  | other params           | n/a               | 196 K 
-------------------------------------------------------------
6.8 M     Trainable params
6.6 M     Non-trainable params
13.4 M    Total params
53.630    Total estimated model params size (MB)


Epoch 0:  81%|████████▏ | 4791/5879 [51:49<11:46,  1.54it/s, v_num=13, inv_loss_step=3.600, var_loss_step=3.460, covar_loss_step=1.490, train_loss_step=8.550]   