In [1]:
from models.AlexNet import AlexNetMNIST, AlexNetMNISTee1, AlexNetMNISTee2, AlexNetWithExistsMNIST
from models.Branchynet import Branchynet

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import datasets, transforms
import torchvision
import matplotlib

import os
import numpy as np
from datetime import datetime as dt

In [2]:
model = AlexNetWithExistsMNIST()

In [None]:
print(model)

In [3]:
transform = transforms.ToTensor()

batch_size = 600

train_data   = datasets.FashionMNIST(root='../data', train=True, download=True, transform=transform)
test_data    = datasets.FashionMNIST(root='../data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

import time
start_time = time.time()

epochs = 5

for i in range(epochs):
    model.train
    
    for b, (xb, yb) in enumerate(train_loader):
        b+=1

        results = model(xb)
        losses = [weighting * criterion(res, yb)
                        for weighting, res in zip(model.exit_loss_weights,results)]

        optimizer.zero_grad()
        for loss in losses[:-1]:
            loss.backward(retain_graph=True)
        losses[-1].backward()

        optimizer.step()
        
        if b%1 == 0:
            print_losses = list(map(lambda x: f'{x.item():10.8f}', losses))
            print(f'epoch: {i:2}  batch: {b:4} [{batch_size*b:6}/60000] losses: {print_losses}')
                
    correct_per_exit = [0, 0, 0]
    total_tests = 0
    with torch.no_grad():
        for b, (X_test, y_test) in enumerate(test_loader):
            
            # Apply the model
            y_val = model(X_test)
            
            for n, exit in enumerate(y_val):
                predicted = torch.max(exit.data, 1)[1] 
                correct_per_exit[n] += (predicted == y_test).sum()
            
            total_tests += len(y_test)
            
    print(f'Accuracy: {100*correct_per_exit[0]/total_tests} {100*correct_per_exit[1]/total_tests} {100*correct_per_exit[2]/total_tests}')



epoch:  0  batch:    1 [   600/60000] losses: ['2.59961367', '1.26119518', '0.46754408']
epoch:  0  batch:    2 [  1200/60000] losses: ['2.23671889', '1.25556040', '1.34791768']
epoch:  0  batch:    3 [  1800/60000] losses: ['1.82373369', '1.15802646', '1.04345906']
epoch:  0  batch:    4 [  2400/60000] losses: ['1.20215595', '0.72629648', '0.70451087']
epoch:  0  batch:    5 [  3000/60000] losses: ['1.17375028', '0.56360120', '0.55334598']
epoch:  0  batch:    6 [  3600/60000] losses: ['1.02933168', '0.55307651', '0.46745703']
epoch:  0  batch:    7 [  4200/60000] losses: ['1.03842342', '0.54510200', '0.40105301']
epoch:  0  batch:    8 [  4800/60000] losses: ['0.95952982', '0.48803285', '0.39068148']
epoch:  0  batch:    9 [  5400/60000] losses: ['0.79567772', '0.43787125', '0.36351272']
epoch:  0  batch:   10 [  6000/60000] losses: ['0.77403730', '0.43133774', '0.33857253']
epoch:  0  batch:   11 [  6600/60000] losses: ['0.79140180', '0.43166468', '0.31460717']
epoch:  0  batch:   1

epoch:  0  batch:   94 [ 56400/60000] losses: ['0.30520639', '0.15357922', '0.09193730']
epoch:  0  batch:   95 [ 57000/60000] losses: ['0.36080912', '0.17826840', '0.10741962']
epoch:  0  batch:   96 [ 57600/60000] losses: ['0.35568300', '0.18522987', '0.10675906']
epoch:  0  batch:   97 [ 58200/60000] losses: ['0.35643065', '0.17041691', '0.09612516']
epoch:  0  batch:   98 [ 58800/60000] losses: ['0.39374104', '0.19748822', '0.11019895']
epoch:  0  batch:   99 [ 59400/60000] losses: ['0.32799962', '0.16513370', '0.08706550']
epoch:  0  batch:  100 [ 60000/60000] losses: ['0.34087148', '0.17266971', '0.08884279']
Accuracy: 87.01000213623047 86.69000244140625 81.30000305175781
epoch:  1  batch:    1 [   600/60000] losses: ['0.34987250', '0.18867518', '0.09644701']
epoch:  1  batch:    2 [  1200/60000] losses: ['0.39042497', '0.19867963', '0.10595768']
epoch:  1  batch:    3 [  1800/60000] losses: ['0.32744256', '0.16036533', '0.08453111']
epoch:  1  batch:    4 [  2400/60000] losses: 

epoch:  1  batch:   86 [ 51600/60000] losses: ['0.32138067', '0.14982344', '0.07479569']
epoch:  1  batch:   87 [ 52200/60000] losses: ['0.21972409', '0.10537992', '0.05459685']
epoch:  1  batch:   88 [ 52800/60000] losses: ['0.29449669', '0.14360201', '0.06717978']
epoch:  1  batch:   89 [ 53400/60000] losses: ['0.24091141', '0.10800074', '0.05743914']
epoch:  1  batch:   90 [ 54000/60000] losses: ['0.29865485', '0.14855437', '0.06821623']
epoch:  1  batch:   91 [ 54600/60000] losses: ['0.24656986', '0.12543873', '0.06088475']
epoch:  1  batch:   92 [ 55200/60000] losses: ['0.25038266', '0.12558013', '0.06152388']
epoch:  1  batch:   93 [ 55800/60000] losses: ['0.24352859', '0.11645695', '0.06344802']
epoch:  1  batch:   94 [ 56400/60000] losses: ['0.28710747', '0.14414020', '0.07153853']
epoch:  1  batch:   95 [ 57000/60000] losses: ['0.31194627', '0.15368934', '0.07020823']
epoch:  1  batch:   96 [ 57600/60000] losses: ['0.26719341', '0.13384138', '0.06593625']
epoch:  1  batch:   9

KeyboardInterrupt: 