In [41]:
import torch
from torch import nn
from torch.nn import functional as F

from ignite.engine import create_supervised_evaluator
from ignite.metrics import Accuracy, ConfusionMatrix, mIoU

from albumentations import Compose, Normalize, Resize
from albumentations.pytorch import ToTensorV2

In [42]:
from torch_semantic_segmentation.models import ENet
from torch_semantic_segmentation.data import CityScapesDataset

In [43]:
tfms = Compose([Normalize(), Resize(512, 1024), ToTensorV2()])

In [44]:
train_ds = CityScapesDataset(root_dir='/home/bml/datasets/cities-scapes/', split='train', transforms=tfms)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=8, num_workers=4)
val_ds = CityScapesDataset(root_dir='/home/bml/datasets/cities-scapes/', split='val', transforms=tfms)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=8, num_workers=4)

In [45]:
model = ENet(3, 19)

In [47]:
model.load_state_dict(torch.load('/tmp/checkpointsjs1li3w1/enet-cityscapes_model_200_mIOU=0.3210312020186483.pth'))

<All keys matched successfully>

In [48]:
device = torch.device('cuda:0')

In [49]:
model = model.to(device)

In [50]:
def accuracy(inputs, targets, ignore_index=None):
  inputs = torch.argmax(inputs, dim=1)
  
  if ignore_index is not None:
    mask = targets != ignore_index
    inputs = inputs[mask]
    targets = targets[mask]
  
  return (inputs == targets).float().mean()

def confusion_matrix(inputs, targets, num_classes):
  inputs = torch.argmax(inputs, dim=1).flatten()
  targets = targets.flatten()
  
  mask = (targets >= 0) & (targets < num_classes)
  inputs = inputs[mask]
  targets = targets[mask]
  
  indices = num_classes * targets + inputs
  m = torch.bincount(indices, minlength=num_classes**2).reshape(num_classes, num_classes)
  return m.float() / m.sum()

In [51]:
class MetricAverage:
  def __init__(self, metric_fn):
    self.metric_fn = metric_fn
    self.count = 0
    self.value = None
  
  def update(self, *args, **kwargs):
    self.count += 1.
    if self.value is None:
      self.value = self.metric_fn(*args, **kwargs)
    else:
      self.value += self.metric_fn(*args, **kwargs)

  def compute(self):
    return self.value / self.count

In [52]:
from functools import partial

In [53]:
from tqdm import tqdm

In [54]:
def evaluate(dataloader):
  model.eval()
  cm = MetricAverage(partial(confusion_matrix, num_classes=19))
  loss = MetricAverage(partial(F.cross_entropy, ignore_index=255))
  for inputs, targets in tqdm(dataloader):
    inputs = inputs.to(device)
    targets = targets.to(device)

    with torch.no_grad():
      outputs = model(inputs)
    
    loss.update(outputs, targets)
    cm.update(outputs, targets)
  return loss.compute(), cm.compute()

In [55]:
loss_train, cm_train = evaluate(train_loader)
loss_val, cm_val = evaluate(val_loader)

100%|██████████| 372/372 [00:57<00:00,  6.42it/s]
100%|██████████| 63/63 [00:10<00:00,  6.11it/s]


In [56]:
def cm_accuracy(cm):
  return cm.diag().sum()

def cm_miou(cm):
  iou = cm.diag() / (cm.sum(dim=1) + cm.sum(dim=0) - cm.diag() + 1e-15)
  return iou.mean()

In [57]:
loss_train, loss_val

(tensor(0.3992, device='cuda:0'), tensor(0.4387, device='cuda:0'))

In [58]:
# Train
cm_accuracy(cm_train).item(), cm_miou(cm_train).item()

(0.8786742687225342, 0.3388688266277313)

In [59]:
# Val
cm_accuracy(cm_val).item(), cm_miou(cm_val).item()

(0.8671292066574097, 0.3206044137477875)