In [1]:
from dataset import PetsDataset
from model import UNet
from utils import train

In [2]:
import torch
from torch.utils.data import DataLoader, random_split

In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters

In [4]:
NUM_EPOCHS = 20
LEARNING_RATE = 3e-4
BATCH_SIZE = 32

# Load data

In [5]:
trainval_set = PetsDataset("../data/", split="trainval")
train_set, val_set = random_split(trainval_set, [0.75, 0.25], generator=torch.Generator().manual_seed(42))
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE)

# Training

In [6]:
model = UNet(in_channels=3, num_classes=3).to(DEVICE)

In [7]:
logger = train(model, train_loader, val_loader, NUM_EPOCHS, LEARNING_RATE, DEVICE)
torch.save(model.state_dict(), "../pretrained/pets_unet_e20.pt")

 1/20: train_loss=0.6828, train_iou=0.5136, val_iou=0.5110, train_acc=0.8504, val_acc=0.8481
 2/20: train_loss=0.5190, train_iou=0.5759, val_iou=0.5731, train_acc=0.8753, val_acc=0.8743
 3/20: train_loss=0.4430, train_iou=0.6046, val_iou=0.6009, train_acc=0.8927, val_acc=0.8905
 4/20: train_loss=0.4086, train_iou=0.6279, val_iou=0.6188, train_acc=0.8969, val_acc=0.8916
 5/20: train_loss=0.3866, train_iou=0.6535, val_iou=0.6410, train_acc=0.9079, val_acc=0.9017
 6/20: train_loss=0.3607, train_iou=0.6587, val_iou=0.6455, train_acc=0.9097, val_acc=0.9027
 7/20: train_loss=0.3347, train_iou=0.6821, val_iou=0.6626, train_acc=0.9210, val_acc=0.9118
 8/20: train_loss=0.3247, train_iou=0.6863, val_iou=0.6668, train_acc=0.9221, val_acc=0.9124
 9/20: train_loss=0.3116, train_iou=0.7005, val_iou=0.6794, train_acc=0.9279, val_acc=0.9175
10/20: train_loss=0.2942, train_iou=0.7081, val_iou=0.6791, train_acc=0.9295, val_acc=0.9159
11/20: train_loss=0.2833, train_iou=0.7165, val_iou=0.6855, train_acc=

In [8]:
# Save training log to csv
rows = []
for mode, metrics in logger.items():
    for i, (iou, acc) in enumerate(zip(metrics["iou"], metrics["acc"])):
        rows.append((i, mode, iou, acc))

with open("../logs/training_log.csv", "w") as f:
    f.write("epoch,mode,loss,iou,acc\n")
    for row in rows:
        f.write(",".join(str(value) for value in row) + "\n")