In [1]:
import os
import torch

from monai.data import ImageDataset, DataLoader, MetaTensor
from monai.losses import DiceLoss
from torch.optim import Adam

from deeplabv3.network.modeling import _segm_resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cpu


In [2]:
model = _segm_resnet(
    name="deeplabv3plus",
    backbone_name="resnet50",
    num_classes=2,
    output_stride=8,
    pretrained_backbone=True,
)

In [3]:
# Creating the dataset class
class OcelotDataset(ImageDataset): 

  @classmethod
  def decode_target(cls, mask): 
    return cls.cmap[mask]

In [4]:
image_file_path = "ocelot_data/images/train/tissue/"
segmentation_file_path = "ocelot_data/annotations/train/tissue/"


image_files = [
    os.path.join(image_file_path, file_name)
    for file_name in os.listdir(image_file_path)
]
image_files.sort()

segmentation_files = [
    os.path.join(segmentation_file_path, file_name)
    for file_name in os.listdir(segmentation_file_path)
]
segmentation_files.sort()

In [5]:
dataset = OcelotDataset(image_files=image_files, seg_files=segmentation_files)
data_loader = DataLoader(dataset=dataset, batch_size=2)

loss_function = DiceLoss(softmax=True)
optimizer = Adam(model.parameters(), lr=1e-3)

In [7]:
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for (inputs, labels) in data_loader:
        # Transposing the inputs to fit the expected shape
        inputs_tensor = torch.Tensor(inputs)
        inputs_tensor = inputs_tensor.permute((0, 3, 1, 2))
        inputs = MetaTensor(inputs_tensor, meta=inputs.meta)

        # Continuing with the regular training loop
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
      
    epoch_loss /= len(data_loader)
    print(f"Epoch {epoch + 1}, Loss: {epoch_loss}")