In [2]:
import torch
import torch.nn as nn
import math
import torchvision.transforms as transforms
import numpy as np
from torch.optim.lr_scheduler import MultiStepLR
import torch.nn.functional as tF
import torchvision.transforms.functional as tvF

import random
import torch.utils.model_zoo as model_zoo
import torch.backends.cudnn as cudnn
from easydict import EasyDict
import logging
from torch.utils.data import Dataset, DataLoader
import glob
import os
from skimage import io, transform
import matplotlib.pyplot as plt
import PIL
import time
from collections import Counter
import gc
from unet import UNet

In [3]:
PIL.__version__

'5.0.0'

In [4]:
cudnn.benchmark = True

In [5]:
EXTENSIONS = ['.jpg', '.png']

In [7]:
OPTIONS = EasyDict()
OPTIONS.DEBUG = False # Less data to train and test
OPTIONS.CODE_NAME = 'UNET_3class_64_DO20' + ('_debug' if OPTIONS.DEBUG else '')
OPTIONS.NUM_CLASSES = 3

OPTIONS.LOG = EasyDict()
OPTIONS.LOG.LOG_FILE = '/home/kevin/nuclei_segmentation/log/log_{}.txt'.format(OPTIONS.CODE_NAME)

OPTIONS.TRAIN = EasyDict()
OPTIONS.TRAIN.BATCH_SIZE = 8
OPTIONS.TRAIN.SHUFFLE = True
OPTIONS.TRAIN.IMG_DIR = './train_aug_with_tumor2/imgs/train/'
OPTIONS.TRAIN.MASK_DIR = './train_aug_with_tumor2/masks/train/'
OPTIONS.TRAIN.PRINT_FREQ = 50
OPTIONS.TRAIN.LR = 1e-4
OPTIONS.TRAIN.LR_DECAY_GAMMA = 0.5
OPTIONS.TRAIN.LR_DECAY_MILESTONES = [3, 7, 11, 15, 19]
OPTIONS.TRAIN.MAX_EPOCH = 20
OPTIONS.TRAIN.EPITHELIUM_WEIGHT = None #set to None if you don't need to weight the crossentropy.

OPTIONS.VAL = EasyDict()
OPTIONS.VAL.SPLIT_RND_SEED = 1357
OPTIONS.VAL.RATIO = 0.20
OPTIONS.VAL.BATCH_SIZE = 16
OPTIONS.VAL.SHUFFLE = False
OPTIONS.VAL.IMG_DIR = OPTIONS.TRAIN.IMG_DIR
OPTIONS.VAL.MASK_DIR = OPTIONS.TRAIN.MASK_DIR


OPTIONS.DATA = EasyDict()
OPTIONS.DATA.INPUT_SIZE = 256

OPTIONS.CHECKPOINT = EasyDict()
OPTIONS.CHECKPOINT.DIR = '/home/kevin/nuclei_segmentation/checkpoints/checkpoints_{}'.format(OPTIONS.CODE_NAME)

In [8]:
logger = logging.getLogger(OPTIONS.CODE_NAME)
logger.setLevel(logging.DEBUG)
logger.propagate = False

log_file = logging.FileHandler(OPTIONS.LOG.LOG_FILE)
log_file.setLevel(logging.DEBUG)

fmt = '%(asctime)s %(levelname)-8s: %(message)s'
fmt = logging.Formatter(fmt)

log_file.setFormatter(fmt)
logger.addHandler(log_file)

logger.info('\n\n'+str(OPTIONS))

In [9]:
def log_info(msg):
    print(msg)
    logger.info(msg)

In [10]:
class MyDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None, debug=False):
        """
        Args:
            img_dir (string): Path to directory with images.
            mask_dir (string): Path to directory with masks.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        assert len(glob.glob(os.path.join(img_dir, '*'))) == len(glob.glob(os.path.join(mask_dir, '*')))
        self.debug = debug
        if self.debug:
            self.img_dir = img_dir
            self.img_names = glob.glob(os.path.join(self.img_dir, '*'))
            random.shuffle(self.img_names)
            self.img_names = self.img_names[:4096]
            self.mask_dir = mask_dir
            self.transform = transform
        else:
            self.img_dir = img_dir
            self.img_names = glob.glob(os.path.join(self.img_dir, '*'))
            self.mask_dir = mask_dir
            self.transform = transform
    
    def __len__(self):
        if self.debug:
            return len(self.img_names)
        else:
            return len(self.img_names)
    
    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        mask_name = os.path.join(self.mask_dir,self.img_names[idx].split('/')[-1].split('.')[0] + '_mask.png')
        image = io.imread(img_name)
        mask = io.imread(mask_name)
        sample = {'image': image, 'mask': mask}#
        if self.transform:
            sample = self.transform(sample)
        return sample


In [11]:
class RandomResizedCrop(object):
    """Crop randomly the image in a sample.

    Args:
        output_size (tuple or int): Desired output size. If int, square crop
            is made.
    """

    def __init__(self, output_size, scale, interpolation=PIL.Image.NEAREST):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
            self.scale = scale
            self.interpolation = interpolation
        else:
            assert len(output_size) == 2
            self.output_size = output_size
            self.scale = scale
            self.interpolation = interpolation
        

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']
        for attempt in range(10):
            area = image.shape[0] * image.shape[1]
            target_area= random.uniform(*self.scale) * area
            w = int(round(math.sqrt(target_area)))
            h = w
            if w <= image.shape[0] and h <= image.shape[1]:
                top = random.randint(0, image.shape[0] - h)
                left = random.randint(0, image.shape[1] - w)
                image = image[top:top+h,
                              left:left+w]
                mask = mask[top:top+h,
                            left:left+w]

                image = PIL.Image.fromarray(image).resize(self.output_size, self.interpolation)
                mask  = PIL.Image.fromarray(mask).resize(self.output_size, self.interpolation)
        
                return {'image': image, 'mask': mask}

class RandomRot90(object):
    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly rotated image.
        """

        t = random.choice([None, PIL.Image.ROTATE_90, PIL.Image.ROTATE_180, PIL.Image.ROTATE_270])
        
        if t is None:
            return sample
        else:
            return {'image': sample['image'].transpose(t),
                    'mask' : sample['mask'].transpose(t)}

        
    def __repr__(self):
        return self.__class__.__name__

class ColorJitter(transforms.ColorJitter):
    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Input image.
        Returns:
            PIL Image: Color jittered image.
        """
        transform = self.get_params(self.brightness, self.contrast,
                                    self.saturation, self.hue)
        return {'image': transform(sample['image']),
                'mask': sample['mask']}

class RandomHorizontalFlip(transforms.RandomHorizontalFlip):
    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return {'image': tvF.hflip(sample['image']),
                    'mask': tvF.hflip(sample['mask'])}
        return sample

class RandomVerticalFlip(transforms.RandomVerticalFlip):
    def __call__(self, sample):
        """
        Args:
            img (PIL Image): Image to be flipped.
        Returns:
            PIL Image: Randomly flipped image.
        """
        if random.random() < self.p:
            return {'image': tvF.vflip(sample['image']),
                    'mask': tvF.vflip(sample['mask'])}
        return sample
    
class Normalize(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
    .. note::
        This transform acts in-place, i.e., it mutates the input tensor.
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized Tensor image.
        """
        return {'image': transforms.functional.normalize(sample['image'], self.mean, self.std),
                'mask' : sample['mask']}

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

class NormalizeTest(object):
    """Normalize a tensor image with mean and standard deviation.
    Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
    will normalize each channel of the input ``torch.*Tensor`` i.e.
    ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
    .. note::
        This transform acts in-place, i.e., it mutates the input tensor.
    Args:
        mean (sequence): Sequence of means for each channel.
        std (sequence): Sequence of standard deviations for each channel.
    """

    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, sample):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized Tensor image.
        """
        return {'image_ori': sample['image'],
                'image': transforms.functional.normalize(sample['image'], self.mean, self.std),
                'mask' : sample['mask'],
                }

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
    
class ToTensor(object):
    """Convert PIL images and masks in sample to Tensors."""

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = np.array(image)/255.
        mask = np.array(mask)/255.
        image = image.transpose((2, 0, 1))
        mask = mask.transpose((2, 0, 1))
        mask[mask<0.5] = 0.
        mask[mask>0.5] = 1.
        image = image.astype(np.float32)
        mask = mask.astype(np.float32)
        return {'image': torch.from_numpy(image),
                'mask': torch.from_numpy(mask)}
    
class ToTensorTest(object):
    """Convert PIL images and masks in sample to Tensors."""

    def __call__(self, sample):
        image, mask = sample['image'], sample['mask']

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = np.array(image)/255.
        mask = np.array(mask)
        image = image.transpose((2, 0, 1))
#        mask = mask.transpose((2, 0, 1))
#        mask[mask<0.5] = 0.
#        mask[mask>0.5] = 1.
        image = image.astype(np.float32)
#        mask = mask.astype(np.float32)
        return {'image': torch.from_numpy(image),
                'mask': torch.from_numpy(mask)}

In [12]:
normalize = Normalize(mean=[0.485, 0.456, 0.406],
                      std=[0.229, 0.224, 0.225])
train_dataset = MyDataset(OPTIONS.TRAIN.IMG_DIR,
                          OPTIONS.TRAIN.MASK_DIR,
                          transform=transforms.Compose([
                                    RandomResizedCrop(OPTIONS.DATA.INPUT_SIZE, scale=(0.8, 1.0)),
                                    RandomRot90(),
                                    ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
                                    RandomHorizontalFlip(),
                                    RandomVerticalFlip(),
                                    ToTensor(),
                                    normalize]),
                          debug=OPTIONS.DEBUG)
    

In [13]:
val_dataset = MyDataset(OPTIONS.VAL.IMG_DIR,
                        OPTIONS.VAL.MASK_DIR,
                          transform=transforms.Compose([
                                    RandomResizedCrop(OPTIONS.DATA.INPUT_SIZE, scale=(0.8, 1.0)),
                                    RandomRot90(),
                                    ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
                                    RandomHorizontalFlip(),
                                    RandomVerticalFlip(),
                                    ToTensor(),
                                    normalize]),
                        debug=OPTIONS.DEBUG)

random.shuffle(val_dataset.img_names)
val_dataset.img_names = val_dataset.img_names[:int(len(val_dataset)*OPTIONS.VAL.RATIO)]


In [14]:
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=OPTIONS.TRAIN.BATCH_SIZE, shuffle=OPTIONS.TRAIN.SHUFFLE,
    num_workers=4, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=OPTIONS.VAL.BATCH_SIZE, shuffle=False,
    num_workers=4, pin_memory=True)

In [15]:
log_info('Training set:')
cnt = Counter()
log_info('TOTAL: {} images'.format(len(train_dataset)))

Training set:
TOTAL: 25000 images


In [16]:
log_info('Val set:')
log_info('TOTAL: {} images'.format(len(val_dataset)))

Val set:
TOTAL: 5000 images


## Average Meter

In [17]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [18]:
def get_accuracy(output, target):
    """Computes the accuracy"""
    with torch.no_grad():
        batch_size = output.size(0)
        img_area = output.size(2) * output.size(3)
        pred = tF.softmax(output).argmax(dim=1)
        correct = pred.eq(target).sum()
        return float(correct) / (float(batch_size) * float(img_area))

In [19]:
def prec_rec(cm):
    pr_mat = np.zeros((len(classes), 2))
    for i in range(len(classes)):
        pr_mat[i, 0] = cm[i, i] / cm[:, i].sum()
        pr_mat[i, 1] = cm[i, i] / cm[i, :].sum()
        
    pr_mat = pd.DataFrame(pr_mat)
    pr_mat.columns = ['Precision', 'Recall']
    pr_mat.index = classes
    return pr_mat

In [20]:
def train(train_loader, model, criterion, optimizer, epoch):
    # switch to train mode
    model.train()

    # update parameters
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    end = time.time()
    for i, sample in enumerate(train_loader):
        data_time.update(time.time() - end)
        image= sample['image']
        mask=sample['mask'].type(torch.LongTensor)
        n_samples = image.size(0)
        
        # need to create labels to match dimensions
        labels = torch.argmax(mask, dim=1)
        labels = labels.cuda()
        image = image.cuda()
        mask = mask.cuda()

        output = model(image)
        output = output.cuda()

        loss = criterion(output, labels)
        accuracy = get_accuracy(output, labels)

        losses.update(loss.item(), n_samples)
        accuracies.update(accuracy)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        del labels, image, mask
        gc.collect()
        if i % OPTIONS.TRAIN.PRINT_FREQ == 0:
            msg = 'Epoch: [{}][{}/{}]\t'.format(epoch, i, len(train_loader))        
            msg += 'Time: {:.3f} ({:.3f})\tData: {:.3f}'.format(batch_time.val, batch_time.avg, data_time.val)
            msg += '\tAccuracy: {:.3f}\t'.format(accuracies.val)
            msg += '\tLoss: {:.06f} ({:0.6f})'.format(losses.val, losses.avg)
            log_info(msg)


In [21]:
def validate(val_loader, model, criterion):
    batch_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()
    confusion_mat = np.zeros((OPTIONS.NUM_CLASSES, OPTIONS.NUM_CLASSES), dtype=np.int)
    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        
        for i, sample in enumerate(val_loader):
            
            image= sample['image']
            mask=sample['mask'].type(torch.LongTensor)
            n_samples = image.size(0)
            
            mask = mask.cuda(non_blocking=True)
            
            labels = torch.argmax(mask, dim=1)
            labels = labels.cuda()
            image = image.cuda()
            mask = mask.cuda()
            
            # compute output
            output = model(image)
            output = output.cuda()

            # measure accuracy and record loss

            loss = criterion(output, labels)
            accuracy = get_accuracy(output, labels)

            losses.update(loss.item(), n_samples)
            accuracies.update(accuracy)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            gc.collect()
            del labels, image, mask
        log_info(' * Accuracy: {}'.format(accuracies.avg))
        log_info(' * Loss: {}'.format(losses.avg))
    return accuracies.avg, losses.avg

In [22]:
def save_checkpoint(state, filename):
    if not os.path.exists(OPTIONS.CHECKPOINT.DIR):
        log_info("Creating dir '{}'...".format(OPTIONS.CHECKPOINT.DIR))
        os.makedirs(OPTIONS.CHECKPOINT.DIR)
        log_info("Done.")
        
    checkpoint_path = os.path.join(OPTIONS.CHECKPOINT.DIR, filename)
    torch.save(state, checkpoint_path)
    log_info("Checkpoint was saved to '{}'".format(checkpoint_path))

In [23]:
model = UNet(3,3)
model = torch.nn.DataParallel(model).cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = MultiStepLR(
    optimizer, 
    OPTIONS.TRAIN.LR_DECAY_MILESTONES,                    
    gamma=OPTIONS.TRAIN.LR_DECAY_GAMMA)


In [None]:
for epoch in range(1, OPTIONS.TRAIN.MAX_EPOCH+1):
    lr_scheduler.step(epoch)
    log_info('Epoch: {}\tlr: {}'.format(epoch, lr_scheduler.get_lr()[0]))
    train(train_loader, model, criterion, optimizer, epoch=epoch)
    validate(val_loader, model, criterion)
    save_checkpoint({
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer' : optimizer.state_dict(),
    }, 'epoch{}.pth'.format(epoch))


Epoch: 1	lr: 0.0001




Epoch: [1][0/3125]	Time: 5.888 (5.888)	Data: 0.186	Accuracy: 0.464		Loss: 1.052194 (1.052194)
Epoch: [1][50/3125]	Time: 0.337 (0.443)	Data: 0.031	Accuracy: 0.615		Loss: 0.981316 (0.894671)
Epoch: [1][100/3125]	Time: 0.340 (0.390)	Data: 0.033	Accuracy: 0.695		Loss: 0.834702 (0.868566)
Epoch: [1][150/3125]	Time: 0.342 (0.374)	Data: 0.033	Accuracy: 0.674		Loss: 0.816043 (0.847997)
Epoch: [1][200/3125]	Time: 0.340 (0.365)	Data: 0.033	Accuracy: 0.711		Loss: 0.807860 (0.828714)
Epoch: [1][250/3125]	Time: 0.342 (0.361)	Data: 0.033	Accuracy: 0.779		Loss: 0.715131 (0.811818)
Epoch: [1][300/3125]	Time: 0.343 (0.358)	Data: 0.033	Accuracy: 0.759		Loss: 0.678405 (0.799044)
Epoch: [1][350/3125]	Time: 0.342 (0.355)	Data: 0.029	Accuracy: 0.757		Loss: 0.685873 (0.784436)
Epoch: [1][400/3125]	Time: 0.344 (0.354)	Data: 0.033	Accuracy: 0.833		Loss: 0.571442 (0.772005)
Epoch: [1][450/3125]	Time: 0.347 (0.353)	Data: 0.032	Accuracy: 0.733		Loss: 0.685467 (0.760835)
Epoch: [1][500/3125]	Time: 0.342 (0.352)	Da