# Fully Connected Experiments on FashionMNIST with a standardized setup


## 1. Setup

In [None]:
import gc
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import seaborn as sns

from filelock import FileLock
from IPython import display
from ray import tune
from ray.air import RunConfig, session
from ray.tune.schedulers import ASHAScheduler
from semitorch import MaxPlus, maxplus_parameters, nonmaxplus_parameters
from semitorch import MinPlus, minplus_parameters, nonminplus_parameters
from semitorch import TropicalSGD
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from tqdm.notebook import trange

data_path = os.path.abspath("./data" if os.path.isdir("./data") else "../data")
device = "cuda" if torch.cuda.is_available() else "cpu"

### Load FashionMNIST data

In [None]:
batch_size = 256
num_workers = 8

transforms_train = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.286,), (0.353,)),
    ]
)

transforms_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.286,), (0.353,)),
    ]
)

trainset = FashionMNIST(root=".", train=True, download=True, transform=transforms_train)
testset = FashionMNIST(root=".", train=False, download=True, transform=transforms_test)


def get_data_loaders():
    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser("~/data.lock")):
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return trainloader, testloader

## 2. Models

In [None]:
class Model(nn.Module):
    def __init__(self, model_name: str, layer_norm: bool = False, skip_connections: bool = False, k=None) -> None:
        super().__init__()
        self.name = model_name
        self.skip_connections = skip_connections

        self.stem = nn.Sequential(*filter(lambda layer: layer is not None, [
            nn.Conv2d(in_channels=1, out_channels=8, kernel_size=4, stride=4),
            nn.Flatten(),
            nn.LayerNorm(8 * 7 * 7) if layer_norm else None,
        ]))

        if model_name == "linear/relu":
            self.backbone_1 = nn.Sequential(
                nn.Linear(8 * 7 * 7, 300), nn.ReLU(),
                nn.Linear(300, 250), nn.ReLU(),
            )
            self.backbone_2 = nn.Sequential(
                nn.Linear(250 + (8 * 7 * 7 if self.skip_connections else 0), 200), nn.ReLU(),
                nn.Linear(200, 150), nn.ReLU(),
            )
            self.backbone_3 = nn.Sequential(
                nn.Linear(150 + (250 if self.skip_connections else 0), 100), nn.ReLU(),
                nn.Linear(100, 50), nn.ReLU(),
            )
        elif model_name == "linear/maxplus":
            self.backbone_1 = nn.Sequential(
                nn.Linear(8 * 7 * 7, 300),
                MaxPlus(300, 250, k=k),
            )
            self.backbone_2 = nn.Sequential(
                nn.Linear(250 + (8 * 7 * 7 if self.skip_connections else 0), 200),
                MaxPlus(200, 150, k=k),
            )
            self.backbone_3 = nn.Sequential(
                nn.Linear(150 + (250 if self.skip_connections else 0), 100),
                MaxPlus(100, 50, k=k),
            )
        elif model_name == "linear/minplus":
            self.backbone_1 = nn.Sequential(
                nn.Linear(8 * 7 * 7, 300),
                MinPlus(300, 250, k=k),
            )
            self.backbone_2 = nn.Sequential(
                nn.Linear(250 + (8 * 7 * 7 if self.skip_connections else 0), 200),
                MinPlus(200, 150, k=k),
            )
            self.backbone_3 = nn.Sequential(
                nn.Linear(150 + (250 if self.skip_connections else 0), 100),
                MinPlus(100, 50, k=k),
            )
        else:
            raise RuntimeError(f"Unknown model ({model_name})")

        self.head = nn.Linear(50, 10, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        input_0 = x
        result_0 = self.stem(input_0)

        input_1 = result_0
        result_1 = self.backbone_1(input_1)

        if self.skip_connections:
            input_2 = torch.cat((result_1, result_0), dim=-1)
        else:
            input_2 = result_1
        result_2 = self.backbone_2(input_2)

        if self.skip_connections:
            input_3 = torch.cat((result_2, result_1), dim=-1)
        else:
            input_3 = result_2
        result_3 = self.backbone_3(input_3)

        output = self.head(result_3)

        return output

## 3. Training

In [None]:
def accuracy(model: nn.Module, x: torch.Tensor, y: torch.Tensor) -> float:
    with torch.no_grad():
        yout = model(x)
        _, prediction = torch.max(yout.cpu(), dim=1)

        return (y.cpu() == prediction).sum().item() / float(y.numel())


def test(model: nn.Module, device: str, testloader: DataLoader) -> float:
    model.eval()
    accs = []

    with torch.no_grad():
        for x, y in testloader:
            x = x.to(device)
            accs.append(accuracy(model, x, y))

    return sum(accs) / len(accs)


def confusion_matrix(model: nn.Module, device: str, testloader: DataLoader) -> None:
    model.eval()

    conf_matrix = torch.zeros(len(testset.classes), len(testset.classes))

    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(device), y.to(device)
            yout = model(x)
            _, prediction = torch.max(yout.cpu(), dim=1)

            conf_matrix[y.cpu(), prediction] += 1

    plt.figure(figsize=(6, 4))

    df_cm = pd.DataFrame(conf_matrix, index=testset.classes, columns=testset.classes).astype(int)
    heatmap = sns.heatmap(df_cm, annot=True, fmt="d")

    heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=15)
    heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=15)
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


def train(
        model: nn.Module,
        device: str,
        trainloader: DataLoader,
        testloader: DataLoader,
        optimizers: list[torch.optim.Optimizer],
        schedulers: list[torch.optim.lr_scheduler],
        loss: torch.nn.modules.loss,
        epochs: int,
) -> None:
    accs = []  # list of accuracy on the test dataset for every epoch
    trainaccs = []  # a list of the accuracies of all the training batches

    fig, ax = plt.subplots(1, 1, figsize=[6, 4])
    hdisplay = display.display("", display_id=True)

    for _ in trange(epochs):
        model.train()

        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            for optimizer in optimizers:
                optimizer.zero_grad()
            yout = model(x)
            _, prediction = torch.max(yout.cpu(), dim=1)
            trainaccs.append((y.cpu() == prediction).sum().item() / float(y.numel()))

            l = loss(yout, y.squeeze())
            l.backward()
            for optimizer in optimizers:
                if isinstance(optimizer, TropicalSGD):
                    optimizer.step(input_tensor=x.cpu())
                else:
                    optimizer.step()

            for scheduler in schedulers:
                scheduler.step()

        accs.append(test(model, device, testloader))

        ax.clear()
        ax.set_xlim(0, epochs)
        ax.set_ylim(-0.02, 1.02)
        ax.plot(
            np.linspace(0, len(accs), len(trainaccs)),
            trainaccs,
            ".",
            markersize=1.5,
            markerfacecolor=(0, 0, 1, 0.3),
        )
        ax.plot(np.linspace(1, len(accs), len(accs)), accs)
        ax.text(
            0.6 * epochs,
            0.30,
            f"max test acc = {max(accs):.2%}",
            ha="center",
            fontsize=10,
        )
        hdisplay.update(fig)

        # prevents OOM when GPU memory is tight
        torch.cuda.empty_cache()
        gc.collect()

    confusion_matrix(model, device, testloader)

In [None]:
def run_ray_tune(config: dict):
    loss = nn.CrossEntropyLoss()
    epochs = 20
    trainloader, testloader = get_data_loaders()

    # Create model
    model_name = config["model_name"]
    layer_norm = config["layer_norm"]
    skip_connections = config["skip_connections"]
    k = config["k"]

    model = Model(model_name=model_name, layer_norm=layer_norm, skip_connections=skip_connections, k=k).to(device)

    # Separate model parameters
    if model_name == "linear/relu":
        linear_params = model.parameters()
        semiring_params = nn.ParameterList()
    elif model_name == "linear/maxplus":
        linear_params = nonmaxplus_parameters(model)
        semiring_params = maxplus_parameters(model)
    elif model_name == "linear/minplus":
        linear_params = nonminplus_parameters(model)
        semiring_params = minplus_parameters(model)
    else:
        raise RuntimeError(f"Unknown model ({model_name})")

    # Create linear optimizer
    linear_lr = config["linear_lr"]
    if config["linear_optimizer"] == "AdamW":
        linear_optimizer = torch.optim.AdamW(linear_params, lr=linear_lr, weight_decay=0.01)
    elif config["linear_optimizer"] == "SGD":
        linear_optimizer = torch.optim.SGD(linear_params, lr=linear_lr)
    else:
        raise RuntimeError(f'Unknown linear optimizer {config["linear_optimizer"]}')
    if config["linear_scheduler"]:
        linear_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            linear_optimizer,
            max_lr=linear_lr,
            anneal_strategy="linear",
            pct_start=0.3,
            three_phase=True,
            final_div_factor=1000.0,
            div_factor=10.0,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )
    else:
        linear_scheduler = None

    # Create semiring optimizer
    semiring_lr = config["semiring_lr"]
    if config["semiring_optimizer"] is None:
        semiring_optimizer = None
    elif config["semiring_optimizer"] == "SGD":
        semiring_optimizer = torch.optim.SGD(semiring_params, lr=semiring_lr)
    elif config["semiring_optimizer"] == "TropicalSGD":
        semiring_optimizer = TropicalSGD(semiring_params, lr=semiring_lr)
    else:
        raise RuntimeError(f'Unknown semiring optimizer {config["semiring_optimizer"]}')
    if config["semiring_scheduler"]:
        semiring_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            semiring_optimizer,
            max_lr=semiring_lr,
            anneal_strategy="linear",
            pct_start=0.3,
            three_phase=True,
            final_div_factor=1000.0,
            div_factor=10.0,
            steps_per_epoch=len(trainloader),
            epochs=epochs,
        )
    else:
        semiring_scheduler = None

    # Create optimizers and schedulers
    optimizers = filter(lambda opt: opt is not None, [linear_optimizer, semiring_optimizer])
    schedulers = filter(lambda sch: sch is not None, [linear_scheduler, semiring_scheduler])

    for _ in range(epochs):
        # Feed to training function
        reported_loss, reported_accuracy = train_ray_tune(
            model,
            device,
            trainloader,
            testloader,
            optimizers,
            schedulers,
            loss,
        )

        session.report({"loss": reported_loss, "accuracy": reported_accuracy})


def train_ray_tune(
        model: nn.Module,
        device: str,
        trainloader: DataLoader,
        testloader: DataLoader,
        optimizers: list[torch.optim.Optimizer],
        schedulers: list[torch.optim.lr_scheduler],
        loss: torch.nn.modules.loss,
) -> tuple[float, float]:
    model.train()

    for x, y in trainloader:
        x, y = x.to(device), y.to(device)
        for optimizer in optimizers:
            optimizer.zero_grad()

        l = loss(model(x), y.squeeze())
        l.backward()

        for optimizer in optimizers:
            if isinstance(optimizer, TropicalSGD):
                optimizer.step(input_tensor=x.cpu())
            else:
                optimizer.step()

        for scheduler in schedulers:
            scheduler.step()

        test(model, device, testloader)

        # prevents OOM when GPU memory is tight
        torch.cuda.empty_cache()
        gc.collect()

    total, correct, loss_val = 0, 0, 0
    for x, y in testloader:
        with torch.no_grad():
            x, y = x.to(device), y.to(device)

            yout = model(x)
            _, predicted = torch.max(yout.data, dim=1)

            total += y.size(0)
            correct += (predicted == y).sum().item()
            loss_val += loss(yout, y)

    return loss_val / len(testloader), correct / total


def ray_find_best_model_for(config):
    tuner = tune.Tuner(
        tune.with_resources(
            tune.with_parameters(run_ray_tune),
            resources={"gpu": 0.5, "cpu": 1}
        ),
        tune_config=tune.TuneConfig(
            metric="accuracy",
            mode="max",
            scheduler=ASHAScheduler(),
            num_samples=1,
        ),
        param_space=config,
        run_config=RunConfig(
            verbose=0,
            sync_config=tune.SyncConfig(
                syncer=None,
            ),
        ),
    )
    results = tuner.fit()

    best_result = results.get_best_result("accuracy", "max")
    print(f"Best trial config: {best_result.config}")

# Default Linear Models

In [None]:
ray_find_best_model_for({
    "model_name": "linear/relu",
    "layer_norm": tune.choice([True, False]),
    "skip_connections": tune.choice([True, False]),
    "k": None,
    "linear_optimizer": tune.choice(["AdamW", "SGD"]),
    "linear_scheduler": tune.choice([True, False]),
    "linear_lr": tune.loguniform(1e-6, 1),
    "semiring_lr": None,
    "semiring_optimizer": None,
    "semiring_scheduler": None,
})

Best linear model

In [None]:
best_linear_model = Model("linear/relu").to(device)
print(f"{best_linear_model.name} model has {len(list(best_linear_model.params()))} trainable parameters")

loss = nn.CrossEntropyLoss()
epochs = 20
trainloader, testloader = get_data_loaders()

best_linear_optimizer = torch.optim.AdamW(best_linear_model.parameters(), lr=lr1, weight_decay=0.01)
best_linear_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    best_linear_optimizer,
    max_lr=lr1,
    anneal_strategy="linear",
    pct_start=0.3,
    three_phase=True,
    final_div_factor=1000.0,
    div_factor=10.0,
    steps_per_epoch=len(trainloader),
    epochs=epochs,
)

train(
    best_linear_model,
    device,
    trainloader,
    testloader,
    [best_linear_optimizer],
    [best_linear_scheduler],
    loss,
    epochs,
)

# Tropical models

MaxPlus

In [None]:
ray_find_best_model_for({
    "model_name": "linear/maxplus",
    "layer_norm": tune.choice([True, False]),
    "skip_connections": tune.choice([True, False]),
    "k": tune.uniform(-10, -1),
    "linear_optimizer": tune.choice(["AdamW", "SGD"]),
    "linear_scheduler": tune.choice([True, False]),
    "linear_lr": tune.loguniform(1e-6, 1),
    "semiring_lr": tune.loguniform(1e-6, 1),
    "semiring_optimizer": tune.choice(["SGD", "TropicalSGD"]),
    "semiring_scheduler": tune.choice([True, False]),
})

Best MaxPlus model

In [None]:
best_maxplus_model = Model("linear/maxplus").to(device)

print(f"{best_maxplus_model.name} model has {len(list(best_maxplus_model.parameters()))} trainable parameters, "
      f"of which {len(list(nonmaxplus_parameters(best_maxplus_model)))} are linear "
      f"and {len(list(maxplus_parameters(best_maxplus_model)))} are semiring related")

loss = nn.CrossEntropyLoss()
epochs = 20
trainloader, testloader = get_data_loaders()

best_maxplus_linear_optimizer = torch.optim.SGD(nonmaxplus_parameters(best_maxplus_model), lr=lr1)
best_maxplus_linear_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    best_maxplus_linear_optimizer,
    max_lr=lr1,
    anneal_strategy="linear",
    pct_start=0.3,
    three_phase=True,
    final_div_factor=1000.0,
    div_factor=10.0,
    steps_per_epoch=len(trainloader),
    epochs=epochs,
)
best_maxplus_semiring_optimizer = torch.optim.SGD(maxplus_parameters(best_maxplus_model), lr=lr2)

best_maxplus_optimizers = [best_maxplus_linear_optimizer, best_maxplus_semiring_optimizer]
best_maxplus_schedulers = [best_maxplus_linear_scheduler]

train(
    best_maxplus_model,
    device,
    trainloader,
    testloader,
    best_maxplus_optimizers,
    best_maxplus_schedulers,
    loss,
    epochs,
)

MinPlus

In [None]:
ray_find_best_model_for({
    "model_name": "linear/minplus",
    "layer_norm": tune.choice([True, False]),
    "skip_connections": tune.choice([True, False]),
    "k": tune.uniform(1, 10),
    "linear_optimizer": tune.choice(["AdamW", "SGD"]),
    "linear_scheduler": tune.choice([True, False]),
    "linear_lr": tune.loguniform(1e-6, 1),
    "semiring_lr": tune.loguniform(1e-6, 1),
    "semiring_optimizer": tune.choice(["SGD", "TropicalSGD"]),
    "semiring_scheduler": tune.choice([True, False]),
})

Best MinPlus model

In [None]:
best_minplus_model = Model("linear/minplus").to(device)

print(f"{best_minplus_model.name} model has {len(list(best_minplus_model.parameters()))} trainable parameters, "
      f"of which {len(list(nonminplus_parameters(best_minplus_model)))} are linear "
      f"and {len(list(minplus_parameters(best_minplus_model)))} are semiring related")

loss = nn.CrossEntropyLoss()
epochs = 20
trainloader, testloader = get_data_loaders()

best_minplus_linear_optimizer = torch.optim.SGD(nonminplus_parameters(best_minplus_model), lr=lr1)
best_minplus_linear_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    best_minplus_linear_optimizer,
    max_lr=lr1,
    anneal_strategy="linear",
    pct_start=0.3,
    three_phase=True,
    final_div_factor=1000.0,
    div_factor=10.0,
    steps_per_epoch=len(trainloader),
    epochs=epochs,
)
best_minplus_semiring_optimizer = torch.optim.SGD(minplus_parameters(best_minplus_model), lr=lr2)

best_minplus_optimizers = [best_minplus_linear_optimizer, best_minplus_semiring_optimizer]
best_minplus_schedulers = [best_minplus_linear_scheduler]

train(
    best_maxplus_model,
    device,
    trainloader,
    testloader,
    best_minplus_optimizers,
    best_minplus_schedulers,
    loss,
    epochs,
)