In [None]:
import numpy as np

import torch
from tqdm import tqdm 


In [None]:
# Examples of batch processing

def process_batch_features(batch):
    return batch.x
    return batch.x, batch.edge_index
    return batch.x.view(-1, 28, 28)

def process_batch_labels(batch):
     return batch.y


In [None]:
def train_epoch(model, loss_f, optimizer, data_loader, device):

    model.train()

    data_iterator = tqdm(data_loader, desc='Training')
    losses = []

    for data in data_iterator:

        data.to(device)
        features = process_batch_features(data)
        labels = process_batch_labels(data)

        optimizer.zero_grad()
        output = model(*features)
        loss = loss_f(output, labels)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        data_iterator.set_postfix(loss=np.mean(losses))

        return np.array(losses)


def eval_epoch(model, loss_f, score_f, data_loader, device):

    model.eval()

    data_iterator = tqdm(data_loader, desc='Validation')
    losses = []
    scores = []

    for data in data_iterator:

        data.to(device)
        features = process_batch_features(data)
        labels = process_batch_labels(data)

        with torch.no_grad():
            output = model(*features)
            loss = loss_f(output, labels)
            # switch to CPU
            score = score(output, labels)

        losses.append(loss)
        scores.append(score)

    return np.array(losses), np.array(scores)

In [None]:
def train(
    model, loss_f, optimizer, score_f,
    train_dataloader, val_dataloader, device,
    epoch_n, val_freq
):

    scores = []

    for epoch in range(epoch_n):

        epoch_losses = train_epoch(train_dataloader, model, loss_f, optimizer, device)
        print("Epoch {:05d} | Loss: {:.4f}".format(epoch + 1, epoch_losses.mean()))

        if epoch % val_freq == 0:
            epoch_losses, epoch_scores = eval_epoch(model, loss_f, score_f, val_dataloader, device)
            print("Score: {:.4f}".format(scores.mean()))
            scores.append(epoch_scores)

    return scores