In [1]:
import copy
import torch

import torch.nn as nn
import pytorch_lightning as pl
import lightly.models.utils as utils
from typing import List
from torch.utils.data import DataLoader
import torchvision
from lightly.utils.benchmarking import BenchmarkModule
import torch.distributed as dist
from lightly.loss import MSNLoss
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_autoencoder import MAEBackbone
import torch.nn.functional as F

from transformers import ViTForImageClassification

  from .autonotebook import tqdm as notebook_tqdm


In [192]:
def mask_tensor(tensor: torch.tensor):
    rows,cols = tensor.size()[0], tensor.size()[1]
    for i in range(rows):
        rand_idx = random.randint(0,cols-1) 
        flag = True
        while flag:
            if tensor[i,rand_idx] == 0.0:
                rand_idx = random.randint(0,cols-1)
            else:
                flag = False
        tensor[i,rand_idx] = 0.0
    return tensor

In [234]:
class MSN2(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.mask_ratio = 0.15
        
        # ehr embedding layer
        self.ehr_embed = nn.Linear(in_features=3,out_features=64)
        
        # cxr-ehr project layer
        self.ehr_cxr_project = nn.Linear(in_features=64+192,out_features=192)
        
        vit = torchvision.models.VisionTransformer(image_size=224,
                                                   patch_size=16,
                                                   num_layers=12,
                                                   num_heads=6,
                                                   hidden_dim=192,
                                                   mlp_dim=192*4)
        
        self.backbone = MAEBackbone.from_vit(vit)
        self.projection_head = MSNProjectionHead(192,2048,1024)

        self.anchor_backbone = copy.deepcopy(self.backbone)
        self.anchor_projection_head = copy.deepcopy(self.projection_head)
        self.anchor_ehr_cxr_project = copy.deepcopy(self.ehr_cxr_project)

        utils.deactivate_requires_grad(self.backbone)
        utils.deactivate_requires_grad(self.projection_head)
        utils.deactivate_requires_grad(self.ehr_cxr_project)
        
        self.prototypes = nn.Linear(1024, 1024, bias=False).weight
        self.criterion = MSNLoss()

    def training_step(self, batch, batch_idx):
        utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
        utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)
        utils.update_momentum(self.anchor_ehr_cxr_project, self.ehr_cxr_project, 0.996)
        
        views, ehr = batch[0], batch[1]
        ehr_masked = mask_tensor(ehr)
        views = [view.to(self.device, non_blocking=True) for view in views]
        ehr = ehr.to(self.device)
        ehr_masked = ehr_masked.to(self.device)
        
        targets = views[0]
        anchors = views[1]
        anchors_focal = torch.concat(views[2:], dim=0)
        
        ehr_embed = self.ehr_embed(ehr)
        ehr_masked_embed = self.ehr_embed(ehr_masked)
        ehr_masked_embed = ehr_embed.repeat((11,1))
        

        targets_out = self.backbone(targets)
        targets_out = self.ehr_cxr_project(torch.cat((targets_out,ehr_embed),1))
        targets_out = self.projection_head(targets_out)
        
        anchors_out = self.encode_masked(anchors)
        anchors_focal_out = self.encode_masked(anchors_focal)
        anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)
        anchors_out = self.anchor_ehr_cxr_project(torch.cat((anchors_out,ehr_masked_embed),1))
        anchors_out = self.anchor_projection_head(anchors_out)
        
        loss = self.criterion(anchors_out, targets_out, self.prototypes.data) 
        self.log("train_loss", loss, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        
        return loss


    def encode_masked(self, anchors):
        batch_size, _, _, width = anchors.shape
        seq_length = (width // self.anchor_backbone.patch_size) ** 2
        idx_keep, _ = utils.random_token_mask(
            size=(batch_size, seq_length),
            mask_ratio=self.mask_ratio,
            device=self.device,
        )
        out = self.anchor_backbone(anchors, idx_keep=idx_keep)
        return out

    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.parameters(), 
                                      lr=0.0001,
                                      weight_decay=0.001)
        
#         scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
#                                                                eta_min=0.0000001,
#                                                                T_max=100)
        return {'optimizer': optimizer,
               #'lr_scheduler': scheduler
               }


In [235]:
model = MSN2()
model.training_step(next(iter(dataloader)),0)

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/core/module.py:420: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


tensor(6.9527, grad_fn=<AddBackward0>)

In [236]:
import os
import glob
import torch
import random
import pandas as pd
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from typing import Tuple, Optional,Dict,Union
from torch.utils.data import Dataset
from lightly.transforms.msn_transform import MSNTransform

In [237]:
from typing import Dict, List, Optional, Tuple, Union

import torchvision.transforms as T
from PIL.Image import Image
from torch import Tensor


from lightly.transforms.multi_view_transform import MultiViewTransform
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

# MIMIC_NORMALIZE ={"mean":torch.tensor([0.4723, 0.4723, 0.4723]), 
#                   "std":torch.tensor([0.3023, 0.3023, 0.3023])}


In [238]:
from lightly.transforms.multi_view_transform import MultiViewTransform
class MSNTransform(MultiViewTransform):
    """Implements the transformations for MSN [0].

    Input to this transform:
        PIL Image or Tensor.

    Output of this transform:
        List of Tensor of length 2 * random_views + focal_views. (12 by default)

    Applies the following augmentations by default:
        - Random resized crop
        - Random horizontal flip
        - ImageNet normalization

    Generates a set of random and focal views for each input image. The generated output
    is (views, target, filenames) where views is list with the following entries:
    [random_views_0, random_views_1, ..., focal_views_0, focal_views_1, ...].

    - [0]: Masked Siamese Networks, 2022: https://arxiv.org/abs/2204.07141

    Attributes:
        random_size:
            Size of the random image views in pixels.
        focal_size:
            Size of the focal image views in pixels.
        random_views:
            Number of random views to generate.
        focal_views:
            Number of focal views to generate.
        random_crop_scale:
            Minimum and maximum size of the randomized crops for the relative to random_size.
        focal_crop_scale:
            Minimum and maximum size of the randomized crops relative to focal_size.
        hf_prob:
            Probability that horizontal flip is applied.
        vf_prob:
            Probability that vertical flip is applied.
        normalize:
            Dictionary with 'mean' and 'std' for torchvision.transforms.Normalize.
    """

    def __init__(
        self,
        random_size: int = 224,
        focal_size: int = 96,
        random_views: int = 2,
        focal_views: int = 10,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        random_crop_scale: Tuple[float, float] = (0.3, 1.0),
        focal_crop_scale: Tuple[float, float] = (0.05, 0.3),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_STAT,
    ):
        random_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=random_size,
            crop_scale=random_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        focal_view_transform = MSNViewTransform(
            affine_dgrees=affine_dgrees,
            affine_scale=affine_scale,
            affine_shear=affine_shear,
            affine_translate=affine_translate,
            crop_size=focal_size,
            crop_scale=focal_crop_scale,
            hf_prob=hf_prob,
            vf_prob=vf_prob,
            normalize=normalize,
        )
        transforms = [random_view_transform] * random_views
        transforms += [focal_view_transform] * focal_views
        super().__init__(transforms=transforms)

In [239]:
class MSNViewTransform:
    def __init__(
        self,
        affine_dgrees: int = 15,
        affine_scale: Tuple[float,float]= (.9, 1.1),
        affine_shear: int = 0,
        affine_translate: Tuple[float,float] = (0.1, 0.1),
        crop_size: int = 224,
        crop_scale: Tuple[float, float] = (0.3, 1.0),
        hf_prob: float = 0.5,
        vf_prob: float = 0.0,
        normalize: Dict[str, List[float]] = IMAGENET_STAT,
    ):

        transform = [
            T.RandomAffine(degrees=affine_dgrees, 
                          scale=affine_scale, 
                          shear=affine_shear, 
                          translate=affine_translate),
            T.RandomResizedCrop(size=crop_size, scale=crop_scale),
            T.RandomHorizontalFlip(p=hf_prob),
            T.RandomVerticalFlip(p=vf_prob),
            T.ToTensor(),
            T.Normalize(mean=normalize["mean"], std=normalize["std"]),
        ]

        self.transform = T.Compose(transform)

    def __call__(self, image: Union[torch.Tensor, Image]) -> torch.Tensor:
        """
        Applies the transforms to the input image.

        Args:
            image:
                The input image to apply the transforms to.

        Returns:
            The transformed image.

        """
        transformed = self.transform(image)
        return transformed

In [240]:
import os
import torch
import torch.nn as nn 
import pytorch_lightning as pl
import numpy as np
import random
from collections import Counter
from typing import Tuple, Optional
from torch import Tensor
from PIL import Image
from torch.utils.data import Dataset

import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T

from torch import Tensor
from copy import deepcopy
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader

In [241]:
class ChexMSNDataset(Dataset):
    def __init__(self, 
                 data_dir: str,
                 transforms: nn.Module,
                 ) -> None:
      
        self.meta = pd.read_csv(data_dir)
        self.transforms = transforms
        self.ehr = self.meta.to_numpy()[:,8:].astype('float32')

        
    def __len__(self
                ) -> int:
        return len(self.meta)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[torch.Tensor]:

        
        img = self.meta['path'][index]
        img = Image.open(fp=img).convert('RGB')
        img = self.transforms(img)
        ehr = torch.from_numpy(self.ehr[index])

        
        return img, ehr
    
    


In [242]:
data_dir = '../data/meta-age-norm.csv'
# pd.read_csv(data_dir)

In [243]:
dataset = ChexMSNDataset(data_dir,transforms=MSNTransform())

In [244]:
dataloader = DataLoader(dataset=dataset,
                              batch_size=64,
                              num_workers=24,
                              pin_memory=True,
                              shuffle=True
                              )

In [245]:
# next(iter(dataloader))[0]

In [246]:
next(iter(dataloader))[1].shape

torch.Size([64, 3])

In [247]:
model = MSN2()

In [248]:
trainer = pl.Trainer(max_epochs=100,
                     log_every_n_steps=1,
                     precision='16-mixed'
                     )

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model=model, train_dataloaders=dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type              | Params
-------------------------------------------------------------
0 | ehr_embed              | Linear            | 256   
1 | ehr_cxr_project        | Linear            | 49.3 K
2 | backbone               | MAEBackbone       | 5.7 M 
3 | projection_head        | MSNProjectionHead | 6.7 M 
4 | anchor_backbone        | MAEBackbone       | 5.7 M 
5 | anchor_projection_head | MSNProjectionHead | 6.7 M 
6 | anchor_ehr_cxr_project | Linear            | 49.3 K
7 | criterion              | MSNLoss           | 0     
  | other params           | n/a               | 1.0 M 
-------------------------------------------------------------
13.5 M    Trainable params
12.5 M    Non-trainable params
26.0 M    Total params
103.881   Total estimated model params size (MB)


Epoch 12:  17%|█▋        | 992/5879 [02:00<09:52,  8.25it/s, v_num=1, train_loss_step=4.970, train_loss_epoch=4.970] 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 13:   4%|▍         | 247/5879 [00:34<13:02,  7.20it/s, v_num=1, train_loss_step=4.840, train_loss_epoch=4.960] 

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 13:  91%|█████████▏| 5365/5879 [10:22<00:59,  8.62it/s, v_num=1, train_loss_step=4.900, train_loss_epoch=4.960]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 14:  73%|███████▎  | 4276/5879 [08:17<03:06,  8.59it/s, v_num=1, train_loss_step=5.030, train_loss_epoch=4.940]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 15:  60%|██████    | 3537/5879 [06:51<04:32,  8.59it/s, v_num=1, train_loss_step=4.930, train_loss_epoch=4.930]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 26:  51%|█████     | 3011/5879 [05:48<05:31,  8.65it/s, v_num=1, train_loss_step=4.820, train_loss_epoch=4.840]