In [218]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from glob import glob
import tifffile
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import datetime
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm
torch.backends.cudnn.benchmark = True

### Creating a Dataset class

In [219]:
class EmbryoNucleiDataset(Dataset):
    def __init__(self,
                 root_dir,
                 epoch_size
                ):
        
        # using root_dir, split and mask create a path to files and sort it 
        self.mask_files = sorted(glob(os.path.join(root_dir, 'cropped_masks', '*.tif'))) # load mask files into sorted list
        self.raw_files = sorted(glob(os.path.join(root_dir, 'cropped_rawfiles', '*.tif'))) # load image files into sorted list
        self.epoch_size = epoch_size
    
    def __len__(self):
        #return len(self.raw_files)
        return self.epoch_size

    def __getitem__(self, idx):
        idx = np.random.randint(len(self.raw_files))
        raw_file = self.raw_files[idx] 
        mask_file = self.mask_files[idx] 
        crops_raw = tifffile.imread(raw_file) # load raw to numpy array
        crops_mask = tifffile.imread(mask_file) # load mask to numpy array
        crops_mask = (crops_mask !=0).astype(np.float32)
        crops_raw = ((crops_raw.astype(np.float32))/65535) * crops_mask
        
        # add channel dimensions to comply with pytorch standard (B, C, H, W) 
        crops_raw = np.expand_dims(crops_raw, axis=0)
        crops_mask = np.expand_dims(crops_mask, axis=0)
        
        return crops_raw, crops_mask

### Creating Autoencoder

In [220]:
class Autoencoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            downsampling_factors,
            fmaps,
            fmul,
            fmaps_bottle = 'default',
            kernel_size=3):

        super(Autoencoder, self).__init__()

        out_channels = in_channels

        encoder = []

        for downsampling_factor in downsampling_factors:

            encoder.append(
                    torch.nn.Conv2d(
                        in_channels,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.Conv2d(
                        fmaps,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.MaxPool2d(downsampling_factor))

            in_channels = fmaps

            fmaps = fmaps * fmul

        if fmaps_bottle == 'default':
            fmaps_bottle = fmaps
        
        encoder.append(
            torch.nn.Conv2d(
                in_channels,
                fmaps_bottle,
                kernel_size))
        encoder.append(
            torch.nn.ReLU(inplace=True))

        self.encoder = torch.nn.Sequential(*encoder)

        decoder = []

        fmaps = in_channels

        decoder.append(
            torch.nn.Conv2d(
                fmaps_bottle,
                fmaps,
                kernel_size))
        decoder.append(
            torch.nn.ReLU(inplace=True))

        for idx, downsampling_factor in enumerate(downsampling_factors[::-1]):

            decoder.append(
                torch.nn.Upsample(
                    scale_factor=downsampling_factor,
                    mode='bilinear'))

            in_channels = fmaps
            
            decoder.append(
                torch.nn.Conv2d(
                    in_channels,
                    fmaps,
                    kernel_size))
            decoder.append(
                torch.nn.ReLU(inplace=True))
            if idx < len(downsampling_factors) - 1:
                fmaps = in_channels // fmul
                decoder.append(
                    torch.nn.Conv2d(
                        in_channels,
                        fmaps,
                        kernel_size))
                decoder.append(
                    torch.nn.ReLU(inplace=True))

            else:
                decoder.append(
                    torch.nn.Conv2d(
                        in_channels,
                        out_channels,
                        kernel_size))

        self.decoder = torch.nn.Sequential(*decoder)

    def forward(self, x):

        enc = self.encoder(x)

        dec = self.decoder(enc)

        return enc, dec
        


### Create training function

In [221]:
def train(batch_size,num_epochs,epoch_size):
    # create train dataset
    dataset = EmbryoNucleiDataset(root_dir,epoch_size)

    # create train dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    # create model
    model = Autoencoder(in_channels=1, downsampling_factors=[downsampling_factor]*model_depth,
        fmaps=32, fmul=2, kernel_size = 3)

    # create loss object
    loss_function = torch.nn.MSELoss()
    #loss_function = torch.nn.L1Loss()

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in tqdm(range(num_epochs), position=0, leave=True):
        train_epoch(dataloader, model, epoch, optimizer, loss_function)

def train_epoch(dataloader, model, epoch, optimizer, loss_function, log_image_interval = 20):
    model.train()
    model = model.to(device)
    loss_list = []
    
    for batch_id, (raw, mask) in enumerate(dataloader):
        raw = raw.to(device) # move to GPU
        optimizer.zero_grad()
        _, prediction = model(raw)
        reduction = raw.shape[2] - prediction.shape[2]
        raw = raw[:, :, reduction//2:-reduction//2, reduction//2:-reduction//2]
        loss = loss_function(prediction, raw)
        step = epoch * len(dataloader) + batch_id
        writer.add_scalar('train loss',loss.item(), step)
        loss_list.append(loss.item())
        loss.backward()
        optimizer.step()
        
        if step % log_image_interval == 0:
            writer.add_images(
                tag="input", img_tensor=raw.to("cpu"), global_step=step
            )
            writer.add_images(
                tag="prediction",
                img_tensor=prediction.to("cpu").detach(),
                global_step=step,
            )
    loss_list = np.array(loss_list)
    print(f"Loss at Epoch {epoch} is {loss_list.mean()}")

### Training Time ! 

In [222]:
# identifying params for training
batch_size = 64
crop_size = 156
num_epochs = 10
epoch_size = 5000
root_dir = '/mnt/efs/shared_data/instance_no_gt/20230830_TIF_cellpose_test/'
assert torch.cuda.is_available()
device = torch.device("cuda")

In [223]:
model_depth = 1
downsampling_factor = 2
downsampling_factors = [downsampling_factor]*model_depth
fmaps = 2
fmul = 2
fmaps_bottle = 'default'
kernel_size = 3
loss = 'MSE'

model = Autoencoder(in_channels=1, downsampling_factors=downsampling_factors, fmaps=fmaps,
                    fmul=fmul, fmaps_bottle = fmaps_bottle, kernel_size = kernel_size).to(device)
summary(model, (1, 156, 156))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 2, 154, 154]              20
              ReLU-2          [-1, 2, 154, 154]               0
            Conv2d-3          [-1, 2, 152, 152]              38
              ReLU-4          [-1, 2, 152, 152]               0
         MaxPool2d-5            [-1, 2, 76, 76]               0
            Conv2d-6            [-1, 4, 74, 74]              76
              ReLU-7            [-1, 4, 74, 74]               0
            Conv2d-8            [-1, 2, 72, 72]              74
              ReLU-9            [-1, 2, 72, 72]               0
         Upsample-10          [-1, 2, 144, 144]               0
           Conv2d-11          [-1, 2, 142, 142]              38
             ReLU-12          [-1, 2, 142, 142]               0
           Conv2d-13          [-1, 1, 140, 140]              19
Total params: 265
Trainable params: 265

In [224]:
# create a logdir for each run and a corresponding summary writer
train_identifier = f'FINALTESTautoencoder_downsamplingfactors_{downsampling_factors}__fmaps_{fmaps}__fmul_{fmul}__fmapsbottle_{fmaps_bottle}__kernelsize_{kernel_size}__loss_{loss}'
logdir = os.path.join("logs", f'{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}_{train_identifier}')
writer = SummaryWriter(logdir)

In [None]:
train(batch_size=batch_size,num_epochs=num_epochs,epoch_size=epoch_size) # tensorboard? train for longer

 10%|████▍                                       | 1/10 [00:47<07:08, 47.65s/it]

Loss at Epoch 0 is 0.0002143297482044672


 20%|████████▊                                   | 2/10 [01:35<06:22, 47.86s/it]

Loss at Epoch 1 is 3.602980916783261e-05


 30%|█████████████▏                              | 3/10 [02:23<05:34, 47.75s/it]

Loss at Epoch 2 is 2.3179328343735216e-05


 40%|█████████████████▌                          | 4/10 [03:10<04:44, 47.35s/it]

Loss at Epoch 3 is 1.5523930299803142e-05


In [13]:
# To view runs in tensorboard you can call either (uncommented):
%reload_ext tensorboard
!tensorboard --logdir logs --port 6009

TensorFlow installation not found - running with reduced feature set.
/home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.29' not found (required by /home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server)
/home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.33' not found (required by /home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server)
/home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server: /lib/x86_64-linux-gnu/libc.so.6: version `GLIBC_2.28' not found (required by /home/evan/conda/envs/06_instance_segmentation/lib/python3.8/site-packages/tensorboard_data_server/bin/server)
/home/evan/conda/envs/06_instance

In [23]:
# Saving the model weights
state = model.state_dict()
filename = root_dir+'models/'+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")+'.pt'
torch.save(state, filename)



In [None]:
# Saving the latent space

In [15]:
# To test: 
# Model depth, L1 loss

In [None]:
# To calculate: 
# IOU (segmentation performance), Pearson (reconstruction)

In [None]:
# UMAP 

In [None]:
# MOBIE 