In [1]:
import os
import torch
import numpy as np

from monai.data import ImageDataset, DataLoader, MetaTensor
from monai.losses import DiceLoss
from torch.optim import Adam

from deeplabv3.network.modeling import _segm_resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda


  This value defaults to True when PyTorch version in [1.7, 1.11] and may affect precision.
  See https://docs.monai.io/en/latest/precision_accelerating.html#precision-and-accelerating


In [2]:
model = _segm_resnet(
    name="deeplabv3plus",
    backbone_name="resnet50",
    num_classes=2,
    output_stride=8,
    pretrained_backbone=True,
)
model.to(device)


DeepLabV3(
  (backbone): IntermediateLayerGetter(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Se

In [3]:
class OcelotDataset(ImageDataset):

    @classmethod
    def decode_target(cls, target):
        return target.argmax(1)
  
    def __init__(self, image_files, seg_files, device=None):
        super().__init__(image_files, seg_files)  # assuming the super class requires these arguments
        self.device = device
        a, b = self.process_data()
        self.processed_images = "hei"
        self.processed_labels = "hei"
        print(f"a: {a}")
        print(f"b: {b}")

    def process_data(self):
        processed_images = []
        processed_labels = []
        for index in range(len(self)):
            image, label = super().__getitem__(index)
            label[np.logical_or(label == 1, label == 255)] = 0  # Set pixels with label 255 to 1
            label[label == 2] = 1  # Set pixels with label 255 to 1
            processed_images.append(torch.tensor(image))
            processed_labels.append(torch.tensor(label))
        return torch.stack(processed_images).to(self.device), torch.stack(processed_labels).to(self.device)
    
    def __getitem__(self, index):
        return self.processed_images[index], self.processed_labels[index]

    def __len__(self):
        return super().__len__()


In [4]:
image_file_path = "ocelot_data/images/train/tissue/"
segmentation_file_path = "ocelot_data/annotations/train/tissue/"


image_files = [
    os.path.join(image_file_path, file_name)
    for file_name in os.listdir(image_file_path)
]
image_files.sort()

segmentation_files = [
    os.path.join(segmentation_file_path, file_name)
    for file_name in os.listdir(segmentation_file_path)
]
segmentation_files.sort()

In [5]:
dataset = OcelotDataset(image_files=image_files, seg_files=segmentation_files, device=device)

data_loader = DataLoader(dataset=dataset, batch_size=2)

loss_function = DiceLoss(
  # softmax=True
  )
optimizer = Adam(model.parameters(), lr=1e-3)

  processed_images.append(torch.tensor(image))
  processed_labels.append(torch.tensor(label))


: 

: 

In [14]:
num_epochs = 1
decode_fn = data_loader.dataset.decode_target

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for (inputs, labels) in data_loader:
        # Transposing the inputs to fit the expected shape
        inputs_tensor = torch.Tensor(inputs)
        inputs_tensor = inputs_tensor.permute((0, 3, 1, 2))
        inputs = MetaTensor(inputs_tensor, meta=inputs.meta)

        # Continuing with the regular training loop
        # inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        outputs = decode_fn(outputs).to(torch.float32)
        outputs.requires_grad = True

        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        loop_loss = loss.item()
        print(f"Loop loss: {loop_loss}")
        epoch_loss += loop_loss
      
    epoch_loss /= len(data_loader)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")

Loop loss: 0.40607723593711853
Loop loss: 0.3937247395515442
Loop loss: 0.12468582391738892
Loop loss: 0.2804950177669525
Loop loss: 0.22310444712638855
Loop loss: 0.3182410001754761
Loop loss: 0.24973900616168976
Loop loss: 0.34927189350128174
