In [1]:
import os
import torch
import numpy as np

from monai.data import ArrayDataset, DataLoader, MetaTensor
from monai.losses import DiceLoss
from torch.optim import Adam
from utils import (
    read_data, 
    get_metadata,
    get_cell_annotations_in_tissue_coordinates,
    crop_and_upscale_tissue,
    get_image_tensor_from_path
)

from deeplabv3.network.modeling import _segm_resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cpu


In [None]:
model = _segm_resnet(
    name="deeplabv3plus",
    backbone_name="resnet50",
    num_classes=3,
    output_stride=8,
    pretrained_backbone=True,
)
model.to(device)

In [None]:
class OcelotCellDataset(ArrayDataset):

    @classmethod
    def decode_target(cls, target):
        return target.argmax(1)
  
    def __init__(self, cell_tensors, cell_annotations, device=None):
        super().__init__(img=cell_tensors, seg=cell_annotations)
        self.device = device
        self.processed_images, self.processed_labels = self.process_data()

    def process_data(self):
        processed_images = []
        processed_labels = []
        for index in range(len(self)):
            image, label = super().__getitem__(index)
            label[np.logical_or(label == 1, label == 255)] = 0
            label[label == 2] = 1
            processed_images.append(torch.tensor(image).permute((2, 0, 1)))
            processed_labels.append(torch.tensor(label))
        return torch.stack(processed_images).to(self.device), torch.stack(processed_labels).to(self.device)
    
    def __getitem__(self, index):
        return self.processed_images[index], self.processed_labels[index]

    def __len__(self):
        return super().__len__()


In [None]:
folder_path = "ocelot_data/"
metadata = get_metadata(folder_path)
data = read_data(folder_path)

In [None]:
segmented_cell_folder = "ocelot_data/annotations/train/segmented_cell/"
cell_annotations = []
for f_name in sorted(list(data.keys())):
    image_path = os.path.join(segmented_cell_folder, f"{f_name}.png")
    cell_annotation = get_image_tensor_from_path(image_path)
    cell_annotations.append(cell_annotation)

cell_annotations_tensor = torch.stack(cell_annotations)

In [None]:
cell_channels_with_tissue_annotations = []
image_size = 1024

for id in sorted(list(data.keys())):
    data_object = data[id]
    offset_tensor = torch.tensor([data_object["x_offset"], data_object["y_offset"]]) * image_size
    scaling_value = data_object["cell_mpp"] / data_object["tissue_mpp"]
    tissue_tensor = data_object["tissue_annotated"]
    cell_tensor = data_object["cell_image"]

    cropped_scaled = crop_and_upscale_tissue(tissue_tensor, offset_tensor, scaling_value)
    cell_tensor_tissue_annotation = torch.cat([cell_tensor, cropped_scaled], 0)
    
    cell_channels_with_tissue_annotations.append(cell_tensor_tissue_annotation)

tissue_crops_scaled_tensor = torch.stack(cell_channels_with_tissue_annotations)

In [None]:
dataset = OcelotCellDataset(cell_tensors=tissue_crops_scaled_tensor, cell_annotations=cell_annotations_tensor, device=device)

data_loader = DataLoader(dataset=dataset, batch_size=2)

loss_function = DiceLoss(
  # softmax=True
  )
optimizer = Adam(model.parameters(), lr=1e-3)

In [None]:
import time
num_epochs = 1
decode_fn = data_loader.dataset.decode_target

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    start = time.time()

    for (inputs, labels) in data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = decode_fn(outputs).to(torch.float32)
        outputs.requires_grad = True

        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"One iteration")
    
    end = time.time()
    epoch_loss /= len(data_loader)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss}, epoch_time: {end-start:.2f} seconds")