In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
from tqdm import tqdm
import numpy as np


def calculate_accuracy(pred, target):
  _, predicted = torch.max(pred, 1)
  correct_pixels = (predicted == target).sum().item()
  total_pixels = target.numel()
  accuracy = correct_pixels / total_pixels * 100
  return accuracy

from torchmetrics import F1Score
def calculate_f1(pred, target, num_classes):
  f1 = F1Score(task="multiclass", num_classes=num_classes).to(device)
  return f1(pred.to(device), target.to(device))


from torchmetrics import JaccardIndex
def calculate_iou(output, mask, num_classes):
  output = output.to(device)
  mask = mask.to(device)

  jaccard = JaccardIndex(task="multiclass", num_classes=num_classes, average="weighted").to(device)
  return jaccard(output, mask)

from torchmetrics import Precision
def calculate_precision(output, mask, num_classes):
  precision = Precision(task="multiclass", num_classes=num_classes, average="weighted").to(device)
  return precision(output.to(device), mask.to(device))

from torchmetrics import Recall
def calculate_recall(output, mask, num_classes):
  recall = Recall(task="multiclass", num_classes=num_classes, average="weighted").to(device)
  return recall(output.to(device), mask.to(device))