### Code for updated data transformations and augmentations and visualise them to change 

In [None]:
import numpy as np
import pandas as pd
import os
import torch
import random
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import torchvision.transforms.functional as TF
import random
from typing import Sequence

import sys 
root_code = os.path.dirname(os.path.dirname(os.getcwd()))
sys.path.insert(0, root_code)

from codebase.utils.constants import *
from codebase.utils.raw_utils import *
from codebase.utils.dataset_utils import clean_train_val_test_sep_for_manual_aligns_ordered_by_quality, filter_and_transform

from codebase.utils.zipper import do_zip
from codebase.experiments.cgan3.training_helpers import *
from codebase.experiments.cgan3.network import * #Translator, Discriminator
from codebase.experiments.cgan3.loaders import CGANDataset
from codebase.utils.eval_utils import get_protein_list
from codebase.utils.HEtransform_utils import *

In [None]:
sys.path

In [None]:
protein_set = 'reduced_ext'
patch_size = 256
cv_split = 'split0'
imc_prep_seq = 'raw_clip99_arc_otsu3'
project_path = '/cluster/work/grlab/projects/projects2021-multivstain/'
data_path = '/cluster/work/grlab/projects/projects2021-multivstain/data/tupro/'
standardize_imc = True
scale01_imc = True
batch_size = 8

train_aligns = get_aligns(project_path, cv_split=cv_split, protein_set=protein_set, aligns_set='train')
protein_subset = get_protein_list(protein_set)    

### 1. Data loader v2 
- modifying the current setup with transform functions and probability 
- flow: data loader with transforms --> batch to device
- IMC transforms for channelwise minmax std; shared transforms for flips and rotations; HE transforms for color and affine transforms  

#### 2.1 IMC transforms

In [None]:
def IMC_transforms(standardize_data, minmax_data, cohort_stats_file, channel_list):
    '''
    The function returns the nn.sequential necessary for normalisation of IMC data
    standardize_data: True/False based on if need to standardise IMC data  
    minmax_data: True/False based on if need to apply minmax to IMC data
    cohort_stats_file: path where the stats of the split reside 
    channel_list: the desired markers in the multiplex   
    '''
    if standardize_data or minmax_data: 
        cohort_stats = pd.read_csv(cohort_stats_file, sep='\t', index_col=[0])

    if standardize_data:
        # load cohort stats based on imc preprocessing steps (naming convention)
        mean_mat = cohort_stats['mean_cohort'][channel_list]
        std_mat = cohort_stats['std_cohort'][channel_list]

    if minmax_data:
        min_col = 'min_stand_cohort' if standardize_data else 'min_cohort'
        max_col = 'max_stand_cohort' if minmax_data else 'max_cohort'
        min_mat = cohort_stats[min_col][channel_list]
        max_mat = cohort_stats[max_col][channel_list]

    def default(val, def_val):
        return def_val if val is None else val

    imc_transforms = []
    if standardize_data: 
        imc_transforms.append(
        T.Normalize(mean_mat, std_mat)
        )

    if minmax_data: 
        imc_transforms.append(
        T.Normalize(min_mat, (max_mat-min_mat))
        )
    imc_transform_default = nn.Sequential(*imc_transforms)
    imc_transforms = default(None, imc_transform_default)  
    return imc_transforms



#### 2.2 Shared transforms 

In [None]:
# ----- SHARED TRANSFORMS -----  
def shared_transforms(img1, img2, p=0.5):
    '''
    The function contains the possible transformation that could be applied simultaneously to H&E and IMC data
    eg: random horizontal or vertical flip, random rotation for angles multiple of 90 degress
    img1: H&E ROI (expected)
    img2: IMC ROI (expected)
    p: the probability with which the transformation should be applied 
    '''
    # Random horizontal flipping
    if random.random() < p:
        img1 = TF.hflip(img1)
        img2 = TF.hflip(img2)

    # Random vertical flipping
    if random.random() < p:
        img1 = TF.vflip(img1)
        img2 = TF.vflip(img2)
        
    # Random 90 degree rotation
    if random.random() < p:
        angle = random.choice([90, 180, 270])
        img1 = TF.rotate(img1, angle)
        img2 = TF.rotate(img2, angle) 
    return img1, img2


#### 2.3 HE transforms 

In [None]:
def HE_transforms(img, p=[0.0, 0.5, 0.5]):
    '''
    The function contains the possible transformation that could be applied to H&E ROIs
    This includes 
    eg: random horizontal or vertical flip, random rotation for angles multiple of 90 degress
    img1: H&E ROI (expected)
    img2: IMC ROI (expected)
    p: the probability with which the transformation should be applied 
    '''
    # Random color jitter
    if random.random() < p[0]:
#         jitter = T.ColorJitter(brightness=.34, hue=.15)
        jitter = T.ColorJitter(brightness=.15, hue=.05, saturation=0.15)
        img = jitter(img)

    # Random HED jitter 
    if random.random() < p[1]:
        img = torch.permute(img, (1, 2, 0)) # channel first to last    
#         hedjitter = HEDJitter(theta=0.05) # from HEtransform_utils
        hedjitter = HEDJitter(theta=0.01) # from HEtransform_utils

        img = hedjitter(img) 

    # Random affine transform
    if random.random() < p[2]:
        if not img.shape[2]==3: 
            img = torch.permute(img, (1, 2, 0)) # channel first to last    
        randomaffine = RandomAffineCV2(alpha=0.02) # from HEtransform_utils
        img = randomaffine(img) 
    return img

In [None]:
# torch.manual_seed(0)
# random.seed(0)
# np.random.seed(0)

In [None]:
# ----- New version of dataloader -----
class CGANDataset_v2(Dataset):
    def __init__(self, project_path, align_results: list, name: str, data_path: str, protein_subset=PROTEIN_LIST_MVS, patch_size=400, imc_prep_seq='raw', cv_split='split0', standardize_imc=True, scale01_imc=True, factor_len_dataloader=8.0, which_HE='new', p_flip_jitter_hed_affine=[0.5,0.0,0.5,0.5]):
        super(CGANDataset_v2, self).__init__()

        self.project_path = project_path
        self.align_results = align_results
        self.name = name
        self.data_path = data_path
        self.patch_size = patch_size
        self.channel_list = [protein2index[prot_name] for prot_name in protein_subset]
        self.cv_split = cv_split
        
        self.HE_ROI_STORAGE = get_he_roi_storage(self.data_path, which_HE)      
        self.IMC_ROI_STORAGE = get_imc_roi_storage(self.data_path, imc_prep_seq, standardize_imc, scale01_imc, cv_split)
         
        # if need to std or minmax IMC data 
        standardize_data = standardize_imc and ('std' not in self.IMC_ROI_STORAGE) 
        minmax_data = scale01_imc and ('minmax' not in self.IMC_ROI_STORAGE) 
        cohort_stats_file = os.path.join(project_path, COHORT_STATS_PATH, cv_split, 'imc_rois_'+imc_prep_seq+'-agg_stats.tsv')

        self.imc_transforms = IMC_transforms(standardize_data, minmax_data, cohort_stats_file, self.channel_list)
        self.shared_transforms = shared_transforms
        self.he_transforms = HE_transforms

        self.p_shared = p_flip_jitter_hed_affine[0]
        self.p_jitter = p_flip_jitter_hed_affine[1]
        self.p_hed = p_flip_jitter_hed_affine[2]
        self.p_affine = p_flip_jitter_hed_affine[3]

        assert len(self.align_results) > 0, "Dataset received empty list of alignment results !"
        print(self.name + " has ", len(self.align_results), " alignment results !")

        # an estimation of number of training samples 
        self.num_samples = int(len(self.align_results) * ((1000 // (self.patch_size // 2)) ** 2) * factor_len_dataloader)
        print(self.name + " has ", self.num_samples, " training samples !")
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):  # CAUTION idx argument is ignored, dataset is purely random !
        
        # load from disk:
        ar = random.choice(self.align_results)
        he_roi = np.load(os.path.join(self.HE_ROI_STORAGE, ar["sample"] + "_" + ar["ROI"] + ".npy"), mmap_mode='r')
        imc_roi = np.load(os.path.join(self.IMC_ROI_STORAGE, ar["sample"] + "_" + ar["ROI"] + ".npy"), mmap_mode='r')
        
        # only keep channels that we need:
        imc_roi = imc_roi[:, :, self.channel_list]
        
        augment_x_offset = random.randint(0, 1000 - self.patch_size)
        augment_y_offset = random.randint(0, 1000 - self.patch_size)

        he_patch = he_roi[4 * augment_y_offset: 4 * augment_y_offset + 4 * self.patch_size,
                          4 * augment_x_offset: 4 * augment_x_offset + 4 * self.patch_size, :]

        imc_patch = imc_roi[augment_y_offset: augment_y_offset + self.patch_size,
                            augment_x_offset: augment_x_offset + self.patch_size, :]   

        he_patch = he_patch.transpose((2, 0, 1)) 
        imc_patch = imc_patch.transpose((2, 0, 1)) 
        he_patch = torch.from_numpy(he_patch.astype(np.float32, copy=False))
        imc_patch = torch.from_numpy(imc_patch.astype(np.float32, copy=False))
        
        he_patch, imc_patch = self.shared_transforms(he_patch, imc_patch, p=self.p_shared)
        imc_patch =  self.imc_transforms(imc_patch)
        he_patch_T = self.he_transforms(he_patch, p=[self.p_jitter, self.p_hed, self.p_affine])

        if not he_patch.shape[0]==3: 
            he_patch = torch.from_numpy(he_patch.transpose((2, 0, 1)))
        if not he_patch_T.shape[0]==3: 
            he_patch_T = torch.from_numpy(he_patch_T.transpose((2, 0, 1)))

        return {'he_patch': he_patch.to(torch.float), 
                'imc_patch': imc_patch.to(torch.float),
                'he_patch_T': he_patch_T.to(torch.float),
                'sample': ar['sample'], 'roi': ar['ROI'], 'x_offset': augment_x_offset, 'y_offset': augment_y_offset
               }    

In [None]:
standardize_imc = True
scale01_imc = True 
which_HE = 'new'

prob_flip_jitter_hed_affine = '0.5,1,1,0.5'
p_flip_jitter_hed_affine = list(map(float, prob_flip_jitter_hed_affine.split(',')))
print(p_flip_jitter_hed_affine)

In [None]:

train_ds = CGANDataset_v2(project_path, align_results=train_aligns[0:30],
                    name="Train",
                    data_path=data_path,
                    patch_size=patch_size,
                    protein_subset=protein_subset,
                    imc_prep_seq=imc_prep_seq,
                    cv_split=cv_split,
                    standardize_imc=standardize_imc,
                    scale01_imc=scale01_imc,
                    factor_len_dataloader=1, 
                    which_HE=which_HE, 
                    p_flip_jitter_hed_affine=p_flip_jitter_hed_affine)

print('Loaded data')
trainloader = DataLoader(train_ds,
                         batch_size=batch_size,
                         shuffle=True,
                         pin_memory=True,
                         num_workers=8, 
                         drop_last=True)

print('HE_ROI_STORAGE: ', train_ds.HE_ROI_STORAGE)
print('IMC_ROI_STORAGE: ', train_ds.IMC_ROI_STORAGE)

print(train_ds.p_shared, train_ds.p_jitter, train_ds.p_hed, train_ds.p_affine)

In [None]:
# get batch of data 
batch = next(iter(trainloader))

In [None]:
print(batch['he_patch'].shape, batch['he_patch'].type(), torch.aminmax(batch['he_patch'][0]))
print(batch['he_patch_T'].shape, batch['he_patch_T'].type(), torch.aminmax(batch['he_patch_T'][0]))
print(batch['imc_patch'].shape, batch['imc_patch'].type(), torch.aminmax(batch['imc_patch'][0]))

In [None]:
# visualise data

plt.figure(figsize = (20,5))
plt.axis('off')
for i in range(8):
    plt.subplot(2, 8, i+1)
    plt.imshow(torch.permute(batch['he_patch'][i], (1, 2, 0)))
    plt.axis('off')

for i in range(8):
    plt.subplot(2, 8, i+1+8)
    plt.imshow(torch.permute(batch['he_patch_T'][i], (1, 2, 0)))
#     plt.imshow(torch.permute(batch['imc_patch'][i], (1, 2, 0))[:,:,6])
    plt.axis('off')
