In [None]:
from functions.load_all_data import load_data_by_color, load_imgs_masks
from functions.load_training_data import load_rescaled_samples_opt, load_rescaled_samples
from functions.crop_image import random_crop
from functions.composites import composite_masks
from functions.rescaling import rescale_img_comp
from functions.sizes import compute_avg_size
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split

# Load Training Samples

In [None]:
class SizeDataSet:
    """
    Data structure containing randomly cropped images and
    corresponding mask composites. Returns a tensor of 
    normalized images, their mask composites, and the corresponding
    sizes of the nuclei in each image.
    
    Data is returned as a 2 dimensional array
        [x][:] x is the index to the tuple containing the nuclei size, mask composite, and image
        [:][x] x=0 --> nuclei size, x=1 --> max composite, x=2 --> image
    """
    
    def __init__(self, n_samples=1000, bins=50, crop_size=128, loud=False):
        # NOTE: rescaling using load_rescaled_samples_opt may take ~20 minutes
        # use load_rescaled_samples for faster results
#         cropped_sizes, cropped_comps, cropped_imgs = load_rescaled_samples_opt(n_samples,
#                                                                               BINS=bins,
#                                                                               CROP_SIZE=crop_size,
#                                                                                loud=loud)
        
        cropped_sizes, cropped_comps, cropped_imgs = load_rescaled_samples(n_samples)
        self.cropped_sizes = cropped_sizes
        self.cropped_comps = cropped_comps
        self.cropped_imgs = cropped_imgs
        self.ToTensor()
        
    def __len__(self):
        return len(self.cropped_sizes)
    
    def __getitem__(self, idx):
        return self.cropped_sizes[idx], self.cropped_comps[idx], self.cropped_imgs[idx]
    
    def ToTensor(self):
        """
        Convert data structure to a tensor for PyTorch manipulations
            must swap axis because
            numpy image: H x W x C
            torch image: C x H x W
        """
        for i in range(len(self.cropped_sizes)):
            self.cropped_imgs[i] = self.cropped_imgs[i].transpose((2, 0, 1))
            
            self.cropped_comps[i] = torch.tensor(self.cropped_comps[i].astype(float) 
                                                                   / 255.).type("torch.FloatTensor")
            self.cropped_imgs[i] = torch.tensor(self.cropped_imgs[i].astype(float)).type("torch.FloatTensor")
            
            self.cropped_sizes = list(self.cropped_sizes)
            self.cropped_sizes = torch.LongTensor(self.cropped_sizes)
    
    

In [None]:
training_data = SizeDataSet(loud=True)

In [None]:
split_data = random_split(training_data, [int(0.8*len(training_data)), 
                                         int(0.1*len(training_data)),
                                         int(0.1*len(training_data)),
                                         ])
train_set = DataLoader(split_data[0], batch_size=50)
validation_set = DataLoader(split_data[1], batch_size=50)
test_set = DataLoader(split_data[2], batch_size=50)

In [None]:
plt.imshow(np.moveaxis(split_data[0][1][2].numpy(), 0, -1))

In [None]:
# Configure GPU for training
# check for available GPUs using "watch nvidia-smi" in terminal
device=torch.device("cuda:7")

# Pytorch Lightning Model
* This is the same model as the Keras implementation
* I moved this model to a pytorch lightning class in order to give it structure and aid readablibilty

In [None]:
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
criterion = nn.L1Loss() #equivalent to mean absolute error

class LightningSizePredictor(pl.LightningModule):
    
    def __init__(self):
        
        super(LightningSizePredictor, self).__init__()
        
        self.layers = torch.nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64,
                                                   kernel_size=(3, 3), stride=(2,2)),
                                          nn.ReLU(),
                                          nn.Conv2d(in_channels=64, out_channels=32,
                                                   kernel_size=(3, 3), stride=(2,2)),
                                          nn.ReLU(),
                                          nn.Conv2d(in_channels=32, out_channels=16,
                                                   kernel_size=(3,3), stride=(2,2)),
                                          nn.ReLU(),
                                          nn.Flatten(),
                                          nn.Linear(392, 256),
                                          nn.ReLU(),
                                          nn.Linear(256, 1)
                                         )
        self.train_regress = nn.L1Loss()
        self.valid_regress = nn.L1Loss()
        
    def forward(self, x):
        return self.layers(x)
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def prepare_data(self):
        self.size_train, self.size_val = train_set, validation_set
        
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self(x)
        loss = criterion(logits, y)
        train_reg = self.train_regress(logits, y)
        
        self.log('train_loss', loss)
        self.log('train_loss_step', train_reg, on_step=True, on_epoch=True)
        return loss
    
    def training_epoch_end(self, outs):
        self.log('train_reg_epoch', self.train_regress.compute())
        
    def train_dataloader(self):
        return DataLoader(self.size_train, batch_size=64)
    
    def val_dataloader(self):
        return DataLoader(self.size_val, batch_size=64)
    
    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self(x)
        loss = criterion(logits, y)
        val_reg = self.valid_acc(logits, y)
        
        self.log('val_loss', loss)
        self.log('valid_reg', val_reg, on_step=True, on_epoch=True)
        
        return loss
    
    def validation_end(self, outputs):
        avg_loss = nn.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss':avg_loss, 'log':tensorboard_logs}

In [None]:
from pytorch_lightning.callbacks import ModelCheckpoint

model = LightningSizePredictor()
checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                     dirpath='checkpoints/',
                                     filename='sizeMod-{epoch:03d}-{val_loss:.2f}-{valid_acc:.2f}',
                                     save_top_k=3, 
                                     mode='min')

# be sure to change the gpus parameter to match the reserved gpu from above
trainer = pl.Trainer(gpus='7',
                    max_epochs=10,
                    default_root_dir='checkpoints/', checkpoint_callback=checkpoint_callback)
trainer.fit(model)