In [1]:
from models.AlexNet import AlexNetMNIST, AlexNetMNISTee1, AlexNetMNISTee2
from models.Branchynet import Branchynet

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import datasets, transforms
import torchvision
import matplotlib

import os
import numpy as np
from datetime import datetime as dt

In [2]:
def count_parameters(model):
    params = [p.numel() for p in model.parameters() if p.requires_grad]
    for item in params:
        print(f'{item:>6}')
    print(f'______\n{sum(params):>6}')

In [2]:
transform = transforms.ToTensor()

batch_size = 600

train_data   = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
test_data    = datasets.FashionMNIST(root='../data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [9]:
model = AlexNetMNISTee1()
print(model)

AlexNetMNISTee1(
  (layer1): Sequential(
    (0): Conv2d(1, 96, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=1, padding=0, dilation=1, ceil_mode=False)
  )
  (layer2): Sequential(
    (0): Conv2d(96, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (layer3): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=4096, out_features=10, bias=True)
  )
)


In [None]:
for i, (X_train, y_train) in enumerate(train_data):
    break

x = X_train.view(1,1,28,28)
print(x.shape)
x = model.layer1(x)
print(x.shape)
x = model.layer2(x)
print(x.shape)
x = model.layer3(x)
print(x.shape)

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

import time
start_time = time.time()

epochs = 5
train_losses = []
test_losses = []
train_correct = []
test_correct = []

for i in range(epochs):
    trn_corr = 0
    tst_corr = 0
    
    # Run the training batches
    for b, (X_train, y_train) in enumerate(train_loader):
        X_train.cuda()
        y_train.cuda()
        b+=1
        
        # Apply the model
        y_pred = model(X_train)  # we don't flatten X-train here
        loss = criterion(y_pred, y_train)
 
        # Tally the number of correct predictions
        predicted = torch.max(y_pred.data, 1)[1]
        batch_corr = (predicted == y_train).sum()
        trn_corr += batch_corr
        
        # Update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Print interim results
        if b%10 == 0:
            print(f'epoch: {i:2}  batch: {b:4} [{batch_size*b:6}/60000]  loss: {loss.item():10.8f}  \
accuracy: {trn_corr.item()*100/(batch_size*b):7.3f}%')
        
    train_losses.append(loss)
    train_correct.append(trn_corr)
        
    # Run the testing batches
    with torch.no_grad():
        for b, (X_test, y_test) in enumerate(test_loader):

            # Apply the model
            y_val = model(X_test)

            # Tally the number of correct predictions
            predicted = torch.max(y_val.data, 1)[1] 
            tst_corr += (predicted == y_test).sum()
            
    loss = criterion(y_val, y_test)
    test_losses.append(loss)
    test_correct.append(tst_corr)
        
print(f'\nDuration: {time.time() - start_time:.0f} seconds') # print the time elapsed  

epoch:  0  batch:   10 [  6000/60000]  loss: 0.72788686  accuracy:  57.433%
epoch:  0  batch:   20 [ 12000/60000]  loss: 0.62489033  accuracy:  66.708%
epoch:  0  batch:   30 [ 18000/60000]  loss: 0.52509028  accuracy:  70.656%
epoch:  0  batch:   40 [ 24000/60000]  loss: 0.46486497  accuracy:  73.542%
epoch:  0  batch:   50 [ 30000/60000]  loss: 0.44731867  accuracy:  75.487%
epoch:  0  batch:   60 [ 36000/60000]  loss: 0.43336985  accuracy:  77.000%
epoch:  0  batch:   70 [ 42000/60000]  loss: 0.38781312  accuracy:  78.088%
epoch:  0  batch:   80 [ 48000/60000]  loss: 0.38396457  accuracy:  79.062%
epoch:  0  batch:   90 [ 54000/60000]  loss: 0.34368518  accuracy:  79.961%
epoch:  0  batch:  100 [ 60000/60000]  loss: 0.33093280  accuracy:  80.630%
epoch:  1  batch:   10 [  6000/60000]  loss: 0.35612792  accuracy:  87.383%
epoch:  1  batch:   20 [ 12000/60000]  loss: 0.32843098  accuracy:  87.758%
epoch:  1  batch:   30 [ 18000/60000]  loss: 0.27989411  accuracy:  87.833%
epoch:  1  b