In [None]:
%load_ext autoreload
%autoreload 2
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
from config.settings import *

In [None]:
from functools import partial
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split
import torchvision
import torchvision.transforms as transforms
from ray import tune
from ray.air import Checkpoint, session
from ray.tune.schedulers import ASHAScheduler
from src.models.networks import UNet
from src.data.loaders import load_data_mitosemseg

In [None]:
def train_unet(config, num_workers=0, max_epochs=100):
    net = UNet(encoder_depth=config['encoder_depth'])

    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
    net.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=config['lr'], momentum=0.9)

    checkpoint = session.get_checkpoint()
    if checkpoint:
        checkpoint_state = checkpoint.to_dict()
        start_epoch = checkpoint_state["epoch"]
        net.load_state_dict(checkpoint_state["net_state_dict"])
        optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
    else:
        start_epoch = 1

    train_iter, val_iter = load_data_mitosemseg(batch_size=config['batch_size'],
                                                num_workers=num_workers,
                                                split=0.85)

    for epoch in range(start_epoch, max_epochs+1):
        running_loss = 0.0
        epoch_steps = 0
        for i, (inputs, targets) in enumerate(train_iter,1):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_steps +=1
            if i % 2000 == 0:
                print(f"[{epoch}, {i:>5}] loss: {running_loss/epoch_steps:.3f}")
                running_loss = 0.0

        val_loss = 0.0
        val_steps = 0
        total = 0
        correct = 0
        for i, (inputs, targets) in enumerate(val_iter, 1):
            with torch.no_grad():
                inputs, targets = inputs.to(device), targets.to(device)

                ouputs = net(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

                loss = criterion(outputs, targets)
                val_loss += loss.cpu().numpy()
                val_steps += 1

        checkpoint_data = {
            "epoch": epoch,
            "net_state_dict": net.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }
        checkpoint = Checkpoint.from_dict(checkpoint_data)

        session.report(
            {"loss": val_loss/val_steps, "accuracy": correct/total},
            checkpoint=checkpoint,
        )
    print("Finished training")

In [None]:
def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
    config = {
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2**x for x in range(6)]),
        "encoder_depth": tune.choice(range(3,8)),
    }

    scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2,
    )

    result = tune.run(
        train_unet,
        resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
        config=config,
        num_samples=num_samples,
        scheduler=scheduler,
        checkpoint_at_end=True)
    
    best_trial = result.get_best_trial("loss", "min", "last")
    print(f"Best trial config: {best_trial.config}")
    print(f"Best trial final validation loss: {best_trial.last_result['loss']}")
    print(f"Best trial final validation accuracy: {best_trial.last_result['accuracy']}")


if __name__ == "__main__":
    main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)