In [None]:
from glob import glob
import os 
import torch
from torch.utils.data import Dataset, DataLoader
import skimage
from skimage import io
from skimage import transform as stransform
from sklearn.model_selection import train_test_split
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import unet_parts_torch as u_parts

In [None]:
USE_CUDA = True
IMAGE_FOLDER = './images'
MASK_FOLDER = './masks'
IMAGE_SIZE = 128
BATCH_SIZE = 6
OUTPUT_DIR = './output'
KERNEL_NUM = 12
SEED = 42
EPOCHS = 20

In [None]:
class NucleiDataset(Dataset):
    """Dataset for semantic segmentation
    https://data.broadinstitute.org/bbbc/BBBC038/"""
    def __init__(self, id_lst, images_folder, masks_folder, transformer=None):
        self.id_lst = id_lst
        self.images_folder = images_folder
        self.masks_folder = masks_folder
        self.transformer = transformer
        
    def __len__(self):
        return len(self.id_lst)
    
    def __getitem__(self, idx):
        img_idx = self.id_lst[idx]
        img_path = os.path.join(self.images_folder, img_idx + '.png')
        mask_path = os.path.join(self.masks_folder, img_idx + '_mask.png')
        image = io.imread(img_path, as_gray=True)
        mask = io.imread(mask_path, as_gray=True)
        sample = {'image': image, 
                  'mask': mask}
        
        if self.transformer:
            sample = self.transformer(sample)
            
        return sample
        

In [None]:
all_images = glob(os.path.join(IMAGE_FOLDER, '*.png'))
ids = [os.path.splitext(os.path.basename(x))[0] for x in all_images]
# show some ids
ids[:4]

In [None]:


class TransformerEval(object):
    def __init__(self):
        self.image_shape = (IMAGE_SIZE, IMAGE_SIZE)
    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']
        image = stransform.resize(image, self.image_shape, mode='constant', anti_aliasing=True)
        mask = stransform.resize(mask, self.image_shape, mode='constant', anti_aliasing=True)
        
        
        image = np.expand_dims(image, -1).astype(np.float32)
        mask = np.expand_dims(mask, -1).astype(np.float32)
        image = image.transpose((2,0,1))
        mask = mask.transpose((2,0,1))
        return {'image': torch.from_numpy(image.copy()),
               'mask': torch.from_numpy(mask.copy())}
    
class TransformerTrain(object):
    def __init__(self):
        self.image_shape = (IMAGE_SIZE, IMAGE_SIZE)
    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']
        image = stransform.resize(image, self.image_shape, mode='constant', anti_aliasing=True)
        mask = stransform.resize(mask, self.image_shape, mode='constant', anti_aliasing=True)
        
        if np.random.binomial(1, 0.5):
            image = np.fliplr(image)
            mask = np.fliplr(mask)
        if np.random.binomial(1, 0.5):
            image = np.flipud(image)
            mask = np.flipud(mask)
        if np.random.binomial(1, 0.5):
            degree = np.random.randint(-180, 180)
            image = stransform.rotate(image, degree, mode='constant')
            mask = stransform.rotate(mask, degree, mode='constant')
       
        image = np.expand_dims(image, -1).astype(np.float32)
        mask = np.expand_dims(mask, -1).astype(np.float32)
        image = image.transpose((2,0,1))
        mask = mask.transpose((2,0,1))
        
        return {'image': torch.from_numpy(image.copy()),
               'mask': torch.from_numpy(mask.copy())} # we somehow need copy, else it fails
    

        

In [None]:
train_ids, eval_ids = train_test_split(ids, random_state=SEED)
print('Trainids length: ', len(train_ids))
print('Eval length: ', len(eval_ids))

In [None]:
ds_train = NucleiDataset(train_ids, IMAGE_FOLDER, MASK_FOLDER, TransformerTrain())
ds_eval = NucleiDataset(eval_ids, IMAGE_FOLDER, MASK_FOLDER, TransformerEval())

In [None]:
dataloader_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
dataloader_eval = DataLoader(ds_eval, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

#for i_batch, sample_batched in enumerate(dataloader_train):
#    print(i_batch, sample_batched['image'].shape, sample_batched['mask'].shape)

In [None]:

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        k = KERNEL_NUM
        output_channels = 1
        input_channels = 1
        conv_op = u_parts.double_conv
        
        
        
        self.inc = u_parts.inconv(input_channels, k, conv_op)
        self.down1 = u_parts.down(k, 2*k, conv_op)
        self.down2 = u_parts.down(2*k, 4*k, conv_op)
        self.down3 = u_parts.down(4*k, 8*k, conv_op)
        self.down4 = u_parts.down(8*k, 8*k, conv_op)
        
        self.up1 = u_parts.up(2*8*k, 4*k, conv_op)
        self.up2 = u_parts.up(2*4*k, 2*k, conv_op)
        self.up3 = u_parts.up(2*2*k, k, conv_op)
        self.up4 = u_parts.up(2*k, k, conv_op)
        self.outc = u_parts.outconv(k, output_channels)
        
        
 
    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [None]:
def train(model, device, loader, optimizer, epoch):
    model.train()
    for batch_idx, sampled_batch in enumerate(loader):
       
        image = sampled_batch['image']
        mask = sampled_batch['mask']
        #print(image.shape, mask.shape)
        image, mask = image.to(device), mask.to(device)
        optimizer.zero_grad()
        output = model(image)
        mask_pred = torch.sigmoid(output)
        #print('Mask pred shape', mask_pred.shape)
        #print('Mask shape', mask.shape)
        #loss = nn.BCELoss(mask, mask_pred)
        loss = F.binary_cross_entropy(mask_pred, mask)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(image), len(loader.dataset),
                    100. * batch_idx / len(loader), loss.item()))
       
    
def dice_pyt(prediction, label):
    # TODO should mean dice over batch
    prediction = torch.round(prediction)
    label = torch.round(label)
    intersection = prediction*label
    return 2*torch.sum(intersection)/(torch.sum(prediction) + torch.sum(label))

def evaluate(model, device, loader):
    model.eval()
    dice = 0
    cnt = 0
    with torch.no_grad():
        for sampled_batch in loader:
            cnt += 1
            image = sampled_batch['image']
            mask = sampled_batch['mask']
            image, mask = image.to(device), mask.to(device)
            output =  model(image)
            mask_pred = torch.sigmoid(output)
            dice += dice_pyt(mask_pred, mask)
            
            # visualization could also be done with tensorboardX 
            # https://github.com/lanpa/tensorboardX
            if cnt % 10 == 0:
                image = image.cpu().data.numpy()[0][0]
                mask = mask.cpu().data.numpy()[0][0]
                mask_pred = mask_pred.cpu().data.numpy()[0][0]
                merge = np.hstack([image, mask, mask_pred])
                io.imsave('progress.png', merge)
        dice/=cnt
        print('Test Dice: ', dice)
            
            
    

In [None]:
device = torch.device("cuda" if USE_CUDA else "cpu")
model = Unet().to(device)
optimizer = optim.Adam(model.parameters())

for epoch in range(1, EPOCHS+1):
    train(model, device, dataloader_train, optimizer, epoch)
    evaluate(model, device, dataloader_eval)