## 0 Import Modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [2]:
from sklearn.metrics import accuracy_score
from tqdm import tqdm

In [3]:
from neumeta.models.lenet import MnistNet

## 1 Functions

### 1.1 Training Loop Function

In [4]:
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for x, target in tqdm(train_loader):
        optimizer.zero_grad()
        x, target = x.to(device), target.to(device)

        # Forward pass
        predict = model(x) #+ compute_tv_loss_for_network(model, lambda_tv=1e-2)

        # Compute loss
        loss = criterion(predict, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss / len(train_loader)

### 1.2 Validation function

In [5]:
def validate(model, val_loader, criterion, device):
    model.eval()
    val_loss = 0.0
    preds = []
    gt = []
    with torch.no_grad():
        for x, target in val_loader:
            x, target = x.to(device), target.to(device)
            predict = model(x)
            pred = torch.argmax(predict, dim=-1)
            preds.append(pred)
            gt.append(target)
            loss = criterion(predict, target)
            val_loss += loss.item()
    return val_loss / len(val_loader), accuracy_score(torch.cat(gt).cpu().numpy(), torch.cat(preds).cpu().numpy())


## 2 Training LeNet-dim_32

### 2.0 Device

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

### 2.1 Hyperparameters

In [7]:
learning_rate = 1e-3
batch_size = 128
num_epochs = 10
hidden_dim= 32

### 2.2 Data Preparations

In [8]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

100%|██████████| 9.91M/9.91M [00:12<00:00, 785kB/s] 
100%|██████████| 28.9k/28.9k [00:00<00:00, 94.8kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 972kB/s] 
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.32MB/s]


In [9]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

### 2.3 Create Model

In [10]:
model = MnistNet(hidden_dim=hidden_dim).to(device)

### 2.4 Optimizer and Criterion

In [11]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

### 2.5 Training and Validation Loop

In [12]:
for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device=device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}")

    if (epoch + 1) % 1 == 0:
        val_loss, acc = validate(model, val_loader, criterion, device=device)
        print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")
torch.save(model.state_dict(), f"toy/mnist_{model.__class__.__name__}_dim{hidden_dim}.pth")
print("Training finished.")

100%|██████████| 469/469 [00:10<00:00, 44.46it/s]


Epoch [1/10], Training Loss: 0.5588
Epoch [1/10], Validation Loss: 0.2161, Validation Accuracy: 93.06%


100%|██████████| 469/469 [00:09<00:00, 47.45it/s]


Epoch [2/10], Training Loss: 0.1902
Epoch [2/10], Validation Loss: 0.1340, Validation Accuracy: 95.85%


100%|██████████| 469/469 [00:10<00:00, 46.39it/s]


Epoch [3/10], Training Loss: 0.1400
Epoch [3/10], Validation Loss: 0.1131, Validation Accuracy: 96.53%


100%|██████████| 469/469 [00:10<00:00, 46.50it/s]


Epoch [4/10], Training Loss: 0.1149
Epoch [4/10], Validation Loss: 0.0930, Validation Accuracy: 97.14%


100%|██████████| 469/469 [00:09<00:00, 49.55it/s]


Epoch [5/10], Training Loss: 0.0979
Epoch [5/10], Validation Loss: 0.0852, Validation Accuracy: 97.32%


100%|██████████| 469/469 [00:09<00:00, 48.65it/s]


Epoch [6/10], Training Loss: 0.0858
Epoch [6/10], Validation Loss: 0.0772, Validation Accuracy: 97.74%


100%|██████████| 469/469 [00:09<00:00, 47.72it/s]


Epoch [7/10], Training Loss: 0.0737
Epoch [7/10], Validation Loss: 0.0760, Validation Accuracy: 97.71%


100%|██████████| 469/469 [00:09<00:00, 48.67it/s]


Epoch [8/10], Training Loss: 0.0684
Epoch [8/10], Validation Loss: 0.0673, Validation Accuracy: 97.93%


100%|██████████| 469/469 [00:09<00:00, 47.66it/s]


Epoch [9/10], Training Loss: 0.0609
Epoch [9/10], Validation Loss: 0.0633, Validation Accuracy: 98.22%


100%|██████████| 469/469 [00:09<00:00, 47.69it/s]


Epoch [10/10], Training Loss: 0.0555
Epoch [10/10], Validation Loss: 0.0621, Validation Accuracy: 98.14%
Training finished.


## 3. Testing

In [14]:
hidden_dim= 32
batch_size = 128
# Data preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
val_dataset = datasets.MNIST(root='./data', train=False, transform=transform)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = MnistNet(hidden_dim=hidden_dim).to(device)
model.load_state_dict(torch.load(f"toy/mnist_{model.__class__.__name__}_dim{hidden_dim}.pth"))
criterion = torch.nn.CrossEntropyLoss()
val_loss, acc = validate(model, val_loader, criterion, device=device)
print(f"Test on MNIST, Validation Loss: {val_loss:.4f}, Validation Accuracy: {acc*100:.2f}%")

Test on MNIST, Validation Loss: 0.0621, Validation Accuracy: 98.14%
