In [None]:
import torch
import torch.nn as nn

In [None]:
class Trainer:
    def __init__(self, model, device, criterion=None):
        self.model = model.to(device)
        self.device = device
        self.criterion = criterion or nn.CrossEntropyLoss()
        self.history = {
            "train_loss": [],
            "train_acc": [],
            "val_loss": [],
            "val_acc": [],
        }

    def train_epoch(self, train_dl, optimizer):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for imgs, labels in train_dl:
            imgs, labels = imgs.to(self.device), labels.to(self.device)

            optimizer.zero_grad()
            outs = self.model(imgs)
            loss = self.criterion(outs, labels)
            loss.backward()
            optimizer.step()

            running_loss = loss.item() * imgs.size(0)
            _, predicted = outs.max(1)
            total += labels.size(0)
            correct = predicted.eq(labels).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        return epoch_loss, epoch_acc

    def validate(self, val_dl):
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, labels in val_dl:
                imgs, labels = imgs.to(self.device), labels.to(self.device)

                outs = self.model(imgs)
                loss = self.criterion(outs, labels)

                running_loss = loss.item() * imgs.size(0)
                _, predicted = outs.max(1)
                total += labels.size(0)
                correct += labels.eq(predicted).sum().item()

        epoch_loss = running_loss / total
        epoch_acc = correct / total

        return epoch_loss, epoch_acc