In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from glob import glob
import tifffile
import torch
from torch.utils.data import Dataset

### Creating a Dataset class

In [4]:
class EmbryoNucleiDataset(Dataset):
    def __init__(self,
                 root_dir,
                 crop_size,
                ):
        
        # using root_dir, split and mask create a path to files and sort it 
        self.mask_files = natsorted(glob(os.path.join(root_dir, 'masks', 'masks*.tif'))) # load mask files into sorted list
        self.raw_files = natsorted(glob(os.path.join(root_dir, 'raw_files', 'raw*.tif'))) # load image files into sorted list
    
    def __len__(self):
        return len(self.raw_files)

    def get_centroids(self, mask):
        ids = np.unique(mask)
        ids = ids[1:]
        centroids = []
        for id in ids:
            y,x = np.where(mask == id)
            xm = int(np.mean(x))
            ym = int(np.mean(y))
            centroid = (ym, xm)
            centroids.append(centroid)
        
        centroids = np.array(centroids)
        return centroids
    
    def crop_top_left(self, coord): 
        y,x = coord
        y_top_left = int(y-(self.crop_size//2))
        x_top_left = int(x-(self.crop_size//2))
        return y_top_left, x_top_left
    
    def get_masked_crop(self, raw, mask, batch_size):
        crops_raw = []
        crops_mask = []
        centroids = get_centroids(mask)
        
        while len(crops_masks) < batch_size: 
            centroid = random.choice(centroids)
            y_top_left, x_top_left = self.crop_top_left(centroid)
            crop_mask = mask[y_top_left:y_top_left+self.crop_size, x_top_left:x_top_left+self.crop_size]
            if crop_mask.shape == (self.crop_size, self.crop_size):
                crop_mask = (crop_mask == crop_mask[int(self.crop_size//2), int(self.crop_size//2)])
                crop_raw = raw[x_top_left:x_top_left+self.crop_size, y_top_left:y_top_left+self.crop_size]
                crops_raw.append(crop_raw*crop_mask)
                crops_mask.append(crop_mask)
        return crops_raw, crops_mask 
    
    def __getitem__(self, idx):
        raw_file = self.raw_files[idx] 
        mask_file = self.mask_files[idx] 
        
        raw = imread(raw_file) # load raw to numpy array
        mask = imread(mask_file) # load mask to numpy array

        # from (H, W) mask extract (B, h, h)
        crops_raw, crops_mask = get_masked_crops(raw, mask)

        # need to cast to float32
        crops_mask = (crops_mask !=0).astype(np.float32)
        crops_raw = (crops_raw != 0).astype(np.float32)
        
        # add channel dimensions to comply with pytorch standard (B, C, H, W) 
        crops_raw = np.expand_dims(crops_raw, axis=1)
        crops_mask = np.expand_dims(crops_mask, axis=1)
        
        return crops_raw, crops_mask

### Creating Autoencoder

In [5]:
class Autoencoder(torch.nn.Module):

    def __init__(
            self,
            in_channels,
            downsampling_factors,
            fmaps,
            fmul,
            kernel_size=3):

        super(Autoencoder, self).__init__()

        out_channels = in_channels

        encoder = []

        for downsampling_factor in downsampling_factors:

            encoder.append(
                    torch.nn.Conv2d(
                        in_channels,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.Conv2d(
                        fmaps,
                        fmaps,
                        kernel_size))
            encoder.append(
                    torch.nn.ReLU(inplace=True))
            encoder.append(
                    torch.nn.MaxPool2d(downsampling_factor))

            in_channels = fmaps

            fmaps = fmaps * fmul

        fmaps_bottle = fmaps

        encoder.append(
            torch.nn.Conv2d(
                in_channels,
                fmaps_bottle,
                kernel_size))
        encoder.append(
            torch.nn.ReLU(inplace=True))

        self.encoder = torch.nn.Sequential(*encoder)

        decoder = []

        fmaps = in_channels

        decoder.append(
            torch.nn.Conv2d(
                fmaps_bottle,
                fmaps,
                kernel_size))
        decoder.append(
            torch.nn.ReLU(inplace=True))

        for downsampling_factor in downsampling_factors[::-1]:

            fmaps = in_channels / fmul

            decoder.append(
                torch.nn.Upsample(
                    scale_factor=downsampling_factor,
                    mode='trilinear'))
            decoder.append(
                torch.nn.Conv2d(
                    in_channels,
                    fmaps,
                    kernel_size))
            decoder.append(
                torch.nn.ReLU(inplace=True))
            decoder.append(
                torch.nn.Conv2d(
                    fmaps,
                    fmaps,
                    kernel_size))
            decoder.append(
                torch.nn.ReLU(inplace=True))

            in_channels = fmaps

        decoder.append(
            torch.nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size))

        self.decoder = torch.nn.Sequential(*decoder)

    def forward(self, x):

        enc = self.encoder(x)

        dec = self.decoder(enc)

        return enc, dec


### Training Time ! 

In [6]:
# identifying params for training
batch_size = 32
crop_size = 156
num_epochs = 50
model_depth = 3
root_dir = '/mnt/efs/shared_data/instance_no_gt/20230830_TIF_cellpose_test/'

In [7]:
assert torch.cuda.is_available()
device = torch.device("cuda")

def train(args):
    print(args)
    
    # create train dataset
    dataset = EmbryoNucleiDataset(root_dir, crop_size)

    

    # create val dataset


    # create model
    model = Autoencoder(in_channels=1, downsampling_factors=[2,2,2],
        fmaps=32, fmul=2, kernel_size = 3)

    # create loss object
    loss_function = torch.nn.MSELoss()

    # create optimizer
    optimizer = torch.optim.Adam(model.parameters())
    
    for epoch in range(num_epochs):
        train_epoch(dataset, model, epoch, optimizer, loss_function)

def train_epoch(dataset,model,epoch,optimizer,loss_function):
    model.train()
    model = model.to(device)
    loss_list = []    
    for batch_id, (raw, mask) in enumerate(dataset):
        raw = raw.to(device)
        optimizer.zero_grad()
        # apply model and calculate loss
        prediction = model(raw)
        loss = loss_function(prediction, raw)
        loss_list.append(loss.item())
        # backpropagate the loss and adjust the parameters
        loss.backward()
        optimizer.step()
    print(f"Loss at Epoch {epoch} is {loss_list.mean()}")

AssertionError: 

In [8]:
torch.cuda.is_available()

False