In [None]:
import torch

torch.cuda.is_available()

In [None]:
torch.cuda.device_count()

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
from os.path import join
import torch

dataset_name = "cifar10"
config = {
    "order": "01",
    "task": "multiclass",
    "data": {
        "batch_size": 64,
        "size": 32,
    },
    "network": {"input_channels": 3, "linear_size": 576, "num_classes": 10},
    "train": {
        "criterion": torch.nn.CrossEntropyLoss(),
        "checkpoint_save_dir": "checkpoints",
        "epochs": 50,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    },
}

In [None]:
ORDER = config["order"]
SIZE = config["data"]["size"]
TASK = config["task"]
NUM_CLASSES = config["network"]["num_classes"]
EPOCHS = config["train"]["epochs"]
BATCH_SIZE = config["data"]["batch_size"]
INPUT_CHANNEL = config["network"]["input_channels"]
LINEAR_SIZE = config["network"]["linear_size"]
CRITERION = config["train"]["criterion"]
DEVICE = config["train"]["device"]
experiment_name = f"{ORDER}-{dataset_name}-{SIZE}"

In [None]:
checkpoint_save_dir = join(config["train"]["checkpoint_save_dir"], experiment_name)

In [None]:
checkpoint_save_dir

In [None]:
from torchvision import datasets, transforms

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        "./data",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]),
    ),
)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(
        "./data",
        train=False,
        download=True,
        transform=transforms.Compose([transforms.ToTensor()]),
    ),
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [None]:
from torch.optim.lr_scheduler import ExponentialLR
import lenet_models


def get_optimizers_dict(model_dict):
    """
    Get a dictionary of optimizers for the models.
    """
    return {
        model_name: torch.optim.Adam(model.parameters(), lr=0.001)
        for model_name, model in model_dict.items()
    }


def get_scheduler_dict(optimizer_dict):
    """
    Get a dictionary of schedulers for the optimizers.
    """
    return {
        model_name: ExponentialLR(optimizer, gamma=0.9)
        for model_name, optimizer in optimizer_dict.items()
    }


model_dict = lenet_models.get_constant_model_dict(INPUT_CHANNEL, NUM_CLASSES, LINEAR_SIZE)
optimizers_dict = get_optimizers_dict(model_dict)
schedulers_dict = get_scheduler_dict(optimizers_dict)

In [None]:
history_dict = {}

In [None]:
from tcnn.utils.experiment.train import train_and_test_model

for model_name, model in model_dict.items():
    print(f"Training model {model_name}")
    history_dict[model_name] = train_and_test_model(
        model,
        train_loader,
        test_loader,
        CRITERION,
        optimizers_dict[model_name],
        scheduler=schedulers_dict[model_name],
        epochs=EPOCHS,
        save_checkpoint=True,
        save_checkpoint_interval=1,
        checkpoint_save_dir=join(checkpoint_save_dir, model_name),
        task=TASK,
    )
    print("***" * 10)

In [None]:
from tcnn.utils.experiment.plot import plot_history

for model_name, history in history_dict.items():
    print(f"Model {model_name} history:")
    plot_history(history, model_name)

In [None]:
import torch

input_shape = (BATCH_SIZE, 3, SIZE, SIZE)
input_tensor = torch.randn(input_shape).to(DEVICE)

In [None]:
import torchprofile
from tcnn.utils.experiment.model import count_parameters
from tcnn.utils.experiment.train import eval_model

result = dict()
for model_name, model in model_dict.items():
    print(f"Evaluating model {model_name}")
    result[model_name] = dict()
    result[model_name]["macs"] = torchprofile.profile_macs(model, input_tensor)
    result[model_name]["params"] = count_parameters(model)
    result[model_name]["performance"] = eval_model(model, test_loader, CRITERION, TASK)
    with torch.no_grad():
        torch.cuda.empty_cache()
    print("***" * 10)

In [None]:
from tcnn.utils.experiment.plot import plot_history_dict

plot_history_dict(history_dict)

In [None]:
from tcnn.utils.experiment.log import show_test_result

show_test_result(result)