In [1]:
import torchvision, torch
import numpy as np
import os
import sys
sys.path.append(os.getcwd())
from models.mlp import *
from models.hebbian import *
from models.hybrid import *
from tqdm import tqdm

In [2]:
def download_dataset(train_prop=0.8, keep_prop=0.5, ds_name=None):
    transform = torchvision.transforms.Compose(
        [
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )

    if ds_name is None:
        download = os.path.exists("../data/MNIST/")
        full_train_set = torchvision.datasets.MNIST(
            root="../data/", train=True, download=download, transform=transform
        )
        full_test_set = torchvision.datasets.MNIST(
            root="../data/", train=False, transform=transform
        )
    elif ds_name == 'EMNIST':
        # Might change split to "balanced" for balanced characters if performance is low
        download = os.path.exists("../data/EMNIST/")

        full_train_set = torchvision.datasets.EMNIST(
            root="../data/", split="byclass", download=download, transform=transform
        )
        full_test_set = torchvision.datasets.EMNIST(
            root="../data/", split="byclass", download=download, transform=transform
        )
    else:
        download = os.path.exists("../data/omniglot-py/")

        full_train_set = torchvision.datasets.Omniglot(
            root="../data/", download=download, transform=transform
        )
        full_test_set = torchvision.datasets.Omniglot(
            root="../data/", download=download, transform=transform
        )
    train_set, valid_set, _ = torch.utils.data.random_split(
        full_train_set, [train_prop * keep_prop, (1 - train_prop) * keep_prop, 1-keep_prop]
    )
    test_set, _ = torch.utils.data.random_split(
        full_test_set, [keep_prop, 1 - keep_prop]
    )

    print("Number of examples retained:")
    print(f"  {len(train_set)} (training)")
    print(f"  {len(valid_set)} (validation)")
    print(f"  {len(test_set)} (test)")

    return train_set, valid_set, test_set

train_set, valid_set, test_set = download_dataset(ds_name='EMNIST')

Number of examples retained:
  279173 (training)
  69793 (validation)
  348966 (test)


In [3]:
# Model
NUM_HIDDEN = 100
NUM_INPUTS = np.prod(train_set.dataset.data[0].shape)
NUM_OUTPUTS = len(train_set.dataset.classes)  # number of classes
ACTIVATION = "sigmoid"  # output constrained between 0 and 1
BIAS = False
MLP1 = MultiLayerPerceptron(
    num_inputs=NUM_INPUTS,
    num_outputs=NUM_OUTPUTS,
    num_hidden=NUM_HIDDEN,
    activation_type=ACTIVATION,
    bias=BIAS,
)
ACTIVATION = "relu"
MLP2 = MultiLayerPerceptron(
    num_inputs=NUM_INPUTS,
    num_outputs=NUM_OUTPUTS,
    num_hidden=NUM_HIDDEN,
    activation_type=ACTIVATION,
    bias=BIAS,
)

hybrid_MLP = HybridMLP(MLP1, MLP2, 'weighted', 0.9)

# Dataloaders
BATCH_SIZE = 32

train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=BATCH_SIZE, shuffle=True
)
valid_loader = torch.utils.data.DataLoader(
    valid_set, batch_size=BATCH_SIZE, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_set, batch_size=BATCH_SIZE, shuffle=False
)

In [4]:
####### FUNCTIONS ########


def train_model(model, train_loader, valid_loader, optimizer, num_epochs=5):
    results_dict = {
        "avg_train_losses": list(),
        "avg_valid_losses": list(),
        "avg_train_accuracies": list(),
        "avg_valid_accuracies": list(),
    }

    for e in tqdm(range(num_epochs)):
        no_train = True if e == 0 else False  # to get a baseline
        latest_epoch_results_dict = train_epoch(
            model, train_loader, valid_loader, optimizer=optimizer, no_train=no_train
        )

        for key, result in latest_epoch_results_dict.items():
            if key in results_dict.keys() and isinstance(results_dict[key], list):
                results_dict[key].append(latest_epoch_results_dict[key])
            else:
                results_dict[key] = result  # copy latest

    return results_dict


def train_epoch(model, train_loader, valid_loader, optimizer, no_train=False):

    criterion = torch.nn.NLLLoss()

    epoch_results_dict = dict()
    for dataset in ["train", "valid"]:
        for sub_str in ["correct_by_class", "seen_by_class"]:
            epoch_results_dict[f"{dataset}_{sub_str}"] = {
                i: 0 for i in range(model.num_outputs)
            }

    model.train()
    train_losses, train_acc = list(), list()
    for X, y in train_loader:
        y_pred = model(X, y=y)
        loss = criterion(torch.log(y_pred), y)
        acc = (torch.argmax(y_pred.detach(), axis=1) == y).sum() / len(y)
        train_losses.append(loss.item() * len(y))
        train_acc.append(acc.item() * len(y))
        update_results_by_class_in_place(
            y,
            y_pred.detach(),
            epoch_results_dict,
            dataset="train",
            num_classes=model.num_outputs,
        )
        optimizer.zero_grad()
        if not no_train:
            loss.backward()
            optimizer.step()

    num_items = len(train_loader.dataset)
    epoch_results_dict["avg_train_losses"] = np.sum(train_losses) / num_items
    epoch_results_dict["avg_train_accuracies"] = np.sum(train_acc) / num_items * 100

    model.eval()
    valid_losses, valid_acc = list(), list()
    with torch.no_grad():
        for X, y in valid_loader:
            y_pred = model(X)
            loss = criterion(torch.log(y_pred), y)
            acc = (torch.argmax(y_pred, axis=1) == y).sum() / len(y)
            valid_losses.append(loss.item() * len(y))
            valid_acc.append(acc.item() * len(y))
            update_results_by_class_in_place(
                y, y_pred.detach(), epoch_results_dict, dataset="valid"
            )

    num_items = len(valid_loader.dataset)
    epoch_results_dict["avg_valid_losses"] = np.sum(valid_losses) / num_items
    epoch_results_dict["avg_valid_accuracies"] = np.sum(valid_acc) / num_items * 100

    return epoch_results_dict


def update_results_by_class_in_place(
    y, y_pred, result_dict, dataset="train", num_classes=10
):
    y_pred = np.argmax(y_pred, axis=1)
    for i in result_dict[f"{dataset}_seen_by_class"].keys():
        idxs = np.where(y == int(i))[0]
        result_dict[f"{dataset}_seen_by_class"][int(i)] += len(idxs)

        num_correct = int(sum(y[idxs] == y_pred[idxs]))
        result_dict[f"{dataset}_correct_by_class"][int(i)] += num_correct

def evaluate_accuracy(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in data_loader:
            y_pred = model(X)
            predicted = torch.argmax(y_pred, axis=1)
            correct += (predicted == y).sum().item()
            total += y.size(0)
    accuracy = 100 * correct / total
    return accuracy

In [5]:
LR = 0.01
backprop_optimizer = BasicOptimizer(hybrid_MLP.parameters(), lr=LR)


NUM_EPOCHS = 5

MLP_results_dict = train_model(
    hybrid_MLP,
    train_loader,
    valid_loader,
    optimizer=backprop_optimizer,
    num_epochs=NUM_EPOCHS,
)
test_accuracy = evaluate_accuracy(hybrid_MLP, test_loader)
print(f"Test Accuracy: {test_accuracy:.2f}%")

  0%|          | 0/5 [00:00<?, ?it/s]