# Machine Intelligence with Deep Learning
## Importance batching for improved training of neural networks
---

In [1]:
import timeit

In [2]:
from models.resnet import ResNet18
from utils.data_utils import DataLoader
from utils.logging_utils import *

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import os
import pandas as pd

from datetime import datetime
today = datetime.today().strftime('%Y%m%d')

In [3]:
SEEDS = [10, 42, 4] # don't change!
STRATEGIES = ['freeze', 'shuffle', 'homogeneous', 'heterogeneous'] # can be changed

In [4]:
### Training
def train(epoch, optimizer, criterion, seed, dataloader, strategy):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches(strategy, \
                                                                           random_state=seed, use_train=True)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    train_acc = 100.*correct/total
    train_loss /= total
    return train_acc, train_loss

### Testing
def test(epoch, best_acc, seed, dataloader, strategy):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches(strategy, \
                                                                               random_state=seed, use_train=False)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    test_acc = 100.*correct/total
    test_loss /= total
    if test_acc > best_acc:
        best_acc = test_acc
        net.save(best_acc, epoch, seed, strategy)
    return test_acc, test_loss

In [5]:
resume = False
given_date = '20191113'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
### task: classification of the following classes
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    
### hyperparameters
test_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
num_epochs = 150 # number of iterations the model gets trained
learning_rates = { # learning rate is reset after specific epochs
    '1': 0.1, # 50 epochs
    '50': 0.01, # 40 epochs
    '90': 0.005, # 30 epochs
    '120': 0.001 # 30 epochs
}
momentum = 0.9
weight_decay = 5e-4

print("Begin training.")
start = timeit.default_timer()
#Logging header
length_table = 90
log_separating_line(length_table)
log_header_line("Seeds: {}".format(SEEDS), length_table)
log_header_line("Strategies: {}".format(STRATEGIES), length_table)
log_header_line("-> Resulting number of iterations: {}".format(len(SEEDS) * len(STRATEGIES)), length_table)
log_header_line("Number of iterations: {}".format(num_epochs), length_table)
log_header_line("Learning rates: {}".format(learning_rates), length_table)
log_header_line("Resuming from checkpoint: {}".format(True if resume else False), length_table)
log_separating_line(length_table)

rows = []
for seed in SEEDS:
    for strategy in STRATEGIES:
                
        np.random.seed(seed)
        torch.manual_seed(seed)
        if device == 'cuda':
            torch.cuda.manual_seed_all(seed)
             
        ### Model
        net = ResNet18()
        net = net.to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=learning_rates['1'], momentum=momentum, weight_decay=weight_decay)
        
        ### load the data
        # if needed, specify batch sizes and shuffle settings
        dataloader = DataLoader()
        dataloader.download_cifar()
        print()
        log_separating_line(length_table)
        log_position_header(seed, strategy, length_table)
        log_separating_line(length_table)

        if resume:
            assert os.path.isdir('serialized'), 'Error: no serialized directory found!'
            ckpt = torch.load('./serialized/{}/{}_ckpt_{}.pth'.format(given_date, strategy, seed))
            test_acc, start_epoch, net = net.load(ckpt)

        for epoch in range(start_epoch, start_epoch+num_epochs):
            # reset learning rate at specific epochs
            if str(epoch+1) in learning_rates.keys():
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rates[str(epoch+1)]
                    
            train_acc, train_loss = train(epoch, optimizer, criterion, seed, dataloader, strategy)
            test_acc, test_loss = test(epoch, test_acc, seed, dataloader, strategy)
            log_position_line(epoch + 1, num_epochs, train_acc, test_acc, train_loss, test_loss, length_table)
            row = {
                'epoch': epoch + 1,
                'seed': seed,
                'train': True,
                'strategy': strategy,
                'accuracy': train_acc,
                'loss': train_loss
            }
            rows.append(row)
            row = {
                'epoch': epoch + 1,
                'seed': seed,
                'train': False,
                'strategy': strategy,
                'accuracy': test_acc,
                'loss': test_loss
            }
            rows.append(row)
        log_separating_line(length_table)
            
stop = timeit.default_timer()
time_needed = stop - start
hrs = int(time_needed / 3600)
mins = int((time_needed / 60) % 60)
secs = int(time_needed % 60)
print()
print("Finished training. Time needed: {} hrs {} mins {} secs".format(hrs, mins, secs))

logging_df = pd.DataFrame(rows, columns=['epoch', 'seed', 'train', 'strategy', 'accuracy', 'loss'])   
training_logs_dir = 'evaluation_logs'
logging_df.to_csv('{}.txt'.format(os.path.join(training_logs_dir, today)), sep='\t', index=False)

Begin training.
+----------------------------------------------------------------------------------------+
| Seeds: [10, 42, 4]                                                                     |
| Strategies: ['freeze', 'shuffle', 'homogeneous', 'heterogeneous']                      |
| -> Resulting number of iterations: 12                                                  |
| Number of iterations: 150                                                              |
| Learning rates: {'1': 0.1, '50': 0.01, '90': 0.005, '120': 0.001}                      |
| Resuming from checkpoint: False                                                        |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 10      Strategy: freeze                                                        

| [079/150]:    100.0           87.29           003e-05      0.00709                     |
| [080/150]:    100.0           87.19           003e-05      0.00709                     |
| [081/150]:    100.0           87.13           003e-05      00.0071                     |
| [082/150]:    100.0           87.16           003e-05      0.00712                     |
| [083/150]:    100.0           87.15           003e-05      0.00714                     |
| [084/150]:    100.0           87.12           003e-05      0.00716                     |
| [085/150]:    100.0           87.07           003e-05      00.0072                     |
| [086/150]:    100.0           87.05           003e-05      0.00723                     |
| [087/150]:    100.0           87.06           003e-05      0.00727                     |
| [088/150]:    100.0           087.0           003e-05      00.0073                     |
| [089/150]:    100.0           87.02           003e-05      0.00735                     |

| [014/150]:    83.84           74.47           0.00736      0.01335                     |
| [015/150]:    84.19           76.79           00.0072      0.01148                     |
| [016/150]:    84.54           71.15           0.00705      00.0149                     |
| [017/150]:    84.83           77.86           0.00689      00.0106                     |
| [018/150]:    85.14           75.14           0.00685      0.01252                     |
| [019/150]:    85.26           79.57           0.00677      0.01044                     |
| [020/150]:    85.22           076.6           00.0067      0.01164                     |
| [021/150]:    085.6           077.0           0.00663      0.01158                     |
| [022/150]:    85.57           73.93           00.0066      0.01286                     |
| [023/150]:    85.57           69.74           0.00658      0.01862                     |
| [024/150]:    85.81           074.8           0.00648      0.01236                     |

| [105/150]:    100.0           89.22           003e-05      0.00604                     |
| [106/150]:    100.0           89.33           003e-05      0.00603                     |
| [107/150]:    100.0           89.35           003e-05      0.00602                     |
| [108/150]:    100.0           89.38           004e-05      0.00615                     |
| [109/150]:    100.0           89.45           003e-05      0.00612                     |
| [110/150]:    100.0           89.41           003e-05      0.00596                     |
| [111/150]:    100.0           89.54           003e-05      0.00597                     |
| [112/150]:    100.0           089.3           004e-05      0.00613                     |
| [113/150]:    100.0           89.26           003e-05      0.00603                     |
| [114/150]:    100.0           089.4           003e-05      0.00594                     |
| [115/150]:    100.0           89.47           003e-05      0.00597                     |

| [040/150]:    81.42           010.0           0.02616      0.20365                     |
| [041/150]:    81.68           010.0           0.02614      0.20318                     |
| [042/150]:    81.55           010.0           0.02608      0.20405                     |
| [043/150]:    081.3           010.0           0.02629      00.2002                     |
| [044/150]:    81.55           010.0           0.02592      0.20456                     |
| [045/150]:    81.55           010.0           0.02608      0.20564                     |
| [046/150]:    81.81           010.0           0.02611      0.20629                     |
| [047/150]:    81.55           010.0           0.02621      0.20458                     |
| [048/150]:    81.68           010.0           0.02608      0.20574                     |
| [049/150]:    81.68           010.0           0.02625      0.20307                     |
| [050/150]:    38.54           010.0           0.04577      0.05694                     |

| [131/150]:    11.02           010.0           0.03654      00.0362                     |
| [132/150]:    07.82           010.0           0.03655      0.03625                     |
| [133/150]:    08.46           010.0           00.0366      0.03629                     |
| [134/150]:    09.49           010.0           0.03659      0.03625                     |
| [135/150]:    07.06           010.0           0.03661      0.03625                     |
| [136/150]:    08.72           010.0           0.03655      0.03626                     |
| [137/150]:    08.59           010.0           0.03654      00.0362                     |
| [138/150]:    11.02           010.0           0.03653      0.03622                     |
| [139/150]:    08.34           010.0           0.03661      0.03631                     |
| [140/150]:    12.56           010.0           0.03657      0.03627                     |
| [141/150]:    10.13           010.0           0.03657      00.0363                     |

| [066/150]:    100.0           88.61           003e-05      0.00642                     |
| [067/150]:    100.0           88.65           003e-05      0.00642                     |
| [068/150]:    100.0           88.52           003e-05      0.00649                     |
| [069/150]:    100.0           88.79           003e-05      0.00641                     |
| [070/150]:    100.0           88.65           003e-05      0.00642                     |
| [071/150]:    100.0           88.66           003e-05      0.00636                     |
| [072/150]:    100.0           088.6           003e-05      0.00646                     |
| [073/150]:    100.0           88.56           003e-05      00.0064                     |
| [074/150]:    100.0           088.7           003e-05      0.00648                     |
| [075/150]:    100.0           88.54           003e-05      0.00651                     |
| [076/150]:    100.0           88.64           002e-05      0.00651                     |

| [001/150]:    28.21           39.24           0.03096      0.02556                     |
| [002/150]:    43.05           50.09           0.02403      0.02134                     |
| [003/150]:    54.81           055.3           0.01952      0.01952                     |
| [004/150]:    63.69           66.26           0.01597      0.01487                     |
| [005/150]:    69.18           65.76           0.01372      0.01588                     |
| [006/150]:    73.48           70.51           0.01189      0.01363                     |
| [007/150]:    076.8           67.87           0.01048      0.01672                     |
| [008/150]:    78.64           068.8           0.00967      0.01496                     |
| [009/150]:    80.38           70.73           0.00887      0.01466                     |
| [010/150]:    81.49           70.49           0.00839      0.01522                     |
| [011/150]:    82.26           70.07           0.00793      0.01563                     |

| [092/150]:    100.0           86.22           003e-05      0.00756                     |
| [093/150]:    100.0           86.25           003e-05      0.00757                     |
| [094/150]:    100.0           86.24           003e-05      0.00759                     |
| [095/150]:    100.0           86.24           003e-05      0.00761                     |
| [096/150]:    100.0           86.27           003e-05      0.00762                     |
| [097/150]:    100.0           86.22           003e-05      0.00764                     |
| [098/150]:    100.0           86.23           003e-05      0.00766                     |
| [099/150]:    100.0           86.18           003e-05      0.00767                     |
| [100/150]:    100.0           86.15           003e-05      0.00769                     |
| [101/150]:    100.0           86.13           003e-05      0.00771                     |
| [102/150]:    100.0           86.16           003e-05      0.00773                     |

| [027/150]:    086.0           70.22           0.00644      0.01676                     |
| [028/150]:    85.84           78.98           0.00647      0.01047                     |
| [029/150]:    86.01           74.78           0.00643      0.01269                     |
| [030/150]:    85.78           75.31           0.00654      0.01182                     |
| [031/150]:    85.82           75.61           0.00639      00.0126                     |
| [032/150]:    86.11           78.24           0.00634      0.01148                     |
| [033/150]:    85.86           64.47           0.00642      0.01993                     |
| [034/150]:    86.23           76.26           0.00635      0.01163                     |
| [035/150]:    86.45           80.54           0.00626      0.00927                     |
| [036/150]:    86.03           79.78           0.00639      0.00997                     |
| [037/150]:    86.17           077.7           0.00629      0.01077                     |

| [118/150]:    100.0           89.27           003e-05      0.00596                     |
| [119/150]:    100.0           89.34           003e-05      0.00596                     |
| [120/150]:    100.0           89.34           003e-05      0.00602                     |
| [121/150]:    100.0           89.52           003e-05      0.00597                     |
| [122/150]:    100.0           89.51           003e-05      0.00593                     |
| [123/150]:    100.0           089.5           003e-05      0.00591                     |
| [124/150]:    100.0           89.41           003e-05      0.00598                     |
| [125/150]:    100.0           89.67           003e-05      0.00594                     |
| [126/150]:    100.0           89.54           003e-05      0.00588                     |
| [127/150]:    100.0           89.45           003e-05      00.0059                     |
| [128/150]:    100.0           89.47           003e-05      0.00598                     |

| [053/150]:    52.24           010.0           0.02636      0.05821                     |
| [054/150]:    50.83           010.0           0.02768      0.05767                     |
| [055/150]:    050.7           010.0           00.0278      0.05847                     |
| [056/150]:    51.34           010.0           0.02713      0.05876                     |
| [057/150]:    50.58           010.0           0.02785      0.05848                     |
| [058/150]:    51.47           010.0           0.02698      0.06006                     |
| [059/150]:    51.34           010.0           00.0272      0.05982                     |
| [060/150]:    55.31           010.0           0.02522      0.05759                     |
| [061/150]:    050.7           010.0           0.02759      00.0591                     |
| [062/150]:    51.22           010.0           0.02732      0.05796                     |
| [063/150]:    50.83           010.0           0.02778      00.0583                     |

| [144/150]:    08.59           010.0           0.03663      0.03634                     |
| [145/150]:    11.79           010.0           0.03656      0.03624                     |
| [146/150]:    11.02           010.0           0.03655      0.03624                     |
| [147/150]:    10.26           010.0           00.0366      0.03631                     |
| [148/150]:    10.64           010.0           0.03663      00.0363                     |
| [149/150]:    08.85           010.0           0.03658      0.03625                     |
| [150/150]:    07.95           010.0           00.0366      0.03628                     |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 42      Strategy: heterogeneous                                                         |
| Epoc

| [079/150]:    097.5           85.55           0.00116      0.00802                     |
| [080/150]:    98.45           86.07           0.00074      0.00839                     |
| [081/150]:    98.79           85.68           0.00058      0.00878                     |
| [082/150]:    98.95           084.1           0.00051      0.00976                     |
| [083/150]:    099.1           85.41           0.00043      0.00946                     |
| [084/150]:    98.83           85.68           0.00056      0.00902                     |
| [085/150]:    98.92           85.66           0.00053      0.00878                     |
| [086/150]:    99.18           85.61           0.00041      0.00953                     |
| [087/150]:    99.12           85.17           0.00043      0.00935                     |
| [088/150]:    99.11           84.76           0.00043      0.00987                     |
| [089/150]:    99.07           84.73           0.00045      0.00987                     |

| [014/150]:    84.64           70.71           0.00694      0.01586                     |
| [015/150]:    85.04           72.78           00.0068      0.01437                     |
| [016/150]:    85.29           68.34           0.00665      0.01664                     |
| [017/150]:    85.08           73.83           0.00663      0.01326                     |
| [018/150]:    085.8           70.94           0.00647      0.01588                     |
| [019/150]:    85.78           71.69           0.00637      0.01504                     |
| [020/150]:    85.93           72.54           0.00636      0.01466                     |
| [021/150]:    86.06           71.75           0.00634      0.01507                     |
| [022/150]:    86.32           73.17           0.00619      00.0142                     |
| [023/150]:    86.28           69.22           0.00617      0.01631                     |
| [024/150]:    86.47           59.77           0.00613      00.0256                     |

| [105/150]:    100.0           85.58           003e-05      0.00836                     |
| [106/150]:    100.0           085.6           003e-05      0.00838                     |
| [107/150]:    100.0           85.55           003e-05      0.00841                     |
| [108/150]:    100.0           85.58           003e-05      0.00843                     |
| [109/150]:    100.0           85.56           003e-05      0.00846                     |
| [110/150]:    100.0           85.63           003e-05      0.00848                     |
| [111/150]:    100.0           85.58           002e-05      0.00851                     |
| [112/150]:    100.0           85.62           002e-05      0.00854                     |
| [113/150]:    100.0           85.66           002e-05      0.00857                     |
| [114/150]:    100.0           85.74           002e-05      00.0086                     |
| [115/150]:    100.0           85.63           002e-05      0.00864                     |

| [040/150]:    86.45           78.02           0.00623      0.01147                     |
| [041/150]:    86.53           77.13           0.00613      00.0109                     |
| [042/150]:    86.37           75.53           0.00625      0.01266                     |
| [043/150]:    86.59           70.78           0.00614      0.01552                     |
| [044/150]:    86.54           080.3           0.00614      0.00978                     |
| [045/150]:    86.45           76.79           0.00623      0.01177                     |
| [046/150]:    86.67           74.94           0.00609      0.01326                     |
| [047/150]:    86.64           77.79           0.00612      0.01094                     |
| [048/150]:    86.76           080.9           00.0061      0.00907                     |
| [049/150]:    86.61           73.86           0.00611      00.0135                     |
| [050/150]:    95.76           87.53           0.00209      0.00629                     |

| [131/150]:    100.0           89.35           003e-05      0.00599                     |
| [132/150]:    100.0           89.26           003e-05      0.00599                     |
| [133/150]:    100.0           89.25           003e-05      0.00599                     |
| [134/150]:    100.0           089.5           003e-05      0.00599                     |
| [135/150]:    100.0           89.32           003e-05      0.00606                     |
| [136/150]:    100.0           89.34           003e-05      000.006                     |
| [137/150]:    100.0           89.13           003e-05      0.00599                     |
| [138/150]:    100.0           89.41           003e-05      0.00599                     |
| [139/150]:    100.0           89.44           003e-05      0.00605                     |
| [140/150]:    100.0           89.36           003e-05      0.00605                     |
| [141/150]:    100.0           089.4           003e-05      0.00598                     |

| [066/150]:    50.83           010.0           0.02762      0.05967                     |
| [067/150]:    55.06           010.0           0.02557      0.05826                     |
| [068/150]:    55.44           010.0           0.02532      0.05918                     |
| [069/150]:    51.47           010.0           0.02697      0.05777                     |
| [070/150]:    50.58           010.0           0.02777      0.05831                     |
| [071/150]:    55.06           010.0           0.02573      0.05856                     |
| [072/150]:    52.24           010.0           0.02636      0.05821                     |
| [073/150]:    50.83           010.0           0.02768      0.05767                     |
| [074/150]:    050.7           010.0           00.0278      0.05847                     |
| [075/150]:    51.34           010.0           0.02713      0.05876                     |
| [076/150]:    50.58           010.0           0.02785      0.05848                     |

| [001/150]:    36.19           46.55           0.02771      0.02292                     |
| [002/150]:    52.63           55.51           0.02041      0.02055                     |
| [003/150]:    63.76           60.82           0.01597      0.01843                     |
| [004/150]:    72.25           70.25           0.01247      00.0136                     |
| [005/150]:    76.75           71.84           0.01052      0.01308                     |
| [006/150]:    79.02           074.7           0.00946      0.01194                     |
| [007/150]:    81.04           75.34           00.0087      0.01216                     |
| [008/150]:    82.23           78.04           0.00813      0.01038                     |
| [009/150]:    82.91           75.84           0.00777      0.01141                     |
| [010/150]:    83.68           074.7           0.00741      0.01199                     |
| [011/150]:    84.29           77.08           0.00709      00.0113                     |

| [092/150]:    100.0           88.52           002e-05      0.00705                     |
| [093/150]:    100.0           88.66           001e-05      0.00697                     |
| [094/150]:    100.0           88.64           002e-05      0.00691                     |
| [095/150]:    100.0           88.76           002e-05      0.00682                     |
| [096/150]:    100.0           88.73           002e-05      0.00668                     |
| [097/150]:    100.0           88.71           002e-05      0.00668                     |
| [098/150]:    100.0           88.77           002e-05      0.00657                     |
| [099/150]:    100.0           88.74           002e-05      0.00653                     |
| [100/150]:    100.0           88.79           002e-05      00.0065                     |
| [101/150]:    100.0           88.89           002e-05      0.00657                     |
| [102/150]:    100.0           88.94           002e-05      0.00645                     |