In [1]:
import torch

import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

import utility

from constants import DATA_PATH, MODELS_PATH
from LeNet import LeNet, BATCH_SIZE, LEARNING_RATE, EPOCHS, MOMENTUM
from pruning_metadata import PruningMetadata

from dataclasses import asdict

### Load the data

In [2]:
# load FashionMNIST data
transform = transforms.Compose([transforms.ToTensor()])

# split into validation and train datasets
train_ds = datasets.FashionMNIST(
    DATA_PATH, train=True, transform=transform, download=True
)
train_ds, valid_ds = random_split(train_ds, [0.8, 0.2])

test_ds = datasets.FashionMNIST(
    DATA_PATH, train=False, transform=transform, download=True
)

## Training Phase

In [3]:
# Get cpu, gpu or mps device for training.
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using {torch.cuda.get_device_name(torch.cuda.current_device())}")

Using NVIDIA GeForce GTX 1660 Ti


In [4]:
base_model = LeNet().to(device)

# Define the loss function and optimizer
cross_entropy = nn.CrossEntropyLoss()
early_stopper = utility.early_stopping.EarlyStopper(patience=3, min_delta=0)
optimizer = optim.SGD(base_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

# create the data loaders
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(valid_ds, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

### Training loop

In [5]:
last_epoch = 0
for epoch in range(EPOCHS):
    train_loss = utility.training.train_epoch(
        module=base_model,
        train_dl=train_loader,
        optimizer=optimizer,
        loss_function=cross_entropy,
        device=device,
    )

    valid_loss, valid_accuracy = utility.training.validate(
        module=base_model,
        valid_dl=validation_loader,
        loss_function=cross_entropy,
        device=device,
    )

    print(
        f"Epoch: {epoch:}\nTrain Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}"
    )

    if early_stopper.early_stop(valid_loss):
        last_epoch = epoch
        print("Early stopping")
        break

Epoch: 0
Train Loss: 0.7960, Valid Loss: 0.5905, Valid Accuracy: 0.7706
Epoch: 1
Train Loss: 0.4431, Valid Loss: 0.4164, Valid Accuracy: 0.8509
Epoch: 2
Train Loss: 0.3780, Valid Loss: 0.3921, Valid Accuracy: 0.8569
Epoch: 3
Train Loss: 0.3392, Valid Loss: 0.3322, Valid Accuracy: 0.8805
Epoch: 4
Train Loss: 0.3150, Valid Loss: 0.3353, Valid Accuracy: 0.8801
Epoch: 5
Train Loss: 0.2963, Valid Loss: 0.3178, Valid Accuracy: 0.8785
Epoch: 6
Train Loss: 0.2809, Valid Loss: 0.2980, Valid Accuracy: 0.8913
Epoch: 7
Train Loss: 0.2685, Valid Loss: 0.3011, Valid Accuracy: 0.8932
Epoch: 8
Train Loss: 0.2554, Valid Loss: 0.2769, Valid Accuracy: 0.9007
Epoch: 9
Train Loss: 0.2464, Valid Loss: 0.2784, Valid Accuracy: 0.9024
Epoch: 10
Train Loss: 0.2367, Valid Loss: 0.2783, Valid Accuracy: 0.8999
Epoch: 11
Train Loss: 0.2264, Valid Loss: 0.2837, Valid Accuracy: 0.8972
Early stopping


In [6]:
test_loss, accuracy = utility.training.test(
    base_model, test_dl=test_loader, loss_function=cross_entropy, device=device
)
print(f"Test Error: \n Accuracy: {accuracy:>0.1f}%, Avg loss: {test_loss:>8f} \n")

Test Error: 
 Accuracy: 89.0%, Avg loss: 0.302961 



In [7]:
metadata = PruningMetadata(
    total_pruned=0,
    pruning_step=0,
    finetune_epochs=last_epoch + 1,
    method=None,
    early_stopping=True,
)

In [9]:
utility.save.save_model_with_metadata(
    base_model,
    path=f"{MODELS_PATH}/{type(base_model).__name__}_fmnist",
    model_name=f"{type(base_model).__name__}_fmnist",
    metadata=asdict(metadata),
)