In [None]:
import torch
import torchvision
import torchvision.transforms as T
import nbimporter
from unet import YouNet
import utils

In [None]:
NUM_WORKERS = 0
BATCH_SIZE = 32
NUM_EPOCHS = 10
IMAGE_HEIGHT = 320
IMAGE_WIDTH = 480
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
DATA_DIR = 'C:/Users/Hayden/Machine Learning/d2l/d2l-en/pytorch/chapter_computer-vision/data/VOCdevkit/VOC2012/'

In [None]:
def train_batch(loader, model, optimizer, loss, scaler):
    for batch_i, (features, labels) in enumerate(loader):
        features.to(DEVICE)
        labels.to(DEVICE)

        # Forward propagation
        with torch.autocast():
            predictions = model(features)
            cost = loss(predictions, labels)
        
        # Backward propagation
        optimizer.zero_grad()
        scaler.scale(cost).backward()
        scaler.step(optimizer)
        scaler.update()

In [None]:
def train(model, transforms, batch_size):
    model = YouNet(in_channels=128, out_channels=20).to(DEVICE)
    train_transform = T.Compose([
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.ToTensor()
    ])
    validation_transform = T.Compose([
        T.ToTensor()
    ])
    train_loader, validation_loader = utils.get_loaders(
        data_dir=DATA_DIR,
        train_transforms=train_transform,
        validation_transforms=validation_transform,
        batch_size=BATCH_SIZE,
    )

    # Initialize weights
    # def init_weights(module):
    #     if type(module) in [torch.nn.Linear, torch.nn.Conv2d]:
    #         torch.nn.init.normal_(module.weight, std=0.01)
    # model.apply(init_weights)

    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, wd=WEIGHT_DECAY)


    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_batch(train_loader, model, optimizer, loss, scaler)

        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        utils.save_checkpoint(checkpoint)

        # TODO: Complete the predictions accuracy function and save predictions to disc.
        
        # predictions accuracy
        
        # utils.save_predictions(validation_loader, model, data_dir='predictions/', device=DEVICE)