In [None]:
import os
import glob
import torch
import pickle
import random
import numpy as np
import matplotlib.pyplot as plt

from torch import nn
from PIL import Image
from tqdm import tqdm
from torch.nn import functional as F
from torchvision.transforms import v2
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split

## Dataloader and Supporting Functions

In [None]:
def compute_single_iou(prediction,mask):
  mask = torch.tensor(mask).int()
  prediction = torch.tensor(prediction).int()
  wrong = 0
  right = 0
  total = 0

  # iterate through values
  for m,p in zip(mask.view(-1), prediction.view(-1)):
      m = m.item()
      p = p.item()
      if p == 0 and m == 0:
          pass
      elif p == m:
          right+=1
          total+=1
      else:
          wrong+=1
          total+=1
  return right/total

def compute_single_iou(prediction, mask):
  """
  Calculates Intersection-over-Union (IoU) for a single prediction and mask tensor.

  Args:
      prediction: A torch tensor of shape (..., H, W) representing the predicted mask.
      mask: A torch tensor of shape (..., H, W) representing the ground truth mask.

  Returns:
      A float value representing the IoU between the prediction and the mask.
  """

  # Flatten tensors (optional, depending on use case)
  prediction = torch.tensor(prediction.flatten())
  mask = torch.tensor(mask.flatten())

  # Intersection
  intersection = torch.sum((prediction == mask) & (prediction != 0))

  # Union (avoiding division by zero)
  union = torch.sum(prediction != 0) + torch.sum(mask != 0) - intersection

  # IoU (avoid division by zero)
  iou = torch.where(union != 0, intersection.float() / union.float(), 0.0)

  return iou.item()

def get_number(filename):
  return int(filename.split("_")[1].split(".")[0]) # extract the number part from the filename (excluding extension)


def encode_labels(mask):
    encoded_mask = np.zeros((mask.shape[0],41,mask.shape[1], mask.shape[2]))
    
    for i in range(len(mask)):
        for k in np.arange(0,41): # for class in classes
            encoded_mask[i][k] = mask[i]==float(k)
        return encoded_mask

def decode_prediction(prediction_mask):
    decoded_prediction = np.zeros((1, prediction_mask.shape[0], prediction_mask.shape[1]))
    max_indices = np.argmax(prediction_mask, axis=0)
    decoded_prediction = max_indices
    return decoded_prediction

transforms = v2.Compose([
    v2.ToTensor(),
    v2.ToDtype(torch.float32, scale=True),
])

class SegmentationDataset(Dataset):
  def __init__(self, root_dir="dataset/train", transform=None, train=True):
    """
    Args:
      data_path (str): Path to the data directory.
      transform (callable, optional): A function for transforming data. Defaults to None.
    """
    all_paths = []
    masks = []
    self.train = train

    for video in os.listdir(root_dir): # for every video in train
        # get names of all video folders
        if "DS_Store" not in video: 
            dirlist = os.listdir(os.path.join(root_dir, video)) 
            dirlist = [path for path in dirlist if "png" in path]
            sorted_dirlist = sorted(dirlist, key=get_number)

            # collect image paths and masks -> append to master list
            if self.train: 
              mask = np.load(os.path.join(root_dir, video, "mask.npy"))
              masks.append(mask)
            images_paths = [os.path.join(root_dir, video,image_path) for image_path in sorted_dirlist]
            all_paths.append(images_paths)
    
    if self.train:
      masks = [mask for mask in masks if mask.shape==(22, 160, 240)]
      masks = np.array(masks)
      self.masks = masks.reshape(masks.shape[0]*masks.shape[1],masks.shape[2],masks.shape[3])
    else: masks = None
    self.all_paths = np.array(all_paths).flatten()
    self.transform = transform  # Optional transformation for data


  def __len__(self):
    """
    Returns the length of the dataset.
    """
    return len(self.all_paths)

  def __getitem__(self, index):
    """
    Args:
      index (int): Index of the data point to return.

    Returns:
      tuple: A tuple containing the data and its corresponding label.
    """


    if index>=len(self.masks):
      index = 0

    image = Image.open(self.all_paths[index])  # access data from path based on index
    image = np.asarray(image,dtype='int32').astype(np.uint8)
    if self.train:
      mask = self.masks[index]
      # Apply transformation if defined
      if self.transform:
        image = self.transform(image)
      return (image, mask)
    else: return image


## Model Definition

In [None]:
# u-net model definition
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        # Encoder path
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        
        # Decoder path
        self.up6 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.conv7 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(128)
        self.conv8 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.up8 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.conv9 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn9 = nn.BatchNorm2d(64)
        self.conv10 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv11 = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x = x.permute(0,3,1,2).float()
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.conv2(x))
        encoder1 = x
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.conv4(x))
        encoder2 = x
        x = self.pool(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.conv6(x))
        x = self.up6(x)
        x = torch.cat([x, encoder2], dim=1)
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.conv8(x))
        x = self.up8(x)
        x = torch.cat([x, encoder1], dim=1)
        x = F.relu(self.bn9(self.conv9(x)))
        x = F.relu(self.conv10(x))
        x = F.relu(self.conv11(x))
        return x

## Training

#### Setup

In [None]:
# declare device for running on gpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Running on {device}")

# hyperparams
epochs = 10
batch_size = 8
model = UNet(3,41)
model.to(device)
model_folder = "models"

# learning params
criterion = nn.CrossEntropyLoss()  # Binary Cross Entropy for segmentation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# lr=0.001)

# dataset
train_dataset = SegmentationDataset(root_dir='../dataset/train', train=True)
val_dataset = SegmentationDataset(root_dir='../dataset/val')
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,  # Adjust batch size as needed
                              shuffle=True)

val_dataloader = DataLoader(val_dataset,
                              batch_size=batch_size,  # Adjust batch size as needed
                              shuffle=True)

#### Training

In [None]:
total_loss = []
total_val_loss = []
total_val_iou = []
total_iou = []
# clear textfile for dumping results

for epoch in range(epochs):
  losses = []
  ious = []
  print(f"Epoch {epoch} training")
  for data, label in tqdm(train_dataloader):
    # Forward pass
    new_data = transforms(torch.tensor(data)).to(device)
    prediction = model(new_data)
    loss = criterion(prediction,torch.tensor(encode_labels(label)).to(device))
    
    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    print(loss)
    prediction = np.array([decode_prediction(pred.detach().cpu()) for pred in prediction])
    for p, l in zip(prediction, label.cpu().numpy()):
        ious.append(compute_single_iou(prediction,label.cpu().numpy()))

  print(f"Train Mean Loss: {np.mean(np.array(losses))}")
  # validation
  print(f"Epoch {epoch} validation")
  val_losses = []
  val_ious = []
  with torch.no_grad():
    model.eval()
    for val_data, val_label in tqdm(val_dataloader):
      val_data = transforms(torch.tensor(val_data)).to(device)
      val_label = torch.tensor(val_label).to(device) # move data and label to device
      val_prediction = model(val_data)
      loss = criterion(val_prediction,torch.tensor(encode_labels(val_label.cpu())).to(device))
      val_losses.append(loss.item())
      val_prediction = np.array([decode_prediction(val_pred.cpu()) for val_pred in val_prediction])
      for p, l in zip(val_prediction, val_label.cpu().numpy()):
        val_ious.append(compute_single_iou(val_prediction,val_label.cpu().numpy()))
  total_loss.append(np.mean(np.array(losses)))
  total_val_loss.append(np.mean(np.array(val_losses)))
  total_val_iou.append(np.mean(np.array(val_ious)))
  total_iou.append(ious)

  print(f"Validation Mean Loss: {np.mean(np.array(val_losses))}")
  print(f"Validation Mean IOU: {np.mean(np.array(val_ious))}")



  torch.save(model.state_dict(),f'{model_folder}/model_{epoch}.pth') # save model every epoch
  model.train()

  # save epoch metrics

  # visualize results
  print(f"Epoch: {epoch+1}, Loss: {np.average(losses):.4f}")
  plt.imshow(data[0])
  plt.title("Original Image")
  plt.show()
  plt.imshow(decode_prediction(prediction))
  plt.title("Predicted Mask")
  plt.show()

with open(f"{model_folder}/metrics.pkl") as f:
  pickle.dump([total_loss,total_validation_loss,total_iou])

#### Validation Loop

In [None]:
model.eval()

total_loss = 0

with torch.no_grad():
    for image in val_dataloader:
        print(image.shape)
        prediction = model(image.to(device))
        break


## Visualize Results on Saved Model (later epochs not necessarily better)

In [None]:
model = UNet(3,41)
model.load_state_dict(torch.load("models/model_6.pth")) # change to reflect model path
model.to(device)
model.eval()

In [None]:
val_dataset[i].shape

In [None]:
data

In [None]:
i = random.randint(0,len(val_dataset)-1) # random index in dataset
data = (val_dataset[i][0]) # first image in batch
mask = torch.tensor(val_dataset[i][1])
new_data = transforms(torch.tensor(data)).to(device).unsqueeze(0) # transform and fit to expected size for model
prediction = decode_prediction(model(new_data)[0].cpu().detach())
plt.imshow(data)
plt.show()
plt.imshow(prediction)
plt.show()
plt.imshow(mask)
plt.show()

In [None]:
wrong = 0
right = 0
total = 0
# iterate through values
for m,p in zip(mask.view(-1), prediction.view(-1)):
    m = m.item()
    p = p.item()
    if p == 0 and m == 0:
        pass
    elif p == m:
        right+=1
        total+=1
    else:
        wrong+=1
        total+=1

print(right/total)



In [None]:
total

In [None]:
wrong

In [None]:
prediction_active

In [None]:
torch.sum(mask>0)

In [None]:
torch.sum(prediction>0)

In [None]:
mask.shape

In [None]:
data = (val_dataset[i][0]) # first image in batch
mask = torch.tensor(val_dataset[i][1])

In [None]:
data.shape

In [None]:
torch.sum(mask!=0)

In [None]:
torch.sum(prediction!=0)

In [None]:
intersection

In [None]:
union

In [None]:
torch.tensor(mask)