In [9]:
import torch

import torch.nn as nn
import torch.optim as optim

from utility.cifar_dataset import get_dataloaders
import utility

from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE, EPOCHS
from pruning_metadata import PruningMetadata

from dataclasses import asdict

### Load the data

In [10]:
train_loader, validation_loader, test_loader = get_dataloaders(
    data_path=DATA_PATH, batch_size=BATCH_SIZE
)

Files already downloaded and verified
Files already downloaded and verified


## Training Phase

In [11]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [12]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = None
optimizer = optim.AdamW(base_model.parameters())

### Training loop

In [13]:
last_epoch = 0
for epoch in range(EPOCHS):
    train_loss = utility.training.train_epoch(
        module=base_model,
        train_dl=train_loader,
        optimizer=optimizer,
        loss_function=cross_entropy,
        device=device,
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=base_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    if early_stopper and early_stopper.early_stop(valid_loss):
        last_epoch = epoch
        print("Early stopping")
        break

Epoch: 0
Train Loss: 1.7302, Valid Loss: 1.5031, Valid Accuracy: 0.4582
Epoch: 1
Train Loss: 1.4363, Valid Loss: 1.3837, Valid Accuracy: 0.5046
Epoch: 2
Train Loss: 1.3177, Valid Loss: 1.2931, Valid Accuracy: 0.5347
Epoch: 3
Train Loss: 1.2223, Valid Loss: 1.2232, Valid Accuracy: 0.5621
Epoch: 4
Train Loss: 1.1559, Valid Loss: 1.2106, Valid Accuracy: 0.5704
Epoch: 5
Train Loss: 1.0973, Valid Loss: 1.1656, Valid Accuracy: 0.5905
Epoch: 6
Train Loss: 1.0502, Valid Loss: 1.1290, Valid Accuracy: 0.6022
Epoch: 7
Train Loss: 1.0040, Valid Loss: 1.1098, Valid Accuracy: 0.6120
Epoch: 8
Train Loss: 0.9568, Valid Loss: 1.0974, Valid Accuracy: 0.6130
Epoch: 9
Train Loss: 0.9189, Valid Loss: 1.0935, Valid Accuracy: 0.6185
Epoch: 10
Train Loss: 0.8843, Valid Loss: 1.0899, Valid Accuracy: 0.6237
Epoch: 11
Train Loss: 0.8482, Valid Loss: 1.0908, Valid Accuracy: 0.6223
Epoch: 12
Train Loss: 0.8113, Valid Loss: 1.0667, Valid Accuracy: 0.6310
Epoch: 13
Train Loss: 0.7824, Valid Loss: 1.0607, Valid Accur

In [14]:
test_loss, accuracy = utility.training.test(
    base_model, test_dl=test_loader, loss_function=cross_entropy, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 64.2%, Avg loss: 1.128266 



In [15]:
metadata = PruningMetadata(
    total_pruned=0,
    pruning_step=0,
    finetune_epochs=EPOCHS,
    total_epochs=EPOCHS,
    method=None,
    early_stopping=False,
)

In [16]:
utility.save.save_model_with_metadata(
    base_model,
    path=f"{MODELS_PATH}/{type(base_model).__name__}_cifar10",
    model_name=f"{type(base_model).__name__}_cifar10",
    metadata=asdict(metadata),
)