In [32]:
import copy
import torch

import torch.nn as nn
import pytorch_lightning as pl
import lightly.models.utils as utils
from typing import List
from torch.utils.data import DataLoader
import torchvision
from lightly.utils.benchmarking import BenchmarkModule
import torch.distributed as dist
from lightly.loss import MSNLoss
from lightly.models.modules.heads import MSNProjectionHead
from lightly.models.modules.masked_autoencoder import MAEBackbone

In [107]:
class MSN(BenchmarkModule):
    def __init__(self,
               dataloader_kNN: DataLoader, 
               num_classes: int = 2,
               knn_k: int = 5,
               knn_t: float = 0.1,
               mask_ratio: float = 0.15,
               lr : float = 0.1,
               prototypes_num: int = 1024,
               weight_decay: float = 0.0,
               max_epochs: int = 100
               ) -> None:
        
        super().__init__(dataloader_kNN, num_classes,knn_k,knn_t)
        #self.save_hyperparameters() 
        self.weight_decay = weight_decay
        self.mask_ratio = mask_ratio
        self.lr = lr
        self.prototypes_num = prototypes_num
        self.max_epochs = max_epochs
       
        vit = torchvision.models.vit_b_16(weights=None)
        self.backbone = MAEBackbone.from_vit(vit)
        self.projection_head = MSNProjectionHead(768)

        self.anchor_backbone = copy.deepcopy(self.backbone)
        self.anchor_projection_head = copy.deepcopy(self.projection_head)

        utils.deactivate_requires_grad(self.backbone)
        utils.deactivate_requires_grad(self.projection_head)

        self.prototypes = nn.Linear(256, self.prototypes_num, bias=False).weight
        self.criterion = MSNLoss()

    def training_step(self, batch, batch_idx):
        utils.update_momentum(self.anchor_backbone, self.backbone, 0.996)
        utils.update_momentum(self.anchor_projection_head, self.projection_head, 0.996)

        views, _, _ = batch
        views = [view.to(self.device, non_blocking=True) for view in views]
        targets = views[0]
        anchors = views[1]
        anchors_focal = torch.concat(views[2:], dim=0)

        targets_out = self.backbone(targets)
        targets_out = self.projection_head(targets_out)
        anchors_out = self.encode_masked(anchors)
        anchors_focal_out = self.encode_masked(anchors_focal)
        anchors_out = torch.cat([anchors_out, anchors_focal_out], dim=0)

        loss = self.criterion(anchors_out, targets_out, self.prototypes.data)
        self.log("train_loss", loss, on_epoch= True,on_step=True , logger=True, prog_bar=True)
        return loss
    
        

    def encode_masked(self, anchors):
        batch_size, _, _, width = anchors.shape
        seq_length = (width // self.anchor_backbone.patch_size) ** 2
        idx_keep, _ = utils.random_token_mask(size=(batch_size, seq_length),
                                              mask_ratio=self.mask_ratio,
                                              device=self.device,
                                              )
        out = self.anchor_backbone(anchors, idx_keep)
        return self.anchor_projection_head(out)
    
    
#     def on_validation_epoch_end(self, outputs):
#         device = self.dummy_param.device
#         if outputs:
#             total_num = torch.Tensor([0]).to(device)
#             total_top1 = torch.Tensor([0.]).to(device)
#             for (num, top1) in outputs:
#                 total_num += num[0]
#                 total_top1 += top1
             

#             acc = float(total_top1.item() / total_num.item())
#             if acc > self.max_accuracy:
#                 self.max_accuracy = acc
#             self.log('kNN_accuracy', acc * 100.0, prog_bar=True, on_epoch=True, logger=True,)


    def configure_optimizers(self):
        params = [
            *list(self.anchor_backbone.parameters()),
            *list(self.anchor_projection_head.parameters()),
            self.prototypes,
        ]

        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr,weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                               eta_min=0.00001,
                                                               T_max=self.max_epochs
                                                            )
        return {'optimizer': optimizer,
               'lr_scheduler': scheduler
               }


In [108]:
import os
import glob
import torch
import random
import pandas as pd
import torch.nn as nn
import torchvision.transforms as T
from PIL import Image
from typing import Tuple, Optional
from torch.utils.data import Dataset
from lightly.transforms.msn_transform import MSNTransform

In [109]:
def preprocess(data_dir:str,
               paths: list,
               split:str,
               ) -> Tuple[List]:
    CLASSES  = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
                'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
                'Lung Opacity', 'No Finding', 'Pleural Effusion', 'Pleural Other',
                'Pneumonia', 'Pneumothorax', 'Support Devices']
        
    filenames_to_path = {path.split('/')[-1].split('.')[0]: path for path in paths}
    metadata = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-metadata.csv'))
    labels = pd.read_csv(os.path.join(data_dir,'mimic-cxr-2.0.0-chexpert.csv'))
    labels[CLASSES] = labels[CLASSES].fillna(0)
    labels = labels.replace(-1.0, 0.0)
    splits = pd.read_csv(os.path.join(data_dir,'mimic-cxr-ehr-split.csv'))
    metadata_with_labels = metadata.merge(labels[CLASSES+['study_id'] ], how='inner', on='study_id')
    filesnames_to_labels = dict(zip(metadata_with_labels['dicom_id'].values, metadata_with_labels[CLASSES].values))
    filenames_loaded = splits.loc[splits.split==split]['dicom_id'].values
    
    filenames_loaded = [filename  for filename in filenames_loaded if filename in filesnames_to_labels]

    return filenames_to_path, filenames_loaded, filesnames_to_labels

In [110]:
class MIMICCXR(Dataset):
    def __init__(self, 
                 paths: str,
                 data_dir: str, 
                 transform: Optional[T.Compose] = None, 
                 split: str = 'validate',
                 percentage:float = 1.0
                 ) -> None:
        self.data_dir = data_dir
        self.transform = transform
        self.filenames_to_path, \
        self.filenames_loaded, \
        self.filesnames_to_labels = preprocess(data_dir=self.data_dir,
                                               paths=paths,
                                               split=split
                                              )
        limit = (round(len(self.filenames_loaded) * percentage))
        self.filenames_loaded = self.filenames_loaded[0:limit]
 
        
    def __getitem__(self, index):
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            labels = torch.tensor(self.filesnames_to_labels[index]).float()

            if self.transform is not None:
                img = self.transform(img)
            return img, labels
        
        filename = self.filenames_loaded[index]
        
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')

        labels = torch.tensor(self.filesnames_to_labels[filename]).float()
        
            
        if self.transform is not None:
            img = self.transform(img)
        return img, labels

    def __len__(self):
        return len(self.filenames_loaded)
    


transform=MSNTransform(cj_prob=0,gaussian_blur=0,random_gray_scale=0)


class PretrainDataset(Dataset):
    def __init__(self, 
                 data_dir: str, 

                 ) -> None:
      
        self.data_dir = data_dir
        self.all_images = os.listdir(self.data_dir)
        for image in self.all_images:
            if image.startswith('._'):
                self.all_images.remove(image)
        
    def __len__(self
                ) -> int:
        return len(self.all_images)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[torch.Tensor]:

        
        name = self.all_images[index]
        path = os.path.join(self.data_dir, name)
        img = Image.open(fp=path).convert('RGB')

        img = transform(img)



        return img, index, name



class MIMICVal(Dataset):
    def __init__(self, 
                 paths: str,
                 data_dir: str, 
                 transform: Optional[T.Compose] = None, 
                 split: str = 'val'
                 ) -> None:
        self.data_dir = data_dir
        self.transform = transform
        self.filenames_to_path, self.filenames_loaded, self.filesnames_to_labels = preprocess(data_dir=self.data_dir,
                                                                                              paths=paths,
                                                                                              split=split
                                                                                              )

    def __getitem__(self, index):
        if isinstance(index, str):
            img = Image.open(self.filenames_to_path[index]).convert('RGB')
            labels = torch.tensor(self.filesnames_to_labels[index]).float()
            if labels[8] == 1.0:
                label = torch.tensor(0)
            else:
                label = torch.tensor(1)

            if self.transform is not None:
                img = self.transform(img)
            return img, label, index
        
        filename = self.filenames_loaded[index]
        
        img = Image.open(self.filenames_to_path[filename]).convert('RGB')

        labels = torch.tensor(self.filesnames_to_labels[filename]).float()
        
        if labels[8] == 1.0:
            label = torch.tensor(0)
        else:
            label = torch.tensor(1)
            
        if self.transform is not None:
            img = self.transform(img)
        return img, label, filename

    def __len__(self):
        return len(self.filenames_loaded)

In [111]:
data_dir = "/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0"

In [112]:
train_dataset = PretrainDataset(data_dir=os.path.join(data_dir,'resized'))

In [113]:
train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=64,
                              num_workers=16,
                              pin_memory=True,
                              shuffle=True
                              )

In [114]:
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

val_test_transforms = T.Compose([T.Resize(256),
                                 T.CenterCrop(224),
                                 T.ToTensor(),
                                 T.Normalize(mean=IMAGENET_STAT["mean"],
                                             std=IMAGENET_STAT["std"])                                                                
                            ])


In [115]:
paths = paths = glob.glob(os.path.join(data_dir,'resized','**','*.jpg'), recursive=True)
val_dataset = MIMICVal(paths=paths,
                       data_dir=data_dir,
                       split='validate',
                       transform = val_test_transforms)

In [116]:
val_dataloader = DataLoader(dataset= val_dataset,
                            batch_size=64,
                            drop_last=True,
                            num_workers=16,
                            pin_memory=True
                            )

In [118]:
model = MSN(dataloader_kNN=val_dataloader,
           num_classes=2,
           knn_k=20,
           knn_t=0.2,
           mask_ratio=0.15,
           lr=0.0001,
           prototypes_num=1024,
           weight_decay=0.001,
           max_epochs=100
           )

In [119]:
trainer = pl.Trainer(max_epochs=100,
                     log_every_n_steps=1,
                     )

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/sas10092/.conda/envs/chexmsn-env/lib/python3.9 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [120]:
trainer.fit(model=model, train_dataloaders=train_dataloader)

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
Missing logger folder: /scratch/sas10092/ChexMSN/notebooks/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                   | Type              | Params
-------------------------------------------------------------
0 | backbone               | MAEBackbone       | 86.6 M
1 | projection_head        | MSNProjectionHead | 6.3 M 
2 | anchor_backbone        | MAEBackbone       | 86.6 M
3 | anchor_projection_head | MSNProjectionHead | 6.3 M 
4 | criterion              | MSNLoss           | 0     
  | other params           | n/a               | 262 K 
-------------------------------------------------------------
93.1 M    Trainable params
92.9 M    Non-trainable params
185 M     Total params
743.989   Total estimated model params size (MB)


Epoch 0: 100%|█████████▉| 5892/5893 [1:59:18<00:01,  0.82it/s, v_num=0, train_loss_step=5.410]

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 22. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 2:  16%|█▌        | 938/5893 [19:01<1:40:32,  0.82it/s, v_num=0, train_loss_step=5.420, train_loss_epoch=5.330] 

/home/sas10092/.conda/envs/chexmsn-env/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
