In [1]:
import os
import random
import cv2
import napari
import numpy as np
import imutils
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

from pyMSDtorch.core import helpers, train_scripts, corcoef
from pyMSDtorch.core.networks import MSDNet, TUNet

import torch.nn as nn
from skimage import exposure
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader, Dataset

from tifffile import imread, imwrite

### Utility functions

In [None]:
torch.cuda.device_count()

In [3]:
def shuffle_training(imgs, masks, seed=123):
    x = np.arange(imgs.shape[0])
    random.seed(seed)
    random.shuffle(x)
    
    return imgs[x,:], masks[x,:]

### Load Data

In [4]:
def crop_data(image,mask):
    image,mask = np.array(image),np.array(mask)
    blurred = cv2.GaussianBlur(image, (5, 5), 0)
    thresh = cv2.threshold(blurred, 60, 255, cv2.THRESH_BINARY)[1]
    thresh = thresh.astype(np.uint8)
    cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    cnts = imutils.grab_contours(cnts)

    for c in cnts:
        # compute the center of the contour
        M = cv2.moments(c)
        
        cX = int(M["m10"] / M["m00"])
        cY = int(M["m01"] / M["m00"])

        w = 1700
        x = max(0,cX - w/2)
        y = max(0,cY - w/2)
        if x+w >2000: x = x - (x+w - 2000)
        if y+w >2000: y = y - (y+w - 2000)
            
        return image[int(y):int(y+w), int(x):int(x+w)],mask[int(y):int(y+w), int(x):int(x+w)]

In [5]:
datadir = 'training_processed_images/combined_nuclie/raw'
train_imgs,train_masks,test_imgs,test_masks =[],[],[],[]

imgs,masks = imread(datadir + '/train.tif'),imread(datadir + '/masks.tif')
for i, img in enumerate(imgs):
    cropped_img,cropped_mask = crop_data(imgs[i],masks[i])
    train_imgs.append(cropped_img)
    train_masks.append(cropped_mask)

imgs,masks = imread(datadir + '/test.tif'),imread(datadir + '/test_masks.tif')
for i, img in enumerate(imgs):
    cropped_img,cropped_mask = crop_data(imgs[i],masks[i])
    test_imgs.append(cropped_img)
    test_masks.append(cropped_mask)

train_imgs,train_masks,test_imgs,test_masks =np.array(train_imgs),np.array(train_masks),np.array(test_imgs),np.array(test_masks) 
train_imgs = np.expand_dims(train_imgs, axis=1)
train_masks = np.expand_dims(train_masks, axis=1)
test_imgs = np.expand_dims(test_imgs, axis=1)
test_masks = np.expand_dims(test_masks, axis=1)


In [6]:
train_imgs, train_masks = shuffle_training(train_imgs, train_masks)

In [None]:
# create validation set
num_val = int(0.1*train_imgs.shape[0])

print('Number of images for validation: '+ str(num_val))

val_imgs = train_imgs[-num_val:,:,:]
val_masks = train_masks[-num_val:,:,:]
train_imgs = train_imgs[:-num_val,:,:]   # actual training
train_masks = train_masks[:-num_val,:,:]   # actual training

### Verify dimensionality


In [None]:
print('Size of training data:   ', train_imgs.shape)
print('Size of validation data: ', val_imgs.shape)
print('Size of testing data:    ', test_imgs.shape)

num_labels = np.unique(train_masks[0:5,:])
print('The unique mask labels: ', num_labels)

In [9]:
#viewer = napari.view_image(train_imgs[::3,:], colormap='gray', name='Input')
#viewer.add_labels(train_masks[::3,:], name='Target')

### Prep data for PyTorch ingestion with DataLoaders

PyTorch dataloaders are popular structures for passing batches of data to GPU

- increase batch_size_train to load GPU with more data
- with such large images (2000x2000), U-Nets can handle bigger batch sizes

In [9]:
def make_loaders(train_data, val_data, test_data, 
                batch_size_train, batch_size_val, batch_size_test):
    # can adjust the batch size depending on available memory
    train_loader_params = {'batch_size': batch_size_train,
                     'shuffle': True,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    val_loader_params = {'batch_size': batch_size_val,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    test_loader_params = {'batch_size': batch_size_test,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    # Finally, train/val/test loaders are created

    train_loader = DataLoader(train_data, **train_loader_params)
    val_loader = DataLoader(val_data, **val_loader_params)
    test_loader = DataLoader(test_data, **test_loader_params)
    
    return train_loader, val_loader, test_loader

 Data Augmentation if required

In [None]:
labeled_imgs = torch.Tensor(train_imgs)
labeled_masks = torch.Tensor(train_masks)
# Data augmentation
rotated_imgs1 = torch.rot90(labeled_imgs, 1, [2, 3])
rotated_masks1 = torch.rot90(labeled_masks, 1, [2, 3])

#rotated_imgs2 = torch.rot90(labeled_imgs, 2, [2, 3])
#rotated_masks2 = torch.rot90(labeled_masks, 2, [2, 3])

#rotated_imgs3 = torch.rot90(labeled_imgs, 3, [2, 3])

#rotated_masks3 = torch.rot90(labeled_masks, 3, [2, 3])

#flipped_imgs1 = torch.flip(labeled_imgs, [2])
#flipped_masks1 = torch.flip(labeled_masks, [2])

#flipped_imgs2 = torch.flip(labeled_imgs, [3])
#flipped_masks2 = torch.flip(labeled_masks, [3])

#flipped_imgs3 = torch.flip(labeled_imgs, [2,3])
#flipped_masks3 = torch.flip(labeled_masks, [2,3])


labeled_imgs = torch.cat((labeled_imgs, rotated_imgs1),0)
labeled_masks = torch.cat((labeled_masks, rotated_masks1),0)

#labeled_imgs = torch.cat((labeled_imgs, rotated_imgs2),0)
#labeled_masks = torch.cat((labeled_masks, rotated_masks2),0)

#labeled_imgs = torch.cat((labeled_imgs, rotated_imgs3),0)
#labeled_masks = torch.cat((labeled_masks, rotated_masks3),0)

#labeled_imgs = torch.cat((labeled_imgs, flipped_imgs1),0)
#labeled_masks = torch.cat((labeled_masks, flipped_masks1),0)

#labeled_imgs = torch.cat((labeled_imgs, flipped_imgs2),0)
#labeled_masks = torch.cat((labeled_masks, flipped_masks2),0)

#labeled_imgs = torch.cat((labeled_imgs, flipped_imgs3),0)
#labeled_masks = torch.cat((labeled_masks, flipped_masks3),0)

print('Shape of augmented data:    ', labeled_imgs.shape, labeled_masks.shape)


In [11]:
# Get data in pytorch Dataset format

labeled_imgs = torch.Tensor(train_imgs)
labeled_masks = torch.Tensor(train_masks)

val_imgs = torch.Tensor(val_imgs)
val_masks = torch.Tensor(val_masks)

train_data = TensorDataset(torch.Tensor(labeled_imgs), torch.Tensor(labeled_masks))
val_data = TensorDataset(torch.Tensor(val_imgs), torch.Tensor(val_masks))
test_data = TensorDataset(torch.Tensor(test_imgs))
# create data loaders
num_workers = 0   # 1 or 2 work better with CPU, 0 best for GPU

batch_size_train = 1
batch_size_val = 1
batch_size_test = 1

train_loader, val_loader, test_loader = make_loaders(train_data,
                                                    val_data,
                                                    test_data,
                                                    batch_size_train, 
                                                     batch_size_val, 
                                                     batch_size_test)

### Initialize MSDNet and TUNets

In [14]:

in_channels = 1
out_channels = len(num_labels)
num_layers = 50             
layer_width = 1 
max_dilation = 20      
activation = nn.ReLU()
normalization = nn.BatchNorm2d
final_layer = None

In [11]:
# MSDNet

in_channels = 1
out_channels = len(num_labels)
num_layers = 40          
layer_width = 1
max_dilation = 15      
activation = nn.ReLU()
normalization = nn.BatchNorm2d
final_layer = None

msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = in_channels,
                                    out_channels = out_channels, 
                                    num_layers=num_layers, 
                                    layer_width=layer_width,
                                    max_dilation = max_dilation, 
                                    activation=activation,
                                    normalization=normalization,
                                    convolution=nn.Conv2d
                                   )

print('Number of parameters: ', helpers.count_parameters(msdnet))

Number of parameters:  5333


In [12]:
msdnet = MSDNet.MixedScaleDenseNetwork(in_channels = 1,
                                out_channels = len(num_labels),
                                  num_layers=45,
                                  layer_width=None,
                                  max_dilation=None,
                                  custom_MSDNet=np.array([1, 2, 4, 8, 16,32])
                                  )
print('Number of parameters: ', helpers.count_parameters(msdnet))

Number of parameters:  9548


In [None]:
# TUNet 3

depth = 4
base_channels = 64
growth_rate = 2
hidden_rate = 1

tunet3 = TUNet.TUNet(image_shape=(train_imgs.shape[2:4]),
            in_channels=in_channels,
            out_channels=out_channels,
            depth=depth,
            base_channels=base_channels,
            #normalization=None,
            growth_rate=growth_rate,
            hidden_rate=hidden_rate
            )

print('Number of parameters: ', helpers.count_parameters(tunet3))

### Train networks

In [None]:

device = helpers.get_device()
device = "cuda:1"
epochs = 100   # Set number of epochs

criterion = nn.CrossEntropyLoss()   # For segmenting >2 classes

LEARNING_RATE = 1e-2
optimizer_msd = optim.Adam(msdnet.parameters(), lr=LEARNING_RATE)
optimizer_tunet3 = optim.Adam(tunet3.parameters(), lr=LEARNING_RATE)

print('Device we will compute on: ', device)   # cuda:0 for GPU. Else, CPU

In [14]:
newds_path = 'InitialMSDNetVsTUNet_newmasks_Results'
if os.path.isdir(newds_path) is False:
    os.mkdir(newds_path)
    
model_msdnet = '/msdnet'
model_tunet3 = '/tunet3'

In [None]:
# Train MSDNet

msdnet.to(device)   # send network to GPU

main_dir = newds_path + model_msdnet
if os.path.isdir(main_dir) is False:
    os.mkdir(main_dir)
    
    
stepsPerEpoch = np.ceil(train_imgs.shape[0]/batch_size_train)
num_steps_down = 2
scheduler = optim.lr_scheduler.StepLR(optimizer_msd,
                                 step_size=int(stepsPerEpoch*(epochs/num_steps_down)),
                                 gamma = 0.1,verbose=False)



msdnet, results = train_scripts.train_segmentation(
    msdnet,train_loader, val_loader, epochs, 
    criterion, optimizer_msd, device,saveevery=3,scheduler=scheduler,savepath=main_dir,show=1)   # training happens here

# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()


torch.save(msdnet.state_dict(), main_dir + '/net')
np.save(main_dir + '/results.npy', results)

plt.figure(figsize=(10,4))
plt.rcParams.update({'font.size': 16})

plt.plot(results['Training loss'], linewidth=2, label='training')
plt.plot(results['Validation loss'], linewidth=2, label='validation')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('MSDNet with ReLU and BatchNorm')
plt.legend()
plt.tight_layout()
plt.savefig(main_dir + '/losses')
plt.show()

In [None]:
batch_size_train = 1
batch_size_val = 1
batch_size_test = 1

tunet3.to(device)   # send network to GPU
# clear out unnecessary variables from device (GPU) memory
#torch.cuda.empty_cache()

main_dir = newds_path + model_tunet3
if os.path.isdir(main_dir) is False:
    os.mkdir(main_dir)

stepsPerEpoch = np.ceil(train_imgs.shape[0]/batch_size_train)
num_steps_down = 2
scheduler = optim.lr_scheduler.StepLR(optimizer_tunet3,
                                 step_size=int(stepsPerEpoch*(epochs/num_steps_down)),
                                 gamma = 0.1,verbose=False)


tunet3, results = train_scripts.train_segmentation(
    tunet3,train_loader, val_loader, epochs, 
    criterion, optimizer_tunet3, device,saveevery=3,scheduler=scheduler,savepath=main_dir,show=1)   # training happens here

# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
    
torch.save(tunet3.state_dict(), main_dir + '/net')
np.save(main_dir + '/results.npy', results)

plt.figure(figsize=(10,4))
plt.rcParams.update({'font.size': 16})

plt.plot(results['Training loss'], linewidth=2, label='training')
plt.plot(results['Validation loss'], linewidth=2, label='validation')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('TUnet with ReLU and BatchNorm')
plt.legend()
plt.tight_layout()
plt.savefig(main_dir + '/losses')
plt.show()

In [22]:
params = {'image_shape': train_imgs.shape[2:4], 'in_channels': in_channels, 'out_channels': out_channels, 'depth': depth, 'base_channels': base_channels, 'growth_rate': growth_rate, 'hidden_rate': hidden_rate},
np.save(main_dir+'/params.npy',params)

### Segment testing data

In [25]:
def regression_metrics( preds, target):
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)


def segment_imgs(testloader, net):
    torch.cuda.empty_cache()
    
    seg_imgs = []
    noisy_imgs = []
    
    counter = 0
    with torch.no_grad():
        for batch in testloader:
            noisy = batch
            noisy = noisy[0]
            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)#.unsqueeze(1)
            output = net(noisy)
            # Compute Pearson Correlation
            #tmp =  regression_metrics(output, target)
            #running_CC_test_val += tmp.item()

            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
                #target_imgs = target.detach().cpu()
            else:
               
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
            counter+=1
    torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs

In [26]:
msdnet_output, noisy  = segment_imgs(test_loader, msdnet)
tunet3_output, noisy  = segment_imgs(test_loader, tunet3)

In [27]:
msdnet_output = torch.argmax(msdnet_output.cpu()[:,:,:,:].data, dim=1)
tunet3_output = torch.argmax(tunet3_output.cpu()[:,:,:,:].data, dim=1)

In [None]:
print(msdnet_output.size())
print(tunet3_output.size())
noisy = torch.squeeze(noisy,1)
print(noisy.size())

In [29]:
imwrite(newds_path + '/msdnet_output.tif', msdnet_output.numpy())
imwrite(newds_path + '/tunet3_output.tif', tunet3_output.numpy())
imwrite(newds_path + '/input.tif', noisy.numpy())

In [None]:
torch.save(msdnet.state_dict(), newds_path + '/msdnet')
torch.save(tunet3.state_dict(), newds_path + '/tunet3')

In [None]:
viewer = napari.view_image(noisy.cpu().numpy(), colormap='gray', name='Input')
viewer.add_labels(msdnet_output.numpy(), name='msdnet')
viewer.add_labels(tunet3_output.numpy(), name='tunet3')