In [None]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

## My configuration/path
data_path = "/home/mas/Desktop/Project-ai-master/data/GlaS/"
grade_file = "Grade.csv"

csv_file=data_path+grade_file

pd.read_csv(csv_file)

In [None]:
# Author: Alexander Hustinx
# Date: 8-06-2018
#
# GlaS Dataset
# Version: v0.1

from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

## My configuration/path
data_path = "/home/mas/Desktop/Project-ai-master/data/GlaS/"
grade_file = "Grade.csv"

class GlaSDataset(Dataset):
    """ GlaS Dataset  """
    def __init__(self, csv_file=data_path+grade_file, root_dir=data_path, transform=None, transform_anno=None, desired_dataset=None):
        """
        Arguments:
            csv_file: path to the grade csv-file
            root_dir: path to the map containing the images
            transform: (optional) transformation to be applied on sample['image']
            transform_anno: (optional) transformation to be applied on the sample['image_anno']
            desired_dataset: (optional) rows where the name does not contains this keyword will be deleted
                    this allows you to split the dataset into 'train' and 'test'
        """
        print(csv_file)
        # File extensions *cough* hardcoded *cough*
        self.image_ext = '.bmp'
        self.annotation_label = '_anno'
        
        print(root_dir)
        print(csv_file)
        #Load csv-file into pandas
        self.framework = pd.read_csv(csv_file)
        
        #Get rid of those pesky whitespaces at the start and end of the grades
        self.framework[' grade (GlaS)'] = self.framework[' grade (GlaS)'].str.strip()
        self.framework[' grade (Sirinukunwattana et al. 2015)'] = self.framework[' grade (Sirinukunwattana et al. 2015)'].str.strip()
        
        #Remove all rows not containing the given desired_dataset, allowing to split 'test' and 'train'
        if desired_dataset:
            self.framework = self.framework[self.framework['name'].str.contains(desired_dataset) == True]
        
        self.root_dir = root_dir
        self.transform = transform
        self.transform_anno = transform_anno
        
    def __len__(self):
        return len(self.framework)
    
    def __getitem__(self, index):
        """Sample format:
            image: image containing the to segment/grade cells
            image_anno: image containing the segmented cells
            patient_id: id of the patient the cell originated from
            GlaS: assigned GlaS grade (target #1)
            grade: assigned (Sirinukunwattana et al. 2015) grade (target #2)
        """
        
        image_name = self.root_dir + self.framework.iloc[index, 0]
        image = io.imread(image_name + self.image_ext)
        image_anno = io.imread(image_name + self.annotation_label + self.image_ext)
        
        #Currently unused, but future-proofing
        patient_id = self.framework.iloc[index, 1]
        
        GlaS = self.framework.iloc[index, 2]
        grade = self.framework.iloc[index, 3]
        
        sample = {'image':image, 'image_anno':image_anno, 'patient_id':patient_id, 'GlaS':GlaS, 'grade':grade}
        
        #Currently unused, but future-proofing (This will be the supplied preprocessing/data augmentation)
        if self.transform:
            #PIL-image must be HxWxC, thus must have 3 dimensions
            if len(sample['image_anno'].shape) == 2:
                sample['image_anno'] = np.expand_dims(sample['image_anno'], axis=2)
            sample['image'] = transforms.functional.to_pil_image(sample['image'])
            sample['image'] = self.transform(sample['image'])
        
        if self.transform_anno:
            #PIL-image must be HxWxC, thus must have 3 dimensions
            if len(sample['image_anno'].shape) == 2:
                sample['image_anno'] = np.expand_dims(sample['image_anno'], axis=2)
            sample['image_anno'] = transforms.functional.to_pil_image(sample['image_anno'])
            sample['image_anno'] = self.transform_anno(sample['image_anno'])
        
        return sample


## Example for the proof-of-concept:
##         Draws the first 4 images and their segmentations
##        Including their GlaS grade and (Sirinukunwattana et al. 2015) grade
if __name__ == '__main__':
    
    #load dataset
    fig = plt.figure()
    dataset = GlaSDataset(desired_dataset='test')
    
    for i in range(len(dataset)):
        #load a sample
        sample = dataset[i]
        
        print("Index #{}:\n\tPatient id:\t\t{}\n\tImage size:\t\t{}\n\tAnnotated image size:\t{}\n\tGlaS grade:\t\t{}\n\tOther grade:\t\t{}"
            .format(i, sample['patient_id'], sample['image'].shape, sample['image_anno'].shape, sample['GlaS'], sample['grade']))
        
        ##plots: start
        ax = plt.subplot(2, 4, i + 1)
        plt.tight_layout()
        ax.axis('off')
        ax.set_title('Sample #{}'.format(i))
        plt.imshow(sample['image'])
        
        ax = plt.subplot(2, 4, i + 5)
        plt.tight_layout()
        ax.axis('off')
        plt.imshow(sample['image_anno'])
        ##plots: end

        #we only show 3, proof-of-concept
        if i == 3:
            plt.show()
            break

In [None]:
# Author: Alexander Hustinx
# Date: 12-06-2018
#
# GlaS DataLoader example and transformation example
# Version: v1.1

from __future__ import print_function, division
import os

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from skimage import io, transform

from GlaS_dataset import GlaSDataset

data_path = "/home/mas/Desktop/Project-ai-master/data/GlaS/"
grade_file = "Grade.csv"

def imshow(input, title=None):    
    images_batch = input['image']
    
    print('images_batch.shape: ', images_batch.shape)
    
    grid = torchvision.utils.make_grid(images_batch)
    
    print('grid.shape: ', grid.shape)
    print('grid T . shape: ', grid.numpy().transpose((1, 2, 0)).shape)
    
    plt.imshow(grid.numpy().transpose((1,2,0)))
    plt.title('batch from dataloader')
        
## Example for the proof-of-concept:
##         Draws the first 4 images and their segmentations
##        Including their GlaS grade and (Sirinukunwattana et al. 2015) grade
if __name__ == '__main__':
    batch_size = 4
    
    data_transform = transforms.Compose([transforms.Scale((256,256)),transforms.ToTensor()])
    
    
    print(data_transform)
    #load train dataset
    GlaS_train_dataset = GlaSDataset(csv_file=data_path+grade_file, root_dir=data_path,transform=data_transform,
                                transform_anno=data_transform, 
                                desired_dataset='train')

    #load test dataset
    GlaS_test_dataset = GlaSDataset(csv_file=data_path+grade_file, root_dir=data_path,transform=data_transform,
                                transform_anno=data_transform,
                                desired_dataset='test')
    
    
    train_loader = DataLoader(GlaS_train_dataset, 
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)
                            
    test_loader = DataLoader(GlaS_test_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1)
                            
    #loop over the set
    for batch_i, sampled_batch in enumerate(train_loader):
        print("Index #{}:\n\tPatient id:\t\t{}\n\tImage size:\t\t{}\n\tAnnotated image size:\t{}\n\tGlaS grade:\t\t{}\n\tOther grade:\t\t{}"
            .format(batch_i, sampled_batch['patient_id'], sampled_batch['image'].shape, sampled_batch['image_anno'].shape, sampled_batch['GlaS'], sampled_batch['grade']))
        
        
        
        #Observe the 3rd batch
        if batch_i == 2:
            plt.figure()
            imshow(sampled_batch)
            plt.axis('off')
            plt.ioff()
            plt.show()
            ##plots: end
            
            break
    
    

In [None]:
# I assume that we did image augmentation now I have an normalized and cropped image and mask.
# I assume that have also hyper params files

In [1]:
# U-net
# Project AI
# 13 June 2018, Masoumeh 

# Initial dummy version

import torch
print(torch.__version__)

import torch
import torchvision
import torchvision.transforms as transforms
from collections import defaultdict
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

import scipy.ndimage


# ---------------------------------- helper functions -------------------------------------------
def load_dataset():
    # ------------------------
    # TODO:    
    train = defaultdict(list) # train["image"][i] = i th image of training set
    validation = defaultdict(list)
    # ------------------------    
    
    return train,validation


def load_hyperparams(param_path):    
    param_file = open(param_path,"r")
    
    hyperparams = defaultdict()

    keywords = ['filterNumStart',"lr", "epochs", "lambda2", "batchSize", "doBatchNorm", "channels", "dropout","depth", "mask", "labels", "dataPath"]
    types = ["int","listfloat", "int", "float", "int", "int", "liststring", "float","int","string", "string", "string"]
    
    
    key_type = {}
    for i in range(len(keywords)):
        key_type[keywords[i]] = types[i]        
    print(key_type)
    
    for line in param_file:
        info = line.replace(' ','').strip().split('=')
        #print(info)
        #print('----------------')
        list_value = []
        if(key_type[info[0]] in ["int", "listfloat", "float"]):
            print(info)            
            hyperparams[info[0]] = list(map(float,info[1].split(',')))
        else:
            hyperparams[info[0]] = info[1].split(',')
        
    return hyperparams

class encoder(nn.Module):
    def __init__(self, hyper_params):
        super(encoder, self).__init__()
   
        
    
class decoder(nn.Module):
    def __init__(self, hyper_params):
        super(decoder, self).__init__()

        
        
class UNet(nn.Module):
    def __init__(self, channels, hyper_params):
        super(UNet, self).__init__()
        
        self.hyper_params = hyper_params
        
        # later will change these to loop over all options for every epoch
        self.network_depth = int(hyper_params['depth'][0])
        filter_num = int(hyper_params['filterNumStart'][0])
        in_channels = int(channels)
            
        
        print("----------- Building Encoder -------------")
        self.down_blocks = []
        self.up_blocks = []
        print(in_channels, filter_num)
        for d in range(self.network_depth):
            block_d = {}

            block_d['conv1'] = nn.Sequential(nn.Conv2d(in_channels, filter_num, kernel_size = 3, stride=1, padding=0), nn.BatchNorm2d(filter_num), nn.ReLU())
            in_channels = filter_num            
            print(in_channels, filter_num)
            block_d['conv2'] = nn.Sequential(nn.Conv2d(in_channels, filter_num, kernel_size = 3, stride=1, padding=0), nn.BatchNorm2d(filter_num), nn.ReLU())
            if(d != self.network_depth-1):
                block_d['maxpool'] = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
                filter_num = filter_num * 2
            self.down_blocks.append(block_d)
            print(in_channels, filter_num)
           
        print("----------- Building Decoder -------------")    
        in_channels = filter_num
        for d in range(int(self.network_depth - 1)):
            block_u = {}
            filter_num = int(in_channels / 2)
            print(in_channels, filter_num)
            block_u['upconv'] = nn.Sequential(nn.ConvTranspose2d(in_channels, filter_num, kernel_size=2, stride=2), nn.BatchNorm2d(filter_num), nn.ReLU())            
            
            print(in_channels, filter_num)
            block_u['conv1'] = nn.Sequential(nn.Conv2d(in_channels, filter_num, kernel_size = 3, stride=1, padding=0), nn.BatchNorm2d(filter_num), nn.ReLU())
            
            in_channels = int(in_channels / 2)
            print(in_channels, filter_num)
            block_u['conv2'] = nn.Sequential(nn.Conv2d(in_channels, filter_num, kernel_size = 3, stride=1, padding=0), nn.BatchNorm2d(filter_num), nn.ReLU())
            
            self.up_blocks.append(block_u)
    
    @staticmethod
    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    m.weight.data.normal_(0, math.sqrt(2. / n)) 
                    
    def forward(self,x):
        # torch.nn.functional.leaky_relu_(input, negative_slope=0.01)
        concat_features = []
        for d in range(self.network_depth):
            x = self.down_blocks[d]['conv1'](x)
            x = self.down_blocks[d]['conv2'](x)
            concat_features.append(x)
            if d != self.network_depth - 1:
                x = self.down_blocks[d]['maxpool'](x)
        # now create the upsampling path
        
        for d in range(self.network_depth-1):
            x = self.up_blocks[d]['upconv'](x)

            #torch.concat([concat_features[self.network_depth-2-d], x], 1) # TODO: crop the concat feature
            x = self.crop_and_concat(x, concat_features[self.network_depth-2-d], crop=True)

            x = self.up_blocks[d]['conv1'](x)
            x = self.up_blocks[d]['conv2'](x)
        
        #print(x)
        return x
    
    # from https://discuss.pytorch.org/t/cropping-images-in-a-batch-on-the-gpu/7485/2
    def crop_and_concat(self, upsampled, bypass, crop=False):
        if crop:
            print(bypass.size())
            print(upsampled.size())
            c = (bypass.size()[2] - upsampled.size()[2]) // 2  # assumes equal width/height
            bypass = F.pad(bypass, (-c, -c, -c, -c))
        return torch.cat((upsampled, bypass), 1)
        
            
            
if __name__ == '__main__':
    

    augmented_mask = "/home/mas/Desktop/Project-ai-master/data/GlaS/testA_1_anno.bmp"
    augmented_image = "/home/mas/Desktop/Project-ai-master/data/GlaS/testA_1.bmp"
    
    hyper_params = load_hyperparams("/home/mas/Desktop/Project-ai-master/data/hyper_params")
    
    print(len(hyper_params['filterNumStart']))
    
    
    
    height, width, channels = scipy.ndimage.imread(augmented_image).shape    
    print(height, width, channels)
    
    
    
    
    net = UNet(channels,hyper_params)
    x = Variable(torch.FloatTensor(np.random.random((1,3,572, 572))))
    out = net(x)
    #main()

0.4.0
{'filterNumStart': 'int', 'lr': 'listfloat', 'epochs': 'int', 'lambda2': 'float', 'batchSize': 'int', 'doBatchNorm': 'int', 'channels': 'liststring', 'dropout': 'float', 'depth': 'int', 'mask': 'string', 'labels': 'string', 'dataPath': 'string'}
['filterNumStart', '64']
['lr', '0.0001,0.0003']
['epochs', '20']
['lambda2', '0.0001']
['batchSize', '32']
['doBatchNorm', '1']
['dropout', '0.3']
['depth', '5']
1
522 775 3
----------- Building Encoder -------------
3 64
64 64
64 128
128 128
128 256
256 256
256 512
512 512
512 1024
1024 1024
1024 1024
----------- Building Decoder -------------
1024 512
1024 512
512 512
512 256
512 256
256 256
256 128
256 128
128 128
128 64
128 64
64 64
torch.Size([1, 512, 64, 64])
torch.Size([1, 512, 56, 56])
torch.Size([1, 256, 136, 136])
torch.Size([1, 256, 104, 104])
torch.Size([1, 128, 280, 280])
torch.Size([1, 128, 200, 200])
torch.Size([1, 64, 568, 568])
torch.Size([1, 64, 392, 392])
