In [2]:
import os
import torch
import torch.nn as nn 
import pytorch_lightning as pl

import random
from collections import Counter
from typing import Tuple, Optional
from torch import Tensor
from PIL import Image
from torch.utils.data import Dataset

import os
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as T
from tqdm import tqdm
from torch import Tensor
from copy import deepcopy
from typing import List, Tuple, Dict
from torch.utils.data import DataLoader

from lightly.transforms import MSNTransform
from lightly.transforms.utils import IMAGENET_NORMALIZE

In [3]:
IMAGENET_STAT = {"mean":torch.tensor([0.4884, 0.4550, 0.4171]),
                 "std":torch.tensor([0.2596, 0.2530, 0.2556])}

In [4]:
transforms = T.Compose([T.RandomAffine(degrees=15, 
                                       scale=(.9, 1.1), 
                                       shear=0, 
                                       translate=(0.1, 0.1)),
                        MSNTransform(cj_prob=0,
                                     random_gray_scale=0,
                                     gaussian_blur=(0,0,0),
                                     sigmas=(0,0),
                                     random_crop_scale=(0.2,1.0),
                                     focal_crop_scale=(0.05,0.2),
                                    )                                                               
                            ])


val_test_transforms = T.Compose([T.Resize(256),
                                 T.CenterCrop(224),
                                 T.ToTensor(),
                                 T.Normalize(mean=IMAGENET_STAT["mean"],
                                             std=IMAGENET_STAT["std"])                                                                
                            ])

In [5]:
images_dir = '/scratch/fs999/shamoutlab/data/physionet.org/files/mimic-cxr-jpg/2.0.0/resized'
data_dir = os.path.join('..','data')

In [6]:
meta = pd.read_csv(os.path.join(data_dir,'meta.csv'))
meta = meta.drop(['Unnamed: 0.2','Unnamed: 0.1','Unnamed: 0'],axis=1)

In [7]:
transforms = MSNTransform(cj_prob=0,gaussian_blur=0)

In [8]:
class ChexMSNDataset(Dataset):
    def __init__(self, 
                 data_dir: str,
                 transforms: nn.Module,
                 same = True
                 ) -> None:
      
        self.meta = pd.read_csv(data_dir)
        self.all_images = list(self.meta.path)
        self.transform = transforms
        self.same = same
    def __len__(self
                ) -> int:
        return len(self.all_images)
    
    def __getitem__(self,
                    index: int
                    ) -> Tuple[Tensor]:

        
        target_path = self.all_images[index]
        image_id = target_path.split('/')[-1][:-4]
        img_age_path, img_gender_path = self._retrieve_anchors(image_id=image_id,
                                                               meta = self.meta,
                                                               same=self.same)

        img_target = Image.open(fp=target_path).convert('RGB')
        img_target = self.transform(img_target)
        
        img_age = Image.open(fp=img_age_path).convert('RGB')
        img_age = self.transform(img_age)

        img_gender = Image.open(fp=img_gender_path).convert('RGB')
        img_gender = self.transform(img_gender)

        return (img_target,img_age,img_gender)
    
    
    def _retrieve_anchors(self,
                          image_id: str,
                          meta: pd.DataFrame,
                          same: bool = False) -> Tuple[str]:
        record = meta[meta.dicom_id == image_id]
    
        subject_id = list(record.subject_id)[0]
        age_groub =list(record.ageR5)[0] 
        gender = list(record.gender)[0]
    
        group = meta[meta.ageR5 == age_groub]
    
        if same:
            candidate_anchors = group[group.gender == gender]
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            sampled_images = random.sample(images,k=2)
            image_age, image_gender = sampled_images[0],sampled_images[1]
            return image_age, image_gender
        else:
            candidate_anchors = group
            candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
            images= list(candidate_anchors.path)
            image_age = random.sample(images,k=1)[0]
            candidate_anchors = candidate_anchors[candidate_anchors.gender == gender]
            images= list(candidate_anchors.path)
            image_gender = random.sample(images,k=1)[0]
            return image_age, image_gender
        

In [9]:
dataset = ChexMSNDataset(data_dir='../data/meta.csv',transforms=transforms,)
len(dataset.__getitem__(5))


3

In [10]:
dataloader = DataLoader(dataset=dataset,batch_size=32,num_workers=8)
# next(iter(dataloader))

[[tensor([[[[-2.1179, -2.1179, -2.1179,  ...,  0.3652,  0.6563,  0.9132],
            [-2.1179, -2.1179, -2.1179,  ...,  0.3652,  0.5878,  0.8447],
            [-2.1179, -2.1179, -2.1179,  ...,  0.3309,  0.5022,  0.7762],
            ...,
            [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
            [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179],
            [-2.1179, -2.1179, -2.1179,  ..., -2.1179, -2.1179, -2.1179]],
  
           [[-2.0357, -2.0357, -2.0357,  ...,  0.5028,  0.8004,  1.0630],
            [-2.0357, -2.0357, -2.0357,  ...,  0.5028,  0.7304,  0.9930],
            [-2.0357, -2.0357, -2.0357,  ...,  0.4678,  0.6429,  0.9230],
            ...,
            [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
            [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357],
            [-2.0357, -2.0357, -2.0357,  ..., -2.0357, -2.0357, -2.0357]],
  
           [[-1.8044, -1.8044, -1.8044,  ...,  0.7228,  1.0191,  1.280

In [57]:
'02aa804e-bde0afdd-112c0b34-7bc16630-4e384014'

'02aa804e-bde0afdd-112c0b34-7bc16630-4e384014'

In [21]:
a = meta[meta.dicom_id == '174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962']
a

Unnamed: 0,dicom_id,subject_id,gender,age,ageR5,ageR10,path
1,174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...


In [22]:
age_groub =list(a.ageR5)[0]
group = meta[meta.ageR5 == age_groub]
group

Unnamed: 0,dicom_id,subject_id,gender,age,ageR5,ageR10,path
0,02aa804e-bde0afdd-112c0b34-7bc16630-4e384014,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
1,174413ec-4ec4c1f7-34ea26b7-c5f994f8-79ef1962,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
2,2a2277a9-b0ded155-c0de8eb9-c124d10e-82c5caab,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
3,e084de3b-be89b11e-20fe3f9f-9c8d8dfe-4cfd202c,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
4,68b5c4b1-227d0485-9cc38c3f-7b84ab51-4b472714,10000032,F,52.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
...,...,...,...,...,...,...,...
375962,e4cc4978-6158179d-6660b53f-06864161-e904cf0c,19995179,F,54.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
375963,01306696-5204aab3-62041daa-6574c92b-cf4c79f3,19995179,F,54.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
376162,5647da0d-52eadee5-ee406fe6-007be4f8-9f4f14e4,19998350,M,53.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...
376163,cc638eb5-d96655f2-1e6cc325-c41744cd-376cf2db,19998350,M,53.0,50-55,50-60,/scratch/fs999/shamoutlab/data/physionet.org/f...


In [19]:
def _retrieve_anchors(self,
                      image_id: str,
                      meta: pd.DataFrame
                      same: bool = False) -> Tuple[str]:
    
    record = meta[meta.dicom_id == image_id]
    
    subject_id = list(record.subject_id)[0]
    age_groub =list(record.ageR5)[0] 
    gender = list(record.gender)[0]
    
    group = meta[meta.ageR5 == age_groub]
    
    if same:
        candidate_anchors = group[group.gender == gender]
        candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
        images= list(candidate_anchors.path)
        sampled_images = random.sample(images,k=2)
        image_age, image_gender = sampled_images[0],sampled_images
        return image_age, image_gender
    else:
        candidate_anchors = group
        candidate_anchors = candidate_anchors[candidate_anchors.subject_id != subject_id]
        images= list(candidate_anchors.path)
        image_age = random.sample(images,k=1)[0]
        candidate_anchors = candidate_anchors[candidate_anchors.gender == gender]
        images= list(candidate_anchors.path)
        image_gender = random.sample(images,k=1)[0]
        return image_age, image_gender
        
    

In [31]:
same = True
image_id = '02aa804e-bde0afdd-112c0b34-7bc16630-4e384014'
_retrieve_anchors(1,image_id,same=True)[0].split('/')[-1][:-4]

'9ef725c8-ed7b6358-9be7a059-7f839fd8-84bfaa3f'