In [None]:
!pip install lightning

In [None]:
# Overview of data distribution

import pandas as pd
import numpy as np
import os
import pydicom

train_labels = pd.read_csv('/kaggle/input/rsna-pneumonia-detection-challenge/stage_2_train_labels.csv')

#-------------------------------
print('Target distribution')
n_zeros = len(train_labels[train_labels['Target']==0])
n_ones = len(train_labels[train_labels['Target']==1])
n_else = len(train_labels)-n_zeros-n_ones
print(f"Ones: {n_ones}, Zeros: {n_zeros}, else: {n_else}")

#------------------------------
print('bounding box position distribution')
bboxes = train_labels[['x','y','width','height']].dropna()
x_info = {
    'min': np.min(bboxes['x']),
    'mean': np.mean(bboxes['x']),
    'max': np.max(bboxes['x']),
    'std': np.std(bboxes['x'])
}
y_info = {
    'min': np.min(bboxes['y']),
    'mean': np.mean(bboxes['y']),
    'max': np.max(bboxes['y']),
    'std': np.std(bboxes['y'])
}
w_info = {
    'min': np.min(bboxes['width']),
    'mean': np.mean(bboxes['width']),
    'max': np.max(bboxes['width']),
    'std': np.std(bboxes['width'])
}
h_info = {
    'min': np.min(bboxes['height']),
    'mean': np.mean(bboxes['height']),
    'max': np.max(bboxes['height']),
    'std': np.std(bboxes['height'])
}
print(f"bbox x distribution - {x_info}")
print(f"bbox y distribution - {y_info}")
print(f"bbox w distribution - {w_info}")
print(f"bbox h distribution - {h_info}")

#-----------------------------
print('pixel data distribution')
patient100 = train_labels['patientId'].sample(100)
p_info = {'min':1e+10,'mean':0,'max':0,'std':0}
for pid in patient100:
    dcm_root_path = '/kaggle/input/rsna-pneumonia-detection-challenge/stage_2_train_images'
    dcm_path = os.path.join(dcm_root_path,f'{pid}.dcm')
    img = pydicom.read_file(dcm_path).pixel_array
    p_info['min'] = min(p_info['min'],np.min(img))
    p_info['mean'] += np.mean(img)
    p_info['max'] = max(p_info['max'],np.max(img))
    p_info['std'] += np.std(img)
    
p_info['mean'] /= 100
p_info['std'] /= 100

print(f"pixel value distribution - {p_info}")


In [None]:
import lightning as L

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F


class CustomTransform(nn.Module):
    # totensor
    # random horizontal flip
    # random crop
    # normalize
    def __init__(self, size, random=True):
        self.size = size
        self.p = 0.5
        self.mean = 126
        self.std = 57
        self.random=random
        
    def forward(self, data):
        # split image and bbox label
        img, bbox = data
        width = img.shape[-1]
        # numpy to tensor
        img = torch.from_numpy(img, dtype=torch.float32)
        if self.random:
            # random horizontal flip
            if torch.rand(1) < self.p:
                img = F.hflip(img)
                bbox[0] = width-bbox[0]-bbox[2]
            # random crop (did not implement padding)
            i, j, h, w = transforms.RandomCrop.get_params(img, self.size)
            img = F.crop(img,i,j,h,w)
            bbox[0] = bbox[0] - j if bbox[0] > j else 0
            bbox[1] = bbox[1] - i if bbox[1] > i else 0
            bbox[2] = w-bbox[0]-1 if bbox[0]+bbox[2] >= w else bbox[2]
            bbox[3] = h-bbox[1]-1 if bbox[1]+bbox[3] >= h else bbox[3]
        # normalize
        img = F.normalize(img, self.mean, self.std)
        
        return img, bbox
    
    def __repr__(self):
        out = "Custom Transform to transform both the image and the bbox\n"
        out += "\ttorch.from_numpy()\n"
        if self.random:
            out += f"\tRandomHorizontalFlip(p={self.p})\n"
            out += f"\tRandomCrop({self.size}, padding=False)\n"
        out += f"\tNormalize(mean={self.mean},std={self.std})"
        return out
        
        
class CustomDataset(Dataset):
    def __init__(self, root, df, transform):
        super(MyDataset).__init__()
        self.root = root
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # dtype of float32 implemented only
        # convert bbox to tensor
        bbox = torch.Tensor(row[1:5], dtype=torch.float32)
        label = torch.Tensor(row[5], dtype=torch.float32)
        
        pid = row[0]
        dcm_path = os.path.join(self.root,f'{pid}.dcm')
        img = pydicom.read_file(dcm_path).pixel_array
        
        img, bbox = self.transform((img,bbox))
        
        return pid, img, bbox, label
        
    

class PDCDataModule(L.LightningDataModule):
    def __init__(self, data_dir='./', batch_size=1,num_workers=0):
        super().__init__()
        self.data_dir = data_dir
        if isinstance(batch_size,int):
            self.tr_batch_size=batch_size
            self.val_batch_size=batch_size
        elif len(batch_size) == 2:
            self.tr_batch_size=batch_size[0]
            self.val_batch_size=batch_size[1]
        else:
            raise ValueError(batch_size) 
        self.num_workers=num_workers
        
    def prepare_data(self, fold=0, random_seed=42)
        # read full dataframe
        full_df = pd.read_csv(os.path.join(
            self.data_dir,'stage_2_train_labels.csv'))
        df_0 = full_df[full_df['Target']==0]
        df_1 = full_df[full_df['Target']==1]
        # apply undersampling to target==0
        df_00 = df_0.sample(frac=0.25,random_state=random_seed)
        df_01 = df_0.drop(df_00.index).sample(n=len(df_00),random_state=random_seed)
        df_10 = df_1.sample(frac=0.5,random_state=random_seed)
        df_11 = df_1.drop(df_10.index)
        # Train Test Split: split the dataframe
        if fold == 0:
            self.tr_df = pd.concat((df_00,df_10))
            self.val_df = pd.concat((df_01,df_11))
        elif fold == 1:
            self.tr_df = pd.concat((df_01,df_11))
            self.val_df = pd.concat((df_00,df_10))
        else:
            raise ValueError('fold should be either 0 or 1')
        
    def setup(self, size):
        self.tr_dset = CustomDataset(
            root=os.path.join(self.data_dir,'stage_2_train_images'),
            df=self.tr_df,
            transform=CustomTransform(size))
        self.val_dset = CustomDataset(
            root=os.path.join(self.data_dir,'stage_2_train_images'),
            df=self.val_df,
            transform=CustomTransform(size, random=False))
    
    def train_dataloader(self):
        return DataLoader(
                    dataset=self.tr_dset, 
                    batch_size=self.tr_batch_size, 
                    suffle=True, num_workers=self.num_workers)
    def val_dataloader(self):
        return DataLoader(
                    dataset=self.val_dset, 
                    batch_size=self.val_batch_size, 
                    shuffle=False, num_workers=self.num_workers)
    
    
# lightning module hooks guide.
# https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#lightning-hooks
class LitMobileNetv2(L.LightningModule):
    def __init__(self, nr, p0, pmax, n_saccade):
        super().__init__()
        self.save_hyperparameters()
        self.model = torch.hub.load('pytorch/vision:v0.10.0','mobilenet_v2',pretrained=True)
        self.model.features[0][0] = nn.Conv2d(
            1,32,kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.classifier =  \
            torch.nn.Linear(in_features=1280, out_features=1, bias=True)
        self.fovea_radius = p0
        self.base_fov = get_inv_FCG_mapping(nr, p0, pmax, (0,0))
        self.un_fov_func = get_FCG_revertFunc(nr,p0,pmax)
        self.n_sacc = n_saccade
        
    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        # Override this to prevent batch transfer at this stage
        # Since in the training step multiple forward is required
        # batch to device is manually done in the training step
        return batch
    
    def training_step(self, batch, batch_idx):
        b_pid, b_img, b_bbox, b_label = batch
        # transfer label to device. 
        # This can be done in the transfer_batch_to_device hook
        b_label = b_label.to(self.device)
        b, c, h, w = b_img.shape
        # initialize zoom/scale to 1
        scale = torch.ones((b,2))
        # initialize saccade/translate to the center of the original image
        translate = torch.ones((b,2))*torch.tensor([w//2,h//2])
        # init loss. losses are acculumated for all of the saccades
        total_loss = 0
        for _ in range(self.n_sacc):
            # get foveated batch 
            b_fov = batch_remap(b_img, self.base_fov, scale, translate, border_value=0)
            # transfer only the foveated to the device
            b_fov = b_fov.to(self.device)
            pred, attn_info = self.model(b_fov)
            # calculate loss
            loss = F.binary_cross_entropy_with_logits(pred,b_label)
            # accumulate loss
            total_loss += loss
            
            # undo foveation to the attn_info(which is y,x,r)
            attn_loc_0 = attn_info[:,:2] - torch.tensor(b_fov.shape[2:])
            attn_loc = self.un_fov_func(attn_loc)
            coef = ( torch.max(attn_loc,dim=1).values 
                    / torch.max(attn_loc_0,dim=1).values )
            attn_r = coef*attn_info[:,2]
            attn_loc = attn_loc*scale+translate
            # update saccade (scale & translate)
            scale = (attn_r/self.fovea_radius).view(-1,1).repeat_interleave(2,1)
            translate = attn_loc
            
            # any logging if necessary
        return total_loss / self.n_sacc
        
            
        

In [None]:
import torch
import torch.nn as nn

class CustomMobileNetV2(nn.Module):
    def __init__(self):
        super(CustomMobileNetV2, self).__init__()
        base = torch.hub.load('pytorch/vision:v0.10.0','mobilenet_v2',pretrained=True)
        self.input_layer = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True))
        self.bottom_features = nn.Sequential(*base.features[1:14])
        self.top_features = nn.Sequential(*base.features[14:])
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.2, inplace=False),
            nn.Linear(in_features=1280, out_features=1, bias=True))
        
        # weight initialization
        for m in self.input_layer.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                # omitted bias
            elif isinstance(m, (nn.BatchNorm2d)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.zeros_(m.bias)
        
    @staticmethod
    def get_attn_loc(attn):
        b, h, w = attn.shape
        max_v, max_i = torch.max(attn.flatten(start_dim=1),dim=1)
        max_y = max_i // w
        max_x = max_i % w
        # calculate the size of the area around the max
        # where 90% of the values are above average
        bool_attn = attn>torch.mean(attn,dim=(1,2)).view(-1,1,1)
        r_batch = torch.ones((b))
        for bid, (b_at, y, x) in enumerate(zip(bool_attn,max_y,max_x)):
            max_r = min((y,h-y,x,w-x))
            for r in range(3,max_r):
                p_above_avg = torch.sum(b_at[y-r:y+r+1,x-r:x+r+1]) / ((2*r+1)**2)
                if p_above_avg < 0.9:
                    r_batch[bid] = r-1
        return torch.stack([max_y,max_x,r_batch]).T
    
    
    def _forward_impl(self,x):
        # This exists since TorchScript doesn't support inheritance, so the superclass method
        # (this one) needs to have a name other than `forward` that can be accessed in a subclass
        x = self.input_layer(x)
        x = self.bottom_features(x)
        # Insert attention block after 13 InvertedResidual block 
        # to keep 16:1 reduce in resolution in the attention map
        # as they did in the paper, use the mean across the features
        attn = torch.mean(x.detach(),dim=1)
        locs = CustomMobileNetV2.get_attn_loc(attn)  
        # continue forward 
        x = self.top_features(x)
        # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
        x = nn.functional.adaptive_avg_pool2d(x, 1).reshape(x.shape[0], -1)
        x = self.classifier(x)
        return x, locs

    def forward(self, x):
        return self._forward_impl(x)

    
print(CustomMobileNetV2())

In [None]:
class PrinterNet(torch.nn.Module):
    def __init__(self):
        super(PrinterNet, self).__init__()
        
    def forward(self, x):
        print(x.shape)
        return x

new_features = []
model = torch.hub.load('pytorch/vision:v0.10.0','mobilenet_v2',pretrained=True)
for i, feature in enumerate(model.features):
    new_features.append(feature)
    if i != 0:
        new_features.append(PrinterNet())
        
model.features = torch.nn.Sequential(*new_features)
# print(model.features)

In [None]:
print(model.features[0][0])

model.features[0][0] = \
    torch.nn.Conv2d(
        1,32,kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

model.classifier[1] = \
    torch.nn.Linear(in_features=1280, out_features=1, bias=True)

print(model.features[0])
print(model.classifier)

In [None]:
model = CustomMobileNetV2()

In [None]:
import pydicom
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms

img = pydicom.read_file('/kaggle/input/rsna-pneumonia-detection-challenge/stage_2_test_images/0000a175-0e68-4ca4-b1af-167204a7e0bc.dcm')
img = img.pixel_array

print(img.shape, type(img))
print(np.min(img),np.mean(img), np.max(img))

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229]),
])

x = preprocess(img[::4,::4])
x = x.unsqueeze(0)
print(x.shape)
print(torch.min(x),torch.mean(x), torch.max(x))

model.eval()
with torch.no_grad():
    out, locs = model(x)
    
print(out.shape, locs.shape)
conf = torch.nn.functional.sigmoid(out[0])
print(conf, locs)
