In [None]:
source = '/data/FIBSEM/testing_scripts/sourceFile'
mrc_file = 'TS_30_032123_tomo_bin4_flipped.mrc'
annotation_masks = f'{source}/annotations/masks'
annotation_imgs = f'{source}/annotations/images'
cryo_flag = True

#### Convert mrc to jpeg files. User mrc2tif or 3dmod. The images should be placed in source folder with name images.
NOTE : predictions with NAD filtered images have shown better results. Use Etomo to create NAD mrc files from the source file and change the path to NAD filtered mrc. 

In [None]:
path = source + '/' +mrc_file
!mkdir {source}/images
!mrc2tif -j {path}  {source}/images/zap

### Install required packages
#### Qlty2d is used to create smaller slices of images.  https://qlty.readthedocs.io/en/latest/
#### dlsia developed by LBNL is used to create custom MSDNETs and UNETs. https://dlsia.readthedocs.io/en/latest/index.html

In [None]:
!pip install numpy
!python3 -m pip install --upgrade Pillow
!pip install tiffile
!pip install qlty
!pip install opencv-python

In [None]:
!git clone https://github.com/phzwart/dlsia.git
!cd dlsia && pip install -e .

#### Pick 15-20 images for annotations. Use apeer or napari to generate annotations for the choosen images. 
The annotations should be placed in source folder with following path.
-  {source}/annotations/images 
-  {source}/annotations/masks 

In [None]:
import os
import cv2
import glob
import torch
import random
import numpy as np
import collections
from PIL import Image
import torch.nn as nn
from skimage import exposure
import matplotlib.pyplot as plt
from tifffile import imread, imwrite
import torch.optim as optim

from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, Dataset
from dlsia.core import helpers, train_scripts, corcoef
from dlsia.core.networks import msdnet, tunet


### For training the network the required format for images and masks are image.tiff and masks.tiff. a single tiff file for each with same sequence of images is required which is to be placed in annotations folder

Apeer generates different files for each mask. Use the below code to generate images.tiff and masks.tiff

### Use the below code to rename masks file with same naming convention to images. 

In [None]:
for count, filename in enumerate(os.listdir(annotation_masks)):
    newfname = filename.replace("_filament.ome", "")
    os.rename(annotation_masks+'/'+filename,annotation_masks+'/'+newfname)

In [None]:
training_imgs,training_masks = [], []

for file in glob.glob(annotation_masks+"/*.tiff"):training_masks.append(file)
for file in glob.glob(annotation_imgs+"/*.jpg"): training_imgs.append(file)
    
training_imgs = sorted(training_imgs)
training_masks = sorted(training_masks)

train_imgs = []
for j in range(len(training_imgs)):
    img = Image.open(training_imgs[j])
    img.load()
    img = np.array(img, dtype='float32')
    if len(img.shape)==3: img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    train_imgs.append(img)

train_masks = []
for j in range(len(training_masks)):
    img = Image.open(training_masks[j])
    img.load()
    img = np.array(img, dtype='uint8')
    img[img!=0] = 1
    train_masks.append(img)
 
train_imgs, train_masks =np.array(train_imgs), np.array(train_masks)
print(train_imgs.shape, train_imgs.dtype)
print(train_masks.shape, train_masks.dtype)
imwrite(f'{source}/annotations/train.tif', train_imgs)
imwrite(f'{source}/annotations/masks.tif', train_masks)

### Use the below code for annotation masks where each class has one mask file.  if all the classes are present in single mask file use the above code. 
List down the classes that are annotated in the masks. This is used to combine multiple annotation files to a single image mask.

In [None]:
class_mapper = {'ribosome':1,'tube':2,'mem':3,'filament':4}

In [None]:
mapper= collections.defaultdict(list)

for count, filename in enumerate(os.listdir(annotation_masks)):
    mapper[filename.split('_')[-1]].append(filename)

train_masks = []
train_imgs = []

for k,v in mapper.items():
    x = np.array([])
    
    for name in v:
        img = Image.open(f'{annotation_masks}/{name}')
        img.load()
        img= np.array(img, dtype='uint8')
        
        if not x.any(): x=img
        for class_type, label in class_mapper.items():
            if class_type in name:x[img!=0]=label 
    
    train_masks.append(x)
    
    img = Image.open(f'{annotation_imgs}/{k}.jpg')
    img.load()
    img = np.array(img, dtype='float32')
    if len(img.shape)==3: img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    train_imgs.append(img)
    
train_imgs, train_masks =np.array(train_imgs), np.array(train_masks)
imwrite(f'{source}/annotations/train.tif', train_imgs)
imwrite(f'{source}/annotations/masks.tif', train_masks)

### Use napari or any visualization to verify the generated tiff for image and masks alignment. 

In [None]:
train_imgs = np.expand_dims(train_imgs, axis=1)
train_masks = np.expand_dims(train_masks, axis=1)

def shuffle_training(imgs, masks, seed=123):
    x = np.arange(imgs.shape[0])
    random.seed(seed)
    random.shuffle(x)
    return imgs[x,:], masks[x,:]

train_imgs, train_masks = shuffle_training(train_imgs, train_masks)
print(train_imgs.shape, train_imgs.dtype)
print(train_masks.shape, train_masks.dtype)

In [None]:
if cryo_flag == True:
    import qlty
    from qlty import qlty2D

    quilt = qlty2D.NCYXQuilt(X=train_imgs.shape[3],
                            Y=train_imgs.shape[2],
                            window=(256,256),
                            step=(64,64),
                            border=(10,10),
                            border_weight=0)

    labeled_imgs = torch.Tensor(train_imgs)
    labeled_masks = torch.Tensor(train_masks)
    labeled_imgs, labeled_masks = quilt.unstitch_data_pair(labeled_imgs,labeled_masks)

    print("x shape: ",train_imgs.shape)
    print("y shape: ",train_masks.shape)
    print("x_bits shape:", labeled_imgs.shape)
    print("y_bits shape:", labeled_masks.shape)

#### the below code adds additional annotations from different datasets. 
input is a list of folders with same convention of annotations(images and masks folders)

In [None]:
additional_dirs=['/data/FIBSEM/cryo/filament_Cryo/reannotate_masks','/data/FIBSEM/testing_scripts/sourceFile']

for dirs in  additional_dirs: 
    add_imgs,add_masks= [],[]
    imgs,masks = imread(dirs + '/train.tif'),imread(dirs + '/masks.tif')
    
    for i, img in enumerate(imgs):
        add_imgs.append(img.astype(np.float32))
        add_masks.append(masks[i])

    add_imgs,add_masks=np.array(add_imgs),np.array(add_masks)
    add_imgs,add_masks = np.expand_dims(add_imgs, axis=1),np.expand_dims(add_masks, axis=1)
    add_imgs,add_masks = shuffle_training(add_imgs, add_masks)

    labeled_actin_imgs,labeled_actin_masks  = torch.Tensor(add_imgs),torch.Tensor(add_masks)
    if cryo_flag == True:
        quilt = qlty2D.NCYXQuilt(X=add_imgs.shape[3],
                                Y=add_masks.shape[2],
                                window=(256,256),
                                step=(64,64),
                                border=(10,10),
                                border_weight=0)

        labeled_actin_imgs, labeled_actin_masks = quilt.unstitch_data_pair(labeled_actin_imgs,labeled_actin_masks)
    print("x shape: ",add_imgs.shape)
    print("y shape: ",add_masks.shape)
    print("x_bits shape:", labeled_actin_imgs.shape)
    print("y_bits shape:", labeled_actin_masks.shape)

    labeled_imgs, labeled_masks = torch.Tensor(np.vstack((labeled_imgs,labeled_actin_imgs))), torch.Tensor(np.vstack((labeled_masks,labeled_actin_masks)))
print("images shape:", labeled_imgs.shape)
print("masks shape:", labeled_masks.shape)

In [None]:
if cryo_flag == True:
    dicedImgs,dicedMasks = [],[]
    for i in range(len(labeled_imgs)):
        # comment this to include all slices even the non annotated slices. 
        if np.unique(labeled_masks[i][0]).shape[0] > 1:
            # bilateral filter
            bilateral = cv2.bilateralFilter(labeled_imgs[i][0].numpy(),5,50,10)
            # clahe equalization 
            clahe = cv2.createCLAHE(clipLimit=3)
            bilateral= bilateral.astype(np.uint16)
            final = clahe.apply(bilateral)
            # Equalize histogram 
            x = exposure.equalize_hist(final)
            dicedImgs.append(x.astype(np.float32))
            dicedMasks.append(labeled_masks[i][0].numpy())
            
    # verify random slice        
    sliceNum = 250
    plt.subplot(1, 2, 1)
    plt.imshow(dicedMasks[sliceNum],cmap='gray',interpolation='none')
    plt.subplot(1, 2, 2)
    plt.imshow(dicedImgs[sliceNum],cmap='gray',interpolation='none')
    plt.show()
    train_imgs,train_masks = np.array(dicedImgs),np.array(dicedMasks)
    train_imgs,train_masks = np.expand_dims(train_imgs, axis=1),np.expand_dims(train_masks, axis=1)

In [None]:
labeled_imgs, test_imgs, labeled_masks, test_masks = train_test_split(train_imgs, train_masks, test_size = 0.1, random_state = 0)

In [None]:
def make_loaders(train_data, val_data, test_data, 
                batch_size_train, batch_size_val, batch_size_test):
    
    # can adjust the batch size depending on available memory
    train_loader_params = {'batch_size': batch_size_train,
                     'shuffle': True,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    val_loader_params = {'batch_size': batch_size_val,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    test_loader_params = {'batch_size': batch_size_test,
                     'shuffle': False,
                     'num_workers': num_workers,
                     'pin_memory':True,
                     'drop_last': False}

    train_loader = DataLoader(train_data, **train_loader_params)
    val_loader = DataLoader(val_data, **val_loader_params)
    test_loader = DataLoader(test_data, **test_loader_params)
    
    return train_loader, val_loader, test_loader

### Use the below only to augment data. Helpful when training data annotations are not enough.

In [None]:
labeled_imgs = torch.Tensor(labeled_imgs)
labeled_masks = torch.Tensor(labeled_masks)

rotated_imgs1 = torch.rot90(labeled_imgs, 1, [2, 3])
rotated_masks1 = torch.rot90(labeled_masks, 1, [2, 3])

rotated_imgs2 = torch.rot90(labeled_imgs, 2, [2, 3])
rotated_masks2 = torch.rot90(labeled_masks, 2, [2, 3])

rotated_imgs3 = torch.rot90(labeled_imgs, 3, [2, 3])
rotated_masks3 = torch.rot90(labeled_masks, 3, [2, 3])

flipped_imgs1 = torch.flip(labeled_imgs, [2])
flipped_masks1 = torch.flip(labeled_masks, [2])

flipped_imgs2 = torch.flip(labeled_imgs, [3])
flipped_masks2 = torch.flip(labeled_masks, [3])

flipped_imgs3 = torch.flip(labeled_imgs, [2,3])
flipped_masks3 = torch.flip(labeled_masks, [2,3])


labeled_imgs = torch.cat((labeled_imgs, rotated_imgs1),0)
labeled_masks = torch.cat((labeled_masks, rotated_masks1),0)

labeled_imgs = torch.cat((labeled_imgs, rotated_imgs2),0)
labeled_masks = torch.cat((labeled_masks, rotated_masks2),0)

labeled_imgs = torch.cat((labeled_imgs, rotated_imgs3),0)
labeled_masks = torch.cat((labeled_masks, rotated_masks3),0)

labeled_imgs = torch.cat((labeled_imgs, flipped_imgs1),0)
labeled_masks = torch.cat((labeled_masks, flipped_masks1),0)

labeled_imgs = torch.cat((labeled_imgs, flipped_imgs2),0)
labeled_masks = torch.cat((labeled_masks, flipped_masks2),0)

labeled_imgs = torch.cat((labeled_imgs, flipped_imgs3),0)
labeled_masks = torch.cat((labeled_masks, flipped_masks3),0)

print('Shape of augmented data:    ', labeled_imgs.shape, labeled_masks.shape)

In [None]:
### Create validation set 

num_val = int(0.1*labeled_imgs.shape[0])
print('Number of images for validation: '+ str(num_val))
val_imgs = labeled_imgs[-num_val:,:,:]
val_masks = labeled_masks[-num_val:,:,:]
train_imgs = labeled_imgs[:-num_val,:,:]   # actual training
train_masks = labeled_masks[:-num_val,:,:]   # actual training

In [None]:
print('Size of training data:   ', train_imgs.shape)
print('Size of validation data: ', val_imgs.shape)
print('Size of testing data:    ', test_imgs.shape)

num_labels = np.unique(train_masks[200:400,:])
print('The unique mask labels: ', num_labels)

In [None]:
# Get data in pytorch Dataset format
train_data = TensorDataset(torch.Tensor(train_imgs), torch.Tensor(train_masks))
val_data = TensorDataset(torch.Tensor(val_imgs), torch.Tensor(val_masks))
test_data = TensorDataset(torch.Tensor(test_imgs), torch.Tensor(test_masks))

# create data loaders
num_workers = 0   # 1 or 2 work better with CPU, 0 best for GPU

# change batch size based on memory available 
batch_size_train = 1
batch_size_val = 1
batch_size_test = 1

train_loader, val_loader, test_loader = make_loaders(train_data,
                                                    val_data,
                                                    test_data,
                                                    batch_size_train, 
                                                    batch_size_val, 
                                                    batch_size_test)

In [None]:
# MSDNET 

in_channels = 1
out_channels = len(num_labels)
num_layers = 40          
layer_width = 1
max_dilation = 15      
activation = nn.ReLU()
normalization = nn.BatchNorm2d
final_layer = None

msd_net = msdnet.MixedScaleDenseNetwork(in_channels = in_channels,
                                    out_channels = out_channels, 
                                    num_layers=num_layers, 
                                    layer_width=layer_width,
                                    max_dilation = max_dilation, 
                                    activation=activation,
                                    normalization=normalization,
                                    convolution=nn.Conv2d
                                   )

print('Number of parameters: ', helpers.count_parameters(msd_net))

In [None]:
# TUNET 


depth = 4
base_channels = 64
growth_rate = 2
hidden_rate = 1
in_channels = 1
out_channels = len(num_labels)
num_layers = 40             
layer_width = 1 
max_dilation = 15 
normalization = nn.BatchNorm2d

tunet3 = tunet.TUNet(image_shape=(train_imgs.shape[2:4]),
            in_channels=in_channels,
            out_channels=out_channels,
            depth=depth,
            kernel_down=nn.AvgPool2d,
            base_channels=base_channels,
            normalization = nn.BatchNorm2d,
            growth_rate=growth_rate,
            hidden_rate=hidden_rate
            )

print('Number of parameters: ', helpers.count_parameters(tunet3))

In [None]:
device = helpers.get_device()
device = "cuda:1"
epochs = 60   # Set number of epochs

criterion = nn.CrossEntropyLoss()   # For segmenting >2 classes
LEARNING_RATE = 1e-2

optimizer_msd = optim.Adam(msd_net.parameters(), lr=LEARNING_RATE)
optimizer_tunet3 = optim.Adam(tunet3.parameters(), lr=LEARNING_RATE)

print('Device we will compute on: ', device)   # cuda:0 for GPU. Else, CPU

In [None]:
newds_path = source+'/Results'
if os.path.isdir(newds_path) is False: os.mkdir(newds_path)
    
model_msdnet = '/msdnet'
model_tunet3 = '/tunet3'


### Train MSDNET

In [None]:
msd_net.to(device)   # send network to GPU

main_dir = newds_path + model_msdnet
if os.path.isdir(main_dir) is False: os.mkdir(main_dir)
     
stepsPerEpoch = np.ceil(train_imgs.shape[0]/batch_size_train)
num_steps_down = 2
scheduler = optim.lr_scheduler.StepLR(optimizer_msd,
                                 step_size=int(stepsPerEpoch*(epochs/num_steps_down)),
                                 gamma = 0.1,verbose=False)

msd_net, results = train_scripts.train_segmentation(
            msd_net,train_loader, val_loader, epochs, 
            criterion, optimizer_msd, device,saveevery=3,scheduler=scheduler,savepath=main_dir,show=1)   # training happens here

# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
torch.save(msd_net.state_dict(), main_dir + '/net')

plt.figure(figsize=(10,4))
plt.rcParams.update({'font.size': 16})
plt.plot(results['Training loss'], linewidth=2, label='training')
plt.plot(results['Validation loss'], linewidth=2, label='validation')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('MSDNet with ReLU and BatchNorm')
plt.legend()
plt.tight_layout()
plt.savefig(main_dir + '/losses')
plt.show()

In [None]:
batch_size_train = 10
batch_size_val = 10
batch_size_test = 10

tunet3.to(device)   # send network to GPU
torch.cuda.empty_cache()

main_dir = newds_path + model_tunet3
if os.path.isdir(main_dir) is False: os.mkdir(main_dir)

stepsPerEpoch = np.ceil(train_imgs.shape[0]/batch_size_train)
num_steps_down = 2
scheduler = optim.lr_scheduler.StepLR(optimizer_tunet3,
                                 step_size=int(stepsPerEpoch*(epochs/num_steps_down)),
                                 gamma = 0.1,verbose=False)

tunet3, results = train_scripts.train_segmentation(
    tunet3,train_loader, val_loader, epochs, 
    criterion, optimizer_tunet3, device,saveevery=3,
    #scheduler=scheduler,
    show=1)   # training happens here

# clear out unnecessary variables from device (GPU) memory
torch.cuda.empty_cache()
    
torch.save(tunet3.state_dict(), main_dir + '/net')

plt.figure(figsize=(10,4))
plt.rcParams.update({'font.size': 16})
plt.plot(results['Training loss'], linewidth=2, label='training')
plt.plot(results['Validation loss'], linewidth=2, label='validation')
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('TUnet with ReLU and BatchNorm')
plt.legend()
plt.tight_layout()
plt.savefig(main_dir + '/losses')
plt.show()

In [None]:
torch.save(tunet3.state_dict(), main_dir + '/net')

In [None]:
params = {'image_shape': train_imgs.shape[2:4], 'in_channels': in_channels, 'out_channels': out_channels, 'depth': depth, 'base_channels': base_channels, 'growth_rate': growth_rate, 'hidden_rate': hidden_rate},
np.save(main_dir+'/params.npy',params)

In [None]:
def regression_metrics( preds, target):
    tmp = corcoef.cc(preds.cpu().flatten(), target.cpu().flatten() )
    return(tmp)

def segment_imgs(testloader, net):
    torch.cuda.empty_cache()
    
    seg_imgs = []
    noisy_imgs = [] 
    counter = 0
    
    with torch.no_grad():
        for batch in testloader:
            noisy = batch
            noisy = noisy[0]
            noisy = torch.FloatTensor(noisy)
            noisy = noisy.to(device)#.unsqueeze(1)
            output = net(noisy)
            if counter == 0:
                seg_imgs = output.detach().cpu()
                noisy_imgs = noisy.detach().cpu()
            else:
                seg_imgs = torch.cat((seg_imgs, output.detach().cpu()), 0)
                noisy_imgs = torch.cat((noisy_imgs, noisy.detach().cpu()), 0)
            counter+=1
    torch.cuda.empty_cache()
    return seg_imgs, noisy_imgs

In [None]:
msdnet_output, noisy  = segment_imgs(test_loader, msd_net)
msdnet_output = torch.argmax(msdnet_output.cpu()[:,:,:,:].data, dim=1)
print(msdnet_output.size())
noisy = torch.squeeze(noisy,1)
imwrite(newds_path + '/tunet3_output.tif', msdnet_output.numpy())
imwrite(newds_path + '/input.tif', noisy.numpy())

In [None]:
tunet3_output, noisy  = segment_imgs(test_loader, tunet3)
tunet3_output = torch.argmax(tunet3_output.cpu()[:,:,:,:].data, dim=1)
print(tunet3_output.size())
noisy = torch.squeeze(noisy,1)
imwrite(newds_path + '/tunet3_output.tif', tunet3_output.numpy())
imwrite(newds_path + '/input.tif', noisy.numpy())