In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from glob import glob
import tifffile
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import datetime
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
from tqdm import tqdm
torch.backends.cudnn.benchmark = True

### Creating a Dataset class

In [2]:
class EmbryoNucleiDataset(Dataset):
    def __init__(self,
                 root_dir,
                 crop_size,
                ):
        
        # using root_dir, split and mask create a path to files and sort it 
        self.mask_files = sorted(glob(os.path.join(root_dir, 'masks', 'masks*.tif'))) # load mask files into sorted list
        self.raw_files = sorted(glob(os.path.join(root_dir, 'raw_files', 'raw*.tif'))) # load image files into sorted list
        self.crop_size = crop_size
        self.batch_size = 32
    
    def __len__(self):
        return len(self.raw_files)
        #return 10

    def get_centroids(self, mask):
        ids = np.unique(mask)
        ids = ids[1:]
        centroids = []
        for id in ids:
            y,x = np.where(mask == id)
            xm = int(np.mean(x))
            ym = int(np.mean(y))
            centroid = (ym, xm)
            centroids.append(centroid)
        
        centroids = np.array(centroids)
        return centroids
    
    def crop_top_left(self, coord): 
        y, x = coord
        y_top_left = int(y-(self.crop_size//2))
        x_top_left = int(x-(self.crop_size//2))
        return y_top_left, x_top_left
    
    def get_masked_crops(self, raw, mask):
        crops_raw = []
        crops_mask = []
        centroids = self.get_centroids(mask)
        
        while len(crops_mask) < self.batch_size:
            centroid = random.choice(centroids)
            y_top_left, x_top_left = self.crop_top_left(centroid)
            crop_mask = mask[y_top_left:y_top_left+self.crop_size, x_top_left:x_top_left+self.crop_size]
            if crop_mask.shape == (self.crop_size, self.crop_size):
                crop_mask = (crop_mask == crop_mask[int(self.crop_size//2), int(self.crop_size//2)])
                crop_raw = raw[y_top_left:y_top_left+self.crop_size, x_top_left:x_top_left+self.crop_size]
                crops_raw.append(crop_raw*crop_mask)
                crops_mask.append(crop_mask)
        return np.array(crops_raw), np.array(crops_mask)
    
    def __getitem__(self, idx):
        num_objects = 0
        while num_objects<self.batch_size:
            idx = np.random.randint(len(self.raw_files))
            mask_file = self.mask_files[idx] 
            mask = tifffile.imread(mask_file) 
            ids = np.unique(mask)
            ids = ids[ids!=0] # skip b.g.
            num_objects = len(ids)
            
        #print(f"Current index is {idx}")
        raw_file = self.raw_files[idx] 
        #mask_file = self.mask_files[idx] 
        #print(f"Crops are being extracted from {raw_file} file currently")
        raw = tifffile.imread(raw_file) # load raw to numpy array
        #mask = tifffile.imread(mask_file) # load mask to numpy array
        
        # from (H, W) mask extract (B, h, h)
        crops_raw, crops_mask = self.get_masked_crops(raw, mask)
        #print(f"Crops raw have shape {crops_raw.shape}. Crops Mask have shape {crops_mask.shape}")
        # need to cast to float32
        
        crops_mask = (crops_mask !=0).astype(np.float32)
        crops_raw = (crops_raw.astype(np.float32))/65535
        
        # add channel dimensions to comply with pytorch standard (B, C, H, W) 
        crops_raw = np.expand_dims(crops_raw, axis=1)
        crops_mask = np.expand_dims(crops_mask, axis=1)
        
        return crops_raw, crops_mask

### Creating Autoencoder

In [3]:
class Autoencoder(torch.nn.Module):
    def __init__(
            self,
            in_channels,
            downsampling_factors,
            fmaps,
            fmul,
            kernel_size=3):

        super(Autoencoder, self).__init__()

        out_channels = in_channels

        encoder = []

        for downsampling_factor in downsampling_factors:

            encoder.append(
                    torch.nn.Conv2d(
                        in_channels,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.Conv2d(
                        fmaps,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.MaxPool2d(downsampling_factor))

            in_channels = fmaps

            fmaps = fmaps * fmul

        fmaps_bottle = fmaps

        encoder.append(
            torch.nn.Conv2d(
                in_channels,
                fmaps_bottle,
                kernel_size))
        encoder.append(
            torch.nn.ReLU(inplace=True))

        self.encoder = torch.nn.Sequential(*encoder)

        decoder = []

        fmaps = in_channels

        decoder.append(
            torch.nn.Conv2d(
                fmaps_bottle,
                fmaps,
                kernel_size))
        decoder.append(
            torch.nn.ReLU(inplace=True))

        for downsampling_factor in downsampling_factors[::-1]:

            fmaps = in_channels // fmul

            decoder.append(
                torch.nn.Upsample(
                    scale_factor=downsampling_factor,
                    mode='bilinear'))
            decoder.append(
                torch.nn.Conv2d(
                    in_channels,
                    fmaps,
                    kernel_size))
            decoder.append(
                torch.nn.ReLU(inplace=True))
            decoder.append(
                torch.nn.Conv2d(
                    fmaps,
                    fmaps,
                    kernel_size))
            decoder.append(
                torch.nn.ReLU(inplace=True))

            in_channels = fmaps

        decoder.append(
            torch.nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size))

        self.decoder = torch.nn.Sequential(*decoder)

    def forward(self, x):

        enc = self.encoder(x)

        dec = self.decoder(enc)

        return enc, dec
        


### Training Time ! 

In [4]:
# identifying params for training
batch_size = 64
crop_size = 156
num_epochs = 50
model_depth = 1
downsampling_factor = 2
root_dir = '/mnt/efs/shared_data/instance_no_gt/20230830_TIF_cellpose_test/'
assert torch.cuda.is_available()
device = torch.device("cuda")

In [5]:
model = Autoencoder(in_channels=1, downsampling_factors=[downsampling_factor]*model_depth, fmaps=32, fmul=2, kernel_size = 3).to(device)
summary(model, (1, 156, 156))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 154, 154]             320
              ReLU-2         [-1, 32, 154, 154]               0
            Conv2d-3         [-1, 32, 152, 152]           9,248
              ReLU-4         [-1, 32, 152, 152]               0
         MaxPool2d-5           [-1, 32, 76, 76]               0
            Conv2d-6           [-1, 64, 74, 74]          18,496
              ReLU-7           [-1, 64, 74, 74]               0
            Conv2d-8           [-1, 32, 72, 72]          18,464
              ReLU-9           [-1, 32, 72, 72]               0
         Upsample-10         [-1, 32, 144, 144]               0
           Conv2d-11         [-1, 16, 142, 142]           4,624
             ReLU-12         [-1, 16, 142, 142]               0
           Conv2d-13         [-1, 16, 140, 140]           2,320
             ReLU-14         [-1, 16, 1

In [6]:
# create a logdir for each run and a corresponding summary writer
logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
writer = SummaryWriter(logdir)

In [7]:
def train():
    # create train dataset
    dataset = EmbryoNucleiDataset(root_dir, crop_size)

    # create train dataloader
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True)

    # create model
    model = Autoencoder(in_channels=1, downsampling_factors=[downsampling_factor]*model_depth,
        fmaps=32, fmul=2, kernel_size = 3)

    # create loss object
    loss_function = torch.nn.MSELoss()

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in tqdm(range(num_epochs)):
        train_epoch(dataloader, model, epoch, optimizer, loss_function)

def train_epoch(dataloader, model, epoch, optimizer, loss_function, log_image_interval = 20):
    print(epoch)
    model.train()
    model = model.to(device)
    loss_list = []  
    
    for batch_id, (raw, mask) in enumerate(dataloader):
        #raw = torch.from_numpy(raw) # convert to torch tensor
        raw = raw.to(device)[0] # move to GPU
        #print(f"raw shape {raw.shape}")
        optimizer.zero_grad()
        # apply model and calculate loss
        _, prediction = model(raw)
        reduction = raw.shape[2] - prediction.shape[2]
        raw = raw[:, :, reduction//2:-reduction//2, reduction//2:-reduction//2]
        loss = loss_function(prediction, raw)
        #print(loss.item())
        #writer.add_scalar('loss',loss.item(), batch_id)
        loss_list.append(loss.item())
        # backpropagate the loss and adjust the parameters
        loss.backward()
        optimizer.step()
        #print(f"Len dataset is {len(dataloader)}")
        step = epoch * len(dataloader) + batch_id
        if step % log_image_interval == 0:
            writer.add_images(
                tag="input", img_tensor=raw[16:17].to("cpu"), global_step=step
            )
            writer.add_images(
                tag="prediction",
                img_tensor=prediction[16:17].to("cpu").detach(),
                global_step=step,
            )
    loss_list = np.array(loss_list)
    print(f"Loss at Epoch {epoch} is {loss_list.mean()}")
    writer.add_scalar('loss',(loss.cpu().detach().numpy()) *0.001, epoch)

In [8]:
train() # add normalize? tensorboard? train for longer

  0%|                                                                                                         | 0/50 [00:00<?, ?it/s]

0


  2%|█▉                                                                                               | 1/50 [00:42<34:43, 42.53s/it]

Loss at Epoch 0 is 0.0011134577820484993
1


  4%|███▉                                                                                             | 2/50 [01:30<36:39, 45.81s/it]

Loss at Epoch 1 is 0.00022380693517334293
2


  6%|█████▊                                                                                           | 3/50 [02:16<35:58, 45.93s/it]

Loss at Epoch 2 is 0.00043085340093966805
3


  6%|█████▊                                                                                           | 3/50 [02:23<37:28, 47.84s/it]


KeyboardInterrupt: 

In [None]:
# To view runs in tensorboard you can call either (uncommented):
%reload_ext tensorboard
!tensorboard --logdir logs --port 6009

In [None]:
[2]*5
