In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision
from model import FullModelInCode, FullModelInPaper, FullModel
from dataset import BodyMeasurementDataset
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

from DataSource import DataSource
dataSource = DataSource()

## Load dataset and model

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda:0'
if torch.backends.mps.is_available():
    device = 'mps:0'

training_set = BodyMeasurementDataset(dataSource.getTrainH5Path())
training_loader = DataLoader(training_set, batch_size=32, shuffle=True)
validate_set = BodyMeasurementDataset(dataSource.getValidateH5Path())
validate_loader = DataLoader(validate_set, batch_size=16, shuffle=True)
# model_in_code = FullModelInCode()
# model_in_paper = FullModelInPaper()
model = FullModel()

## show samples of dataset

In [None]:
def showimg(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(training_loader)
fronts, sides, labels = next(dataiter)
img_grid = torchvision.utils.make_grid(fronts)
showimg(img_grid)
img_grid = torchvision.utils.make_grid(sides)
showimg(img_grid)

## Train model

In [None]:
def train_one_epoch(model, dataloader, optimizer, device, epoch_index, tb_writer, write_every=5):
    running_loss = 0.0
    all_loss = 0.0
    
    for i, data in tqdm(enumerate(dataloader)):
        fronts, sides, labels = data
        fronts, sides, labels = fronts.to(device), sides.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(fronts, sides)
        loss = torch.nn.functional.mse_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss = loss.item()
        all_loss += running_loss
        if i % write_every == write_every - 1:
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', running_loss, tb_x)
            # print('  epoch: {} batch: {} loss: {}'.format(epoch_index, i, running_loss))
    
    return all_loss / len(dataloader)

def train(model, train_dataloader, validate_dataloader, device, lr=0.001, write_every=10, epochs=150):
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    writer = SummaryWriter('runs/model_{}'.format(timestamp))
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    model.to(device)
    epoch_number = 0
    best_vloss = 1_000_000.

    for epoch in range(epochs):
        print('EPOCH {}:'.format(epoch_number + 1))

        # Make sure gradient tracking is on, and do a pass over the data
        model.train(True)
        avg_loss = train_one_epoch(model, train_dataloader, optimizer, device, epoch_number, writer)

        running_vloss = 0.0
        # Set the model to evaluation mode, disabling dropout and using population
        # statistics for batch normalization.
        model.eval()

        # Disable gradient computation and reduce memory consumption.
        with torch.no_grad():
            for i, vdata in enumerate(validate_dataloader):
                vfronts, vsides, vlabels = vdata
                vfronts, vsides, vlabels = vfronts.to(device), vsides.to(device), vlabels.to(device)
                voutputs = model(vfronts, vsides)
                vloss = torch.nn.functional.mse_loss(voutputs, vlabels)
                running_vloss += vloss

        avg_vloss = running_vloss / (i + 1)
        print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

        # Log the running loss averaged per batch
        # for both training and validation
        writer.add_scalars('Training vs. Validation Loss',
                        { 'Training' : avg_loss, 'Validation' : avg_vloss },
                        epoch_number + 1)
        writer.flush()

        # Track best performance, and save the model's state
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            torch.save(model.state_dict(), dataSource.getModelPath(timestamp, epoch_number))

        epoch_number += 1

## Train model (from scratch)

In [None]:
train(model, training_loader, validate_loader, device)

In [None]:
# train from a checkpoint
model = FullModel()
model.load_state_dict(torch.load('model/model_20231028_220313_138.ckpt'))
train(model, training_loader, validate_loader, device, lr=0.0001, epochs=200)

## Train model replicated from code (from scratch)

## Train model replicated from paper (from scratch)