In [None]:
from torch import nn

class ConvRelu(nn.Module):
    def __init__(self, *args, **kwargs):
        super(ConvRelu, self).__init__()
        self.conv = nn.Conv2d(*args, **kwargs)
        self.relu = nn.ReLU()
        nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')
        nn.init.zeros_(self.conv.bias)

    def forward(self, x):
        return self.relu(self.conv(x))


class LinerRelu(nn.Module):
    def __init__(self, *args, dropout=0.5, **kwargs):
        super(LinerRelu, self).__init__()
        self.lin = nn.Linear(*args, **kwargs)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        nn.init.xavier_normal_(self.lin.weight)
        nn.init.zeros_(self.lin.bias)

    def forward(self, x):
        return self.dropout(self.relu(self.lin(x)))

In [None]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.c1 = ConvRelu(1, 96, kernel_size=11, stride=4)
        self.p1 = nn.MaxPool2d(3, 2)
        self.c2 = ConvRelu(96, 256, kernel_size=5, padding=2)
        self.p2 = nn.MaxPool2d(3, 2)
        self.c3 = ConvRelu(256, 384, kernel_size=3, padding=1)
        self.c4 = ConvRelu(384, 384, kernel_size=3, padding=1)
        self.c5 = ConvRelu(384, 256, kernel_size=3, padding=1)
        self.p3 = nn.MaxPool2d(3, 2)
        self.flatten = nn.Flatten()
        self.l1 = LinerRelu(6 * 6 * 256, 4096)
        self.l2 = LinerRelu(4096, 4096)
        self.l3 = LinerRelu(4096, 10)

    def forward(self, x):
        x = self.p1(self.c1(x))
        x = self.p2(self.c2(x))
        x = self.p3(self.c5(self.c4(self.c3(x))))
        x = self.flatten(x)
        x = self.l3(self.l2(self.l1(x)))
        return x

In [None]:
from torchinfo import summary
model = AlexNet()
summary(model, (1, 1, 227, 227))

In [None]:
class AlexNetSmall(nn.Module):
    def __init__(self, dropout=0.5):
        super(AlexNetSmall, self).__init__()
        self.c1 = ConvRelu(1, 96, kernel_size=11, stride=4)
        self.p1 = nn.MaxPool2d(3, 2)
        self.c2 = ConvRelu(96, 256, kernel_size=5, padding=2)
        self.p2 = nn.MaxPool2d(3, 2)
        self.c3 = ConvRelu(256, 384, kernel_size=3, padding=1)
        self.c4 = ConvRelu(384, 384, kernel_size=3, padding=1)
        self.c5 = ConvRelu(384, 256, kernel_size=3, padding=1)
        self.p3 = nn.MaxPool2d(3, 2)
        self.flatten = nn.Flatten()
        self.l1 = LinerRelu(256, 128, dropout=dropout)
        self.l2 = LinerRelu(128, 128, dropout=dropout)
        self.l3 = LinerRelu(128, 10)

    def forward(self, x):
        x = self.p1(self.c1(x))
        x = self.p2(self.c2(x))
        x = self.p3(self.c5(self.c4(self.c3(x))))
        x = self.flatten(x)
        x = self.l3(self.l2(self.l1(x)))
        return x

In [None]:
model_small= AlexNetSmall()
summary(model_small, (1, 1, 67, 67))

In [None]:
%run utils.py

In [None]:
from torchvision import datasets, transforms
from utils import train_val_split

full = datasets.FashionMNIST(root='./data', train=True, download=True)
test = datasets.FashionMNIST(root='./data', train=False, download=True)
train, valid = train_val_split(full, seed=666)

print(len(train), len(valid), len(test))

In [None]:
from utils import PackDataset

trans = transforms.Compose([transforms.Resize(size=67), transforms.ToTensor()])
train_data = PackDataset(train, transform=trans)
valid_data = PackDataset(valid, transform=trans)
test_data = PackDataset(test, transform=trans)

image, label = train_data[0]
print(image.size())

In [None]:
from skorch.callbacks import EarlyStopping, Checkpoint, EpochScoring, LRScheduler, ProgressBar
from torch.optim.lr_scheduler import CosineAnnealingLR

def control_callbacks(
        epochs, show_bar=True,
        model_name='best_model.pt', check_dir='./data/checkpoints'
    ):
    bar = ProgressBar()
    lr_scheduler = LRScheduler(policy=CosineAnnealingLR, T_max=epochs)
    early_stopping = EarlyStopping(monitor='valid_acc', lower_is_better=False, patience=6)
    train_acc = EpochScoring(name='train_acc', scoring='accuracy', on_train=True)
    check_point = Checkpoint(
        dirname=check_dir, f_params=model_name,
        monitor='valid_acc_best', load_best=True
    )
    calls = []
    if show_bar:
        calls.append(bar)
    calls.extend([lr_scheduler, early_stopping, train_acc, check_point])
    return calls

In [None]:
import torch
from skorch import NeuralNetClassifier
from skorch.helper import predefined_split

epochs = 50
calls = control_callbacks(epochs, check_dir='./data/alex-checkpoints')
net = NeuralNetClassifier(
    AlexNetSmall,
    criterion=nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    lr=0.001,
    batch_size=2048,
    max_epochs=epochs,
    train_split=predefined_split(valid_data),
    device='cuda' if torch.cuda.is_available() else 'cpu',
    callbacks=calls,
    classes=list(range(10)),
)
net.fit(X=train_data, y=None)

In [None]:
import matplotlib.pyplot as plt

def plot_history(net):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

    ax1.plot(net.history[:, 'train_loss'], label='Train Loss', linewidth=3)
    ax1.plot(net.history[:, 'valid_loss'], label='Valid Loss', linewidth=3)
    ax1.set_xlabel('Epoch', fontsize=14)
    ax1.set_ylabel('Loss', fontsize=14)
    ax1.set_title('Training & Validation Loss', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    ax1.legend()

    ax2.plot(net.history[:, 'train_acc'], label='Train Accuracy', linewidth=3)
    ax2.plot(net.history[:, 'valid_acc'], label='Valid Accuracy', linewidth=3)
    ax2.set_xlabel('Epoch', fontsize=14)
    ax2.set_ylabel('Accuracy (%)', fontsize=14)
    ax2.set_title('Validation Accuracy', fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    ax2.legend()

    plt.tight_layout() 
    plt.show()

In [None]:
plot_history(net)

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
import seaborn as sns
import numpy as np

def check_result(net, test_data):
    y_pred = net.predict(test_data) 
    y_prob = net.predict_proba(test_data) 
    y_true = np.array([y for x, y in iter(test_data)])     
    test_accuracy = accuracy_score(y_true, y_pred)
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print('='*100)

    cm = confusion_matrix(y_true, y_pred)
    print("Confusion Matrix:\n", cm)
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt="d", 
        cmap="Blues",
        annot_kws={"size": 10},
    )
    plt.xlabel("Predicted Label", fontsize=14)
    plt.ylabel("True Label", fontsize=14)
    plt.title("Confusion Matrix (Test Set)", fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)
    plt.show()
    print('='*100)
    y_hat = np.asarray(y_true)                 
    wrong_idx = np.where(y_pred != y_hat)[0]
    error_list = []
    for i in wrong_idx:
        features, _ = test_data[i]                  
        error_list.append({
            "features": features,              
            "true_label": int(y_hat[i]),
            "pred_label": int(y_pred[i]),
            "probabilities": y_prob[i]      
        })

    print(f'error number: {len(error_list)}')
    return error_list

In [None]:
check_result(net, test_data)

In [None]:
from sklearn.model_selection import ParameterGrid


epochs = 1
param_grid = {
    'lr': [0.01, 0.005],
    'batch_size': [2048],
}

results = {
    'best_params': None,
    'best_acc': 0.0,
    'all_results': []
}

for params in ParameterGrid(param_grid):
    print(f"\nTraining with params: {params}")

    net = NeuralNetClassifier(
        AlexNetSmall,
        criterion=nn.CrossEntropyLoss,
        optimizer=torch.optim.Adam,
        lr=params['lr'],
        batch_size=params['batch_size'],
        max_epochs=epochs,
        train_split=predefined_split(valid_data),
        device='cuda' if torch.cuda.is_available() else 'cpu',
        callbacks=calls,
        classes=list(range(10)),
    )
    net.fit(X=train_data, y=None)
    valid_acc = max(net.history[:, 'valid_acc'])
    current_result = {'params': params, 'valid_acc': valid_acc}
    results['all_results'].append(current_result)

    if valid_acc > results['best_acc']:
        results['best_acc'] = valid_acc
        results['best_params'] = params

    print(f"\nBest params: {results['best_params']}, best acc: {results['best_acc']}")