In [1]:
import torch
from torch import nn
from torch.utils import data
from tqdm import tqdm
from utils.fashion_mnist import load_data_fashion_mnist

In [2]:
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),
    nn.Linear(120, 84), nn.Sigmoid(),
    nn.Linear(84, 10)
)

In [3]:
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape: \t',X.shape)

Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
AvgPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
AvgPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])


In [2]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)

In [3]:
def evaluate(model:torch.Tensor, dataloader:data.DataLoader, device:str) -> float:
    model.to(device)
    model.eval()

    total = 0
    correct = 0
    for X, y in tqdm(dataloader):
        X, y = X.to(device), y.to(device)
        y_hat = model(X)
        
        predict = torch.argmax(y_hat, dim=1)
        correct += (predict == y).sum().item()
        total += y.shape[0]

    model.train()
    
    return correct / total

In [4]:
def train(model:torch.Tensor, dataloader:data.DataLoader, optim:torch.optim.Optimizer, criterion:torch.nn, test_data:data.DataLoader, num_epochs:int=10, device:str="cuda") -> None:
    model.to(device)
    model.train()

    for _ in range(num_epochs):
        total = 0
        total_loss = 0.0
        total_correct = 0.0

        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))

        for i, (X, y) in progress_bar:
            X, y = X.to(device), y.to(device)

            optim.zero_grad()
            y_hat = model(X)
            l = criterion(y_hat, y)
            l.backward()
            optim.step()

            total_loss += l.item()
            predict = torch.argmax(y_hat, dim=1)
            total += y.shape[0]
            total_correct += (predict.type(y.dtype) == y).sum().item()

            progress_bar.set_description(f"Epoch {_}")
            progress_bar.set_postfix(loss=total_loss/(i+1), accuracy=100.*total_correct/total, Learning_rate=optim.param_groups[0]['lr'])

        test_acc = evaluate(model, test_data, device)
        accuracy = 100 * total_correct / total
        print(f"Loss: {total_loss / len(dataloader)}, Accuracy: {accuracy}%, test_acc: {test_acc}")

    

In [7]:
optim = torch.optim.Adam(net.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

In [8]:
train(net, train_iter, optim, criterion, test_iter, num_epochs=20)

Epoch 0: 100%|██████████| 938/938 [00:05<00:00, 156.76it/s, Learning_rate=0.001, accuracy=55.3, loss=1.18]
100%|██████████| 157/157 [00:00<00:00, 241.52it/s]

Loss: 1.1817171348691748, Accuracy: 55.27166666666667%, test_acc: 0.7019



Epoch 1: 100%|██████████| 938/938 [00:06<00:00, 141.95it/s, Learning_rate=0.001, accuracy=73.8, loss=0.675]
100%|██████████| 157/157 [00:00<00:00, 220.35it/s]

Loss: 0.6748720108827294, Accuracy: 73.81333333333333%, test_acc: 0.7398



Epoch 2: 100%|██████████| 938/938 [00:06<00:00, 139.93it/s, Learning_rate=0.001, accuracy=77.3, loss=0.583]
100%|██████████| 157/157 [00:00<00:00, 227.62it/s]

Loss: 0.5825603211612336, Accuracy: 77.33833333333334%, test_acc: 0.7759



Epoch 3: 100%|██████████| 938/938 [00:05<00:00, 165.52it/s, Learning_rate=0.001, accuracy=79.9, loss=0.529]
100%|██████████| 157/157 [00:00<00:00, 282.26it/s]

Loss: 0.5287880431423818, Accuracy: 79.94666666666667%, test_acc: 0.7964



Epoch 4: 100%|██████████| 938/938 [00:05<00:00, 177.33it/s, Learning_rate=0.001, accuracy=81.8, loss=0.484]
100%|██████████| 157/157 [00:00<00:00, 265.40it/s]

Loss: 0.4839009398590527, Accuracy: 81.78%, test_acc: 0.8175



Epoch 5: 100%|██████████| 938/938 [00:05<00:00, 175.84it/s, Learning_rate=0.001, accuracy=83.2, loss=0.449]
100%|██████████| 157/157 [00:00<00:00, 268.03it/s]

Loss: 0.4493539646299663, Accuracy: 83.18333333333334%, test_acc: 0.8266



Epoch 6: 100%|██████████| 938/938 [00:05<00:00, 170.90it/s, Learning_rate=0.001, accuracy=84.3, loss=0.427]
100%|██████████| 157/157 [00:00<00:00, 265.03it/s]

Loss: 0.4268936789366228, Accuracy: 84.28166666666667%, test_acc: 0.8291



Epoch 7: 100%|██████████| 938/938 [00:05<00:00, 179.59it/s, Learning_rate=0.001, accuracy=84.9, loss=0.407]
100%|██████████| 157/157 [00:00<00:00, 276.87it/s]

Loss: 0.4072182572313717, Accuracy: 84.88666666666667%, test_acc: 0.8395



Epoch 8: 100%|██████████| 938/938 [00:04<00:00, 193.84it/s, Learning_rate=0.001, accuracy=85.5, loss=0.391]
100%|██████████| 157/157 [00:00<00:00, 321.07it/s]

Loss: 0.3912442306688091, Accuracy: 85.50833333333334%, test_acc: 0.8436



Epoch 9: 100%|██████████| 938/938 [00:05<00:00, 170.40it/s, Learning_rate=0.001, accuracy=86.1, loss=0.376]
100%|██████████| 157/157 [00:00<00:00, 326.30it/s]

Loss: 0.37633228705508875, Accuracy: 86.13%, test_acc: 0.8454



Epoch 10: 100%|██████████| 938/938 [00:04<00:00, 196.69it/s, Learning_rate=0.001, accuracy=86.6, loss=0.364]
100%|██████████| 157/157 [00:00<00:00, 270.41it/s]

Loss: 0.36408047044455116, Accuracy: 86.60166666666667%, test_acc: 0.8531



Epoch 11: 100%|██████████| 938/938 [00:04<00:00, 191.54it/s, Learning_rate=0.001, accuracy=86.9, loss=0.352]
100%|██████████| 157/157 [00:00<00:00, 312.40it/s]

Loss: 0.3522199939118265, Accuracy: 86.90666666666667%, test_acc: 0.8537



Epoch 12: 100%|██████████| 938/938 [00:05<00:00, 175.67it/s, Learning_rate=0.001, accuracy=87.2, loss=0.342]
100%|██████████| 157/157 [00:00<00:00, 259.41it/s]

Loss: 0.34245355463803195, Accuracy: 87.24666666666667%, test_acc: 0.8616



Epoch 13: 100%|██████████| 938/938 [00:05<00:00, 182.83it/s, Learning_rate=0.001, accuracy=87.6, loss=0.335]
100%|██████████| 157/157 [00:00<00:00, 357.37it/s]

Loss: 0.3347955585511, Accuracy: 87.605%, test_acc: 0.8651



Epoch 14: 100%|██████████| 938/938 [00:05<00:00, 181.64it/s, Learning_rate=0.001, accuracy=88, loss=0.326]  
100%|██████████| 157/157 [00:00<00:00, 276.94it/s]

Loss: 0.3255427184659662, Accuracy: 87.95%, test_acc: 0.8632



Epoch 15: 100%|██████████| 938/938 [00:05<00:00, 165.48it/s, Learning_rate=0.001, accuracy=88.1, loss=0.32] 
100%|██████████| 157/157 [00:00<00:00, 229.15it/s]

Loss: 0.3196786491792085, Accuracy: 88.14166666666667%, test_acc: 0.8731



Epoch 16: 100%|██████████| 938/938 [00:05<00:00, 183.50it/s, Learning_rate=0.001, accuracy=88.3, loss=0.314]
100%|██████████| 157/157 [00:00<00:00, 260.11it/s]

Loss: 0.3139794432421102, Accuracy: 88.315%, test_acc: 0.8713



Epoch 17: 100%|██████████| 938/938 [00:05<00:00, 175.41it/s, Learning_rate=0.001, accuracy=88.5, loss=0.307]
100%|██████████| 157/157 [00:00<00:00, 276.18it/s]

Loss: 0.3073392922181819, Accuracy: 88.50166666666667%, test_acc: 0.8621



Epoch 18: 100%|██████████| 938/938 [00:05<00:00, 176.86it/s, Learning_rate=0.001, accuracy=88.8, loss=0.301]
100%|██████████| 157/157 [00:00<00:00, 258.79it/s]

Loss: 0.30100927258859567, Accuracy: 88.79166666666667%, test_acc: 0.8735



Epoch 19: 100%|██████████| 938/938 [00:04<00:00, 195.81it/s, Learning_rate=0.001, accuracy=89.1, loss=0.295]
100%|██████████| 157/157 [00:00<00:00, 316.07it/s]

Loss: 0.2950122020423794, Accuracy: 89.085%, test_acc: 0.8743





In [24]:
class ImproveLeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, padding=2), nn.Tanh(), # [28, 28]
            nn.Dropout(0.2),
            nn.Conv2d(32, 48, kernel_size=3, padding=1), nn.Tanh(), # [28, 28]
            nn.MaxPool2d(kernel_size=2, stride=2), # [14, 14]
            nn.Conv2d(48, 16, kernel_size=5), nn.Tanh(), # [10, 10]
            nn.MaxPool2d(kernel_size=2, stride=2),  # [5, 5]
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 128), nn.LeakyReLU(),
            nn.Linear(128, 64), nn.LeakyReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, X):
        return self.net(X)

In [25]:
improveModel = ImproveLeNet()
optim = torch.optim.Adam(improveModel.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
train(improveModel, train_iter, optim, criterion, test_iter, num_epochs=20)

Epoch 0: 100%|██████████| 235/235 [00:02<00:00, 86.71it/s, Learning_rate=0.001, accuracy=76.1, loss=0.658]
100%|██████████| 40/40 [00:00<00:00, 98.35it/s]

Loss: 0.657974954615248, Accuracy: 76.14666666666666%, test_acc: 0.8429



Epoch 1: 100%|██████████| 235/235 [00:02<00:00, 93.80it/s, Learning_rate=0.001, accuracy=86.2, loss=0.371]
100%|██████████| 40/40 [00:00<00:00, 88.78it/s]

Loss: 0.3712228908817819, Accuracy: 86.22666666666667%, test_acc: 0.8663



Epoch 2: 100%|██████████| 235/235 [00:02<00:00, 93.31it/s, Learning_rate=0.001, accuracy=87.9, loss=0.326]
100%|██████████| 40/40 [00:00<00:00, 94.06it/s]

Loss: 0.32564190924167635, Accuracy: 87.86166666666666%, test_acc: 0.8752



Epoch 3: 100%|██████████| 235/235 [00:02<00:00, 92.69it/s, Learning_rate=0.001, accuracy=89.2, loss=0.295]
100%|██████████| 40/40 [00:00<00:00, 96.37it/s]

Loss: 0.29492687811242774, Accuracy: 89.15666666666667%, test_acc: 0.8858



Epoch 4: 100%|██████████| 235/235 [00:02<00:00, 92.61it/s, Learning_rate=0.001, accuracy=90, loss=0.273]  
100%|██████████| 40/40 [00:00<00:00, 102.48it/s]

Loss: 0.2726254943203419, Accuracy: 89.95833333333333%, test_acc: 0.8918



Epoch 5: 100%|██████████| 235/235 [00:02<00:00, 94.80it/s, Learning_rate=0.001, accuracy=90.5, loss=0.254]
100%|██████████| 40/40 [00:00<00:00, 97.42it/s]

Loss: 0.2537701392427404, Accuracy: 90.54333333333334%, test_acc: 0.896



Epoch 6: 100%|██████████| 235/235 [00:02<00:00, 92.79it/s, Learning_rate=0.001, accuracy=90.8, loss=0.244]
100%|██████████| 40/40 [00:00<00:00, 95.39it/s]

Loss: 0.24372950843039978, Accuracy: 90.83833333333334%, test_acc: 0.8922



Epoch 7: 100%|██████████| 235/235 [00:02<00:00, 94.21it/s, Learning_rate=0.001, accuracy=91.3, loss=0.234]
100%|██████████| 40/40 [00:00<00:00, 94.57it/s]

Loss: 0.23355789355775144, Accuracy: 91.28833333333333%, test_acc: 0.9008



Epoch 8: 100%|██████████| 235/235 [00:02<00:00, 93.62it/s, Learning_rate=0.001, accuracy=92.1, loss=0.216]
100%|██████████| 40/40 [00:00<00:00, 93.28it/s]

Loss: 0.21598959672958293, Accuracy: 92.05333333333333%, test_acc: 0.9031



Epoch 9: 100%|██████████| 235/235 [00:02<00:00, 93.76it/s, Learning_rate=0.001, accuracy=92.3, loss=0.207]
100%|██████████| 40/40 [00:00<00:00, 94.46it/s]

Loss: 0.2071053776335209, Accuracy: 92.28%, test_acc: 0.9061



Epoch 10: 100%|██████████| 235/235 [00:02<00:00, 91.68it/s, Learning_rate=0.001, accuracy=92.7, loss=0.197]
100%|██████████| 40/40 [00:00<00:00, 90.18it/s]

Loss: 0.1971926487189658, Accuracy: 92.67166666666667%, test_acc: 0.9074



Epoch 11: 100%|██████████| 235/235 [00:02<00:00, 92.40it/s, Learning_rate=0.001, accuracy=93.1, loss=0.187]
100%|██████████| 40/40 [00:00<00:00, 87.45it/s]

Loss: 0.18748561755773868, Accuracy: 93.095%, test_acc: 0.9116



Epoch 12: 100%|██████████| 235/235 [00:02<00:00, 85.60it/s, Learning_rate=0.001, accuracy=93.4, loss=0.178]
100%|██████████| 40/40 [00:00<00:00, 94.99it/s]

Loss: 0.17839913025815435, Accuracy: 93.42666666666666%, test_acc: 0.902



Epoch 13: 100%|██████████| 235/235 [00:02<00:00, 91.92it/s, Learning_rate=0.001, accuracy=93.7, loss=0.17] 
100%|██████████| 40/40 [00:00<00:00, 93.94it/s]

Loss: 0.16978321281519343, Accuracy: 93.70166666666667%, test_acc: 0.9074



Epoch 14: 100%|██████████| 235/235 [00:02<00:00, 93.32it/s, Learning_rate=0.001, accuracy=94, loss=0.162]  
100%|██████████| 40/40 [00:00<00:00, 84.06it/s] 

Loss: 0.16188947393539105, Accuracy: 93.985%, test_acc: 0.9082



Epoch 15: 100%|██████████| 235/235 [00:02<00:00, 95.49it/s, Learning_rate=0.001, accuracy=94.2, loss=0.154]
100%|██████████| 40/40 [00:00<00:00, 90.97it/s]

Loss: 0.15381994073061234, Accuracy: 94.25%, test_acc: 0.9134



Epoch 16: 100%|██████████| 235/235 [00:02<00:00, 94.86it/s, Learning_rate=0.001, accuracy=94.4, loss=0.152]
100%|██████████| 40/40 [00:00<00:00, 93.41it/s]

Loss: 0.15163111163580673, Accuracy: 94.39%, test_acc: 0.9101



Epoch 17: 100%|██████████| 235/235 [00:02<00:00, 95.68it/s, Learning_rate=0.001, accuracy=94.6, loss=0.143]
100%|██████████| 40/40 [00:00<00:00, 93.97it/s]

Loss: 0.14292468247895546, Accuracy: 94.6%, test_acc: 0.9127



Epoch 18: 100%|██████████| 235/235 [00:02<00:00, 94.88it/s, Learning_rate=0.001, accuracy=95, loss=0.136]  
100%|██████████| 40/40 [00:00<00:00, 102.44it/s]

Loss: 0.13594677419738566, Accuracy: 95.00166666666667%, test_acc: 0.9087



Epoch 19: 100%|██████████| 235/235 [00:02<00:00, 91.57it/s, Learning_rate=0.001, accuracy=95, loss=0.134]  
100%|██████████| 40/40 [00:00<00:00, 93.41it/s]

Loss: 0.13424210963731117, Accuracy: 94.95%, test_acc: 0.9149



