In [None]:
import torch

import torch.nn as nn
import torch.optim as optim

from utility.cifar_dataset import get_dataloaders
import utility

from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE, EPOCHS
from pruning_metadata import PruningMetadata

from dataclasses import asdict

### Load the data

In [None]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

## Training Phase

In [None]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

In [None]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = None
optimizer = optim.AdamW(base_model.parameters())

### Training loop

In [None]:
last_epoch = 0
for epoch in range(EPOCHS):
    train_loss = utility.training.train_epoch(
        module=base_model,
        train_dl=train_loader,
        optimizer=optimizer,
        loss_function=cross_entropy,
        device=device,
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=base_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    if early_stopper and early_stopper.early_stop(valid_loss):
        last_epoch = epoch
        print("Early stopping")
        break

In [None]:
test_loss, accuracy = utility.training.test(
    base_model, test_dl=test_loader, loss_function=cross_entropy, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
metadata = PruningMetadata(
    total_pruned=0,
    pruning_step=0,
    finetune_epochs=EPOCHS,
    total_epochs=EPOCHS,
    method=None,
    early_stopping=False,
)

In [None]:
utility.save.save_model_with_metadata(
    base_model,
    path=f"{MODELS_PATH}/{type(base_model).__name__}_cifar10",
    model_name=f"{type(base_model).__name__}_cifar10",
    metadata=asdict(metadata),
)