In [23]:
from importlib import reload
import numpy as np
import torch
import matplotlib.pyplot as plt

import lab3
reload(lab3)

import lab3.classes
reload(lab3.classes)
import lab3.classes as cs

import lab3.show
reload(lab3.show)
from lab3.show import display_img_with_masks

import lab3.trans
reload(lab3.trans)
from lab3.trans import validation_trans, train_trans, test_trans

import lab3.dataset
reload(lab3.dataset)
from lab3.dataset import FiftyOneDataset

import lab3.net
reload(lab3.net)
from lab3.net import Net

import lab1.device
reload(lab1.device)
from lab1.device import device

import lab3.util
reload(lab3.util)
from lab3.util import seconds_to_time

device



device(type='mps')

In [24]:
from typing import Literal
import fiftyone as fo
import fiftyone.utils.openimages as fouo
import fiftyone.zoo as foz

def download(split = "train", max_samples: int = 2000):
  return foz.load_zoo_dataset(
    "open-images-v6",
    split        = split,
    label_types  = ["segmentations", "detections"],
    classes      = cs.classes_no_background,
    max_samples  = max_samples,
    dataset_dir  = "data-lab3",
    dataset_name =f"open-images-v6-{split}"
  )

def load(split = "train"):
  dataset =  fouo.OpenImagesV6DatasetImporter(
    dataset_dir = f"data-lab3/{split}",
    label_types = "segmentations"
  )

  dataset.setup()

  return dataset

# train_ds = download("train")
# valid_ds = download("validation", max_samples = 300)
# test_ds  = download("test", max_samples = 300)

train_ds = load("train")
valid_ds = load("validation")

train_ds #, fo.list_datasets()

<fiftyone.utils.openimages.OpenImagesV6DatasetImporter at 0x33e93bd40>

In [25]:
import torch

train_dataset = FiftyOneDataset(train_ds, train_trans)
valid_dataset = FiftyOneDataset(valid_ds, validation_trans)

num_workers = 8
batch_size = 128

train_ld = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = True)
valid_ld = torch.utils.data.DataLoader(valid_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False)

print(f'Train: {len(train_dataset)}, Test: {len(valid_dataset)}')

Train: 808, Test: 157


In [26]:
ZERO = torch.zeros(0, 6, 128, 128) < 1

class Stats:
  loss_acum = t([])
  y_hats = ZERO.clone()
  ys = ZERO.clone()
  
  def init(self):
    pass
  
  def add_data(self, ys = ZERO.clone(), y_hats = ZERO.clone(), loss = t([])):
    self.y_hats = c((self.y_hats, ys), dim = 0)
    self.ys = c((self.ys, y_hats), dim = 0)
    self.loss_acum = c((self.loss_acum, loss))
  
  def get_stats(self):
    ys = self.ys
    y_hats = self.y_hats

    intersection = torch.bitwise_and(ys, y_hats)
    union = torch.bitwise_or(ys, y_hats)

    intersection = torch.sum(intersection)
    union = torch.sum(union)

    # Jaccard = || A \intersect B || / || A \union B ||
    iou = intersection / union

    # DICE = 2 || A \intersect B || / (||A|| + ||B||)
    dice = 2 * intersection / (torch.sum(ys) + torch.sum(y_hats))

    # Flatten the tensors
    ys_flat = ys.view(-1).numpy()
    y_hats_flat = y_hats.view(-1).numpy()

    # Calculate Micro-F1 and Macro-F1 scores
    micro_f1 = f1_score(ys_flat, y_hats_flat, average='micro')
    macro_f1 = f1_score(ys_flat, y_hats_flat, average='macro')

    return iou.item(), dice.item(), micro_f1.item(), macro_f1.item(), torch.mean(self.loss_acum)

def run_epoch(model: Net,
              loader: torch.utils.data.DataLoader,
              loss_fn, optimizer):
  stats = Stats()
  IS_TRAIN = optimizer is not None
  
  if IS_TRAIN:
    model.train()
  else:
    model.eval()

  ix = -1
  for images, true_masks in loader:
    images = images.to(device)
    true_masks = true_masks.to(device)
    
    if not IS_TRAIN:
      with torch.no_grad():
        predictions = model(images)
    else:
      predictions = model(images)

    loss = loss_fn(true_masks, predictions)

    true_masks = true_masks > 0.5
    predictions = predictions > 0.5
    
    stats.add_data(true_masks.cpu().detach(), predictions.cpu().detach(), loss.cpu().detach())
    
    if IS_TRAIN:
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()

    if ix == 0:
      break
    ix -= 1

  return stats.get_stats()


In [27]:
from datetime import datetime

def train_and_eval(model, train_ld, valid_ld, epoch_count = 10, learning_rate = 1e-3):
  loss_func = torch.nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

  start_time = datetime.now()

  train_loss_acum = []
  train_iou_acum  = []
  train_dice_acum = []
  train_micro_acum = []
  train_macro_acum = []
  valid_loss_acum = []
  valid_iou_acum  = []
  valid_dice_acum = []
  valid_micro_acum = []
  valid_macro_acum = []
  
  for epoch in range(epoch_count):
    print(f'EPOCH: {epoch}')
    train_iou, train_dice, train_micro, train_macro, train_loss = run_epoch(model, train_ld, loss_func, optimizer)
    
    current_time = datetime.now()
    elapsed = seconds_to_time((current_time - start_time).total_seconds())
    print(f'  train      | Elapsed: {elapsed}')

    valid_iou, valid_dice, valid_micro, valid_macro, valid_loss = run_epoch(model, valid_ld, loss_func, optimizer)
    
    current_time = datetime.now()
    elapsed = seconds_to_time((current_time - start_time).total_seconds())
    print(f'  valid      | Elapsed: {elapsed}')

    train_iou_acum.append(train_iou)
    train_dice_acum.append(train_dice)
    train_micro_acum.append(train_micro)
    train_macro_acum.append(train_macro)
    train_loss_acum.append(train_loss)
    
    valid_iou_acum.append(valid_iou)
    valid_dice_acum.append(valid_dice)
    valid_micro_acum.append(valid_micro)
    valid_macro_acum.append(valid_macro)
    valid_loss_acum.append(valid_loss)

    print(f'  Training Loss:  {train_loss},  Validation Loss:  {valid_loss}')
    print(f'  Training IoU:   {train_iou},   Validation IoU:   {valid_iou}')
    print(f'  Training Dice:  {train_dice},  Validation Dice:  {valid_dice}')
    print(f'  Training Micro: {train_micro}, Validation Micro: {valid_micro}')
    print(f'  Training Macro: {train_macro}, Validation Macro: {valid_macro}')

  return train_iou_acum, valid_iou_acum, train_dice_acum, valid_dice_acum, train_micro_acum, valid_micro_acum, train_macro_acum, valid_macro_acum, train_loss_acum, valid_loss_acum

In [28]:
def plot(train, valid, label = "IoU"):
  plt.clf()
  plt.plot(train, 'b', label = f'Training {label}')
  plt.plot(valid, 'r', label = f'Validation {label}')
  plt.ylim(0.0, 1.0)
  plt.legend()
  plt.show()

In [29]:
model = Net(train_dataset[0][0].shape[0], num_classes = cs.num_classes).to(device)
print(f'Parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

EPOCHS = 30

train_iou, valid_iou, train_dice, valid_dice, train_micro, valid_micro, train_macro, valid_macro, train_loss, valid_loss = train_and_eval(model, train_ld, valid_ld, epoch_count = EPOCHS, learning_rate = 1e-3)
plot(train_iou, valid_iou)
plot(train_dice, valid_dice, label = "dice")
plot(train_micro, valid_micro, label = "micro")
plot(train_macro, valid_macro, label = "macro")
plot(train_loss, valid_loss, label = "loss")


Parameter count: 1,928,582
EPOCH: 0


Python(2229) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2230) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2231) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2233) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2234) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2235) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2238) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(2239) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


RuntimeError: MPS backend out of memory (MPS allocated: 17.85 GB, other allocations: 258.83 MB, max allowed: 18.13 GB). Tried to allocate 256.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
torch.save(model, 'lab3/net_attempt2.pth')

In [None]:
model = torch.load('lab3/net_attempt2.pth').to(device)

img, true_mask = valid_dataset[40]
y = true_mask
mask = model(img.unsqueeze(0).to(device)).cpu().detach().squeeze(0)
# mask = mask < .9

# aggregated_mask = lab3.dataset.aggregate_detections(mask)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize= (8, 3))
display_img_with_masks(ax1, img, true_mask)
display_img_with_masks(ax2, img, mask)