In [None]:
def validate(model, val_dataloader, loss_fn):
  device = next(model.parameters()).device

  val_loss = 0.0
  val_accuracy = 0.0
  val_iou = 0.0
  val_precision = 0.0
  val_recall = 0.0
  val_total = 0

  model.eval()
  with torch.inference_mode():
    for inputs, targets in tqdm(val_dataloader):
      inputs, targets = inputs.to(device), targets.to(device)

      outputs = model(inputs)
      loss = loss_fn(outputs, targets.long())

      val_loss += loss.item()

      accuracy = calculate_accuracy(outputs, targets)
      iou = calculate_iou(outputs, targets, num_classes=14)
      precision = calculate_precision(outputs, targets, 14)
      recall = calculate_recall(outputs, targets, 14)

      val_accuracy += accuracy
      val_iou += iou
      val_precision += precision
      val_recall += recall

      val_total += targets.size(0)

  mean_val_loss = val_loss / len(val_dataloader)
  mean_val_accuracy = val_accuracy / len(val_dataloader)
  mean_val_iou = val_iou / len(val_dataloader)
  mean_val_precision = val_precision / len(val_dataloader)
  mean_val_recall = val_recall / len(val_dataloader)

  mean_val_f1 = 2*(mean_val_precision * mean_val_recall) / (mean_val_precision + mean_val_recall)

  return mean_val_loss, mean_val_accuracy, mean_val_iou, mean_val_f1, mean_val_precision, mean_val_recall