# Machine Intelligence with Deep Learning
## Importance batching for improved training of neural networks
---

In [1]:
import timeit

In [2]:
from models.resnet import ResNet18
from utils.data_utils import DataLoader
from utils.logging_utils import *

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import os
import pandas as pd

from datetime import datetime
today = datetime.today().strftime('%Y%m%d')

In [3]:
SEEDS = [10, 42, 4] # don't change!
STRATEGIES = ['freeze', 'shuffle', 'homogeneous', 'heterogeneous', 'max_k_loss', 'min_k_loss'] # can be changed

In [4]:
### Training
def train(epoch, optimizer, criterion_fn, seed, dataloader, strategy, device):
    criterion = criterion_fn()
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches(strategy, \
                                                  random_state=seed, use_train=True, \
                                                  criterion=criterion_fn(reduction='none'),\
                                                  device=device)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    train_acc = 100.*correct/total
    train_loss /= total
    return train_acc, train_loss

### Testing
def test(epoch, best_acc, criterion_fn, seed, dataloader):
    criterion = criterion_fn()
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches('shuffle', \
                                                                               random_state=seed, use_train=False)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    test_acc = 100.*correct/total
    test_loss /= total
    if test_acc > best_acc:
        best_acc = test_acc
        net.save(best_acc, epoch, seed, strategy)
    return test_acc, test_loss

In [5]:
resume = False
given_date = '20191113' #only needed if resumed from checkpoint
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
### task: classification of the following classes
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    
### hyperparameters
test_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
num_epochs = 100 #50 # number of iterations the model gets trained
#learning_rates = { # learning rate is reset after specific epochs
#    '1': 0.1, # 50 epochs
#    '50': 0.01, # 40 epochs
#    '90': 0.005, # 30 epochs
#    '120': 0.001 # 30 epochs
#}
learning_rate = 0.01
momentum = 0.9
weight_decay = 5e-4

print("Begin training.")
start = timeit.default_timer()
#Logging header
length_table = 90
log_separating_line(length_table)
log_header_line("Seeds: {}".format(SEEDS), length_table)
log_header_line("Strategies: {}".format(STRATEGIES), length_table)
log_header_line("-> Resulting number of iterations: {}".format(len(SEEDS) * len(STRATEGIES)), length_table)
log_header_line("Number of iterations: {}".format(num_epochs), length_table)
log_header_line("Learning rate: {}".format(learning_rate), length_table)
log_header_line("Resuming from checkpoint: {}".format(True if resume else False), length_table)
log_separating_line(length_table)

rows = []
for seed in SEEDS:
    for strategy in STRATEGIES:
                
        np.random.seed(seed)
        torch.manual_seed(seed)
        if device == 'cuda':
            torch.cuda.manual_seed_all(seed)
             
        ### Model
        net = ResNet18()
        net = net.to(device)
        criterion = nn.CrossEntropyLoss # no function!
        optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
        
        ### load the data
        # if needed, specify batch sizes and shuffle settings
        dataloader = DataLoader()
        dataloader.download_cifar()
        dataloader.set_model(net)
        if strategy in ['max_k_loss', 'min_k_loss']:
            dataloader.initialize_weights(criterion(reduction='none'), device, seed=seed)
        print()
        log_separating_line(length_table)
        log_position_header(seed, strategy, length_table)
        log_separating_line(length_table)

        if resume:
            assert os.path.isdir('serialized'), 'Error: no serialized directory found!'
            ckpt = torch.load('./serialized/{}/{}_ckpt_{}.pth'.format(given_date, strategy, seed))
            test_acc, start_epoch, net = net.load(ckpt)

        for epoch in range(start_epoch, start_epoch+num_epochs):
            # reset learning rate at specific epochs
            #if str(epoch+1) in learning_rates.keys():
            #    for param_group in optimizer.param_groups:
            #        param_group['lr'] = learning_rates[str(epoch+1)]
                    
            train_acc, train_loss = train(epoch, optimizer, criterion, seed, dataloader, \
                                          strategy, device)
            test_acc, test_loss = test(epoch, test_acc, criterion, seed, dataloader)
            log_position_line(epoch + 1, num_epochs, train_acc, test_acc, train_loss, test_loss, length_table)
            row = {
                'epoch': epoch + 1,
                'seed': seed,
                'train': True,
                'strategy': strategy,
                'accuracy': train_acc,
                'loss': train_loss
            }
            rows.append(row)
            row = {
                'epoch': epoch + 1,
                'seed': seed,
                'train': False,
                'strategy': strategy,
                'accuracy': test_acc,
                'loss': test_loss
            }
            rows.append(row)
        log_separating_line(length_table)
            
stop = timeit.default_timer()
time_needed = stop - start
hrs = int(time_needed / 3600)
mins = int((time_needed / 60) % 60)
secs = int(time_needed % 60)
print()
print("Finished training. Time needed: {} hrs {} mins {} secs".format(hrs, mins, secs))

logging_df = pd.DataFrame(rows, columns=['epoch', 'seed', 'train', 'strategy', 'accuracy', 'loss'])   
training_logs_dir = 'evaluation_logs'
logging_df.to_csv('{}.txt'.format(os.path.join(training_logs_dir, today)), sep='\t', index=False)

Begin training.
+----------------------------------------------------------------------------------------+
| Seeds: [10, 42, 4]                                                                     |
| Strategies: ['freeze', 'shuffle', 'homogeneous', 'heterogeneous', 'max_k_loss', 'min_k_loss']|
| -> Resulting number of iterations: 18                                                  |
| Number of iterations: 100                                                              |
| Learning rate: 0.01                                                                    |
| Resuming from checkpoint: False                                                        |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 10      Strategy: freeze                                                  

| [078/100]:    85.36           074.2           0.00673      0.01309                     |
| [079/100]:    87.98           76.46           0.00546      0.01242                     |
| [080/100]:    90.16           75.41           0.00443      0.01421                     |
| [081/100]:    92.36           76.94           00.0035      0.01387                     |
| [082/100]:    93.98           77.04           0.00274      0.01432                     |
| [083/100]:    95.13           78.73           0.00223      00.0141                     |
| [084/100]:    95.98           79.19           0.00185      0.01354                     |
| [085/100]:    96.78           79.44           0.00149      0.01319                     |
| [086/100]:    97.34           79.29           0.00125      0.01424                     |
| [087/100]:    97.74           78.12           00.0011      00.0158                     |
| [088/100]:    97.82           80.21           0.00101      0.01311                     |

| [063/100]:    98.67           081.1           0.00066      00.0118                     |
| [064/100]:    99.13           81.84           0.00048      0.01151                     |
| [065/100]:    99.44           083.1           0.00031      0.01091                     |
| [066/100]:    99.39           81.93           0.00034      0.01216                     |
| [067/100]:    98.09           79.36           00.0009      0.01285                     |
| [068/100]:    98.66           81.51           0.00067      0.01179                     |
| [069/100]:    98.48           82.13           0.00076      0.01149                     |
| [070/100]:    98.62           81.33           0.00069      0.01191                     |
| [071/100]:    98.67           82.71           0.00065      0.01077                     |
| [072/100]:    98.74           82.54           0.00065      0.01102                     |
| [073/100]:    98.71           82.41           0.00064      0.01123                     |

| [048/100]:    62.36           010.0           0.02329      0.08256                     |
| [049/100]:    055.6           010.0           0.02547      0.06024                     |
| [050/100]:    62.32           010.0           0.02302      0.05943                     |
| [051/100]:    52.02           010.0           00.0267      0.05819                     |
| [052/100]:    51.47           010.0           0.02705      0.05922                     |
| [053/100]:    050.7           010.0           0.02757      00.0595                     |
| [054/100]:    50.83           010.0           0.02761      0.05982                     |
| [055/100]:    50.83           010.0           0.02771      0.05982                     |
| [056/100]:    68.82           010.0           0.02035      0.08283                     |
| [057/100]:    060.7           010.0           0.02356      0.06678                     |
| [058/100]:    52.36           010.0           0.02657      0.06014                     |

| [033/100]:    100.0           85.69           002e-05      0.00789                     |
| [034/100]:    100.0           85.78           002e-05      0.00776                     |
| [035/100]:    100.0           85.85           002e-05      00.0077                     |
| [036/100]:    100.0           85.93           002e-05      0.00763                     |
| [037/100]:    100.0           85.98           002e-05      0.00751                     |
| [038/100]:    100.0           85.83           002e-05      0.00756                     |
| [039/100]:    100.0           85.86           002e-05      00.0076                     |
| [040/100]:    100.0           085.9           002e-05      0.00757                     |
| [041/100]:    100.0           86.07           002e-05      0.00764                     |
| [042/100]:    100.0           85.87           002e-05      0.00764                     |
| [043/100]:    100.0           85.89           002e-05      0.00766                     |

| [018/100]:    35.94           10.32           00.0462      03.3951                     |
| [019/100]:    46.88           010.1           0.01847      02.0983                     |
| [020/100]:    57.81           09.88           0.02768      1.28252                     |
| [021/100]:    45.31           12.41           0.03586      0.55184                     |
| [022/100]:    59.38           15.39           0.02533      0.37777                     |
| [023/100]:    45.31           15.28           0.02381      0.41573                     |
| [024/100]:    53.12           15.15           0.02904      0.37034                     |
| [025/100]:    46.88           15.38           0.03636      0.22257                     |
| [026/100]:    46.88           14.13           0.03053      0.13585                     |
| [027/100]:    48.44           13.29           0.03718      0.15711                     |
| [028/100]:    51.56           11.54           00.0242      0.16794                     |

| [003/100]:    100.0           010.0           007e-05      0.42602                     |
| [004/100]:    100.0           010.0           00000.0      0.72485                     |
| [005/100]:    100.0           010.0           00000.0      1.06989                     |
| [006/100]:    100.0           010.0           00000.0      1.40094                     |
| [007/100]:    100.0           010.0           00000.0      1.67772                     |
| [008/100]:    100.0           010.0           00000.0      1.88335                     |
| [009/100]:    100.0           010.0           00000.0      02.0269                     |
| [010/100]:    100.0           010.0           00000.0      02.1051                     |
| [011/100]:    100.0           010.0           00000.0      2.13838                     |
| [012/100]:    100.0           010.0           00000.0      2.14185                     |
| [013/100]:    100.0           010.0           00000.0      2.11954                     |

| [094/100]:    100.0           010.0           00000.0      1.34643                     |
| [095/100]:    100.0           010.0           00000.0      01.3484                     |
| [096/100]:    100.0           010.0           00000.0      1.34674                     |
| [097/100]:    100.0           010.0           00000.0      1.34866                     |
| [098/100]:    100.0           010.0           00000.0      01.3478                     |
| [099/100]:    100.0           010.0           00000.0      1.34831                     |
| [100/100]:    100.0           010.0           00000.0      1.34545                     |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 42      Strategy: freeze                                                         |
| Epoch      

| [079/100]:    96.96           80.24           0.00142      0.01346                     |
| [080/100]:    97.74           81.15           0.00107      0.01304                     |
| [081/100]:    98.12           81.78           00.0009      0.01243                     |
| [082/100]:    97.99           80.63           0.00094      0.01365                     |
| [083/100]:    98.21           79.86           0.00084      00.0148                     |
| [084/100]:    98.14           81.91           0.00086      0.01227                     |
| [085/100]:    98.66           079.6           0.00067      0.01569                     |
| [086/100]:    98.39           79.39           0.00078      0.01488                     |
| [087/100]:    98.54           81.37           0.00071      0.01355                     |
| [088/100]:    98.74           82.19           0.00061      0.01218                     |
| [089/100]:    98.83           80.73           0.00061      0.01427                     |

| [064/100]:    98.16           81.86           0.00094      00.0116                     |
| [065/100]:    98.44           81.92           0.00075      0.01106                     |
| [066/100]:    98.98           83.95           0.00053      0.00986                     |
| [067/100]:    99.54           84.22           0.00029      0000.01                     |
| [068/100]:    99.36           82.03           0.00035      0.01103                     |
| [069/100]:    099.1           83.02           0.00049      0.01057                     |
| [070/100]:    98.75           81.91           0.00062      0.01132                     |
| [071/100]:    97.84           80.87           0.00105      0.01196                     |
| [072/100]:    98.26           82.21           0.00086      0.01099                     |
| [073/100]:    98.97           83.38           0.00054      0.01027                     |
| [074/100]:    98.99           83.67           0.00053      0.01048                     |

| [049/100]:    51.34           010.0           0.02718      0.05747                     |
| [050/100]:    51.09           010.0           0.02747      0.05793                     |
| [051/100]:    50.58           010.0           0.02764      0.05927                     |
| [052/100]:    50.83           010.0           00.0277      0.05974                     |
| [053/100]:    50.45           010.0           0.02798      0.05961                     |
| [054/100]:    50.45           010.0           0.02782      0.05929                     |
| [055/100]:    050.7           010.0           0.02748      0.05747                     |
| [056/100]:    51.22           010.0           0.02724      0.05762                     |
| [057/100]:    50.83           010.0           00.0276      0.05902                     |
| [058/100]:    50.58           010.0           0.02794      0.05835                     |
| [059/100]:    050.7           010.0           0.02772      0.05863                     |

| [034/100]:    100.0           86.66           002e-05      0.00723                     |
| [035/100]:    100.0           86.58           002e-05      0.00718                     |
| [036/100]:    100.0           86.74           002e-05      0.00718                     |
| [037/100]:    100.0           86.66           002e-05      0.00713                     |
| [038/100]:    100.0           86.73           002e-05      0.00709                     |
| [039/100]:    100.0           86.65           002e-05      0.00707                     |
| [040/100]:    100.0           86.78           002e-05      0.00715                     |
| [041/100]:    100.0           86.75           002e-05      0.00717                     |
| [042/100]:    100.0           86.87           002e-05      0.00721                     |
| [043/100]:    100.0           086.7           002e-05      0.00732                     |
| [044/100]:    100.0           86.78           002e-05      0.00727                     |

| [019/100]:    73.44           13.17           0.01418      0.53461                     |
| [020/100]:    79.69           12.59           0.01431      0.55355                     |
| [021/100]:    81.25           12.08           0.01195      00.5517                     |
| [022/100]:    81.25           12.04           0.00867      0.51776                     |
| [023/100]:    76.56           012.0           0.00867      0.47877                     |
| [024/100]:    71.88           12.06           0.01094      0.46235                     |
| [025/100]:    76.56           12.02           0.01167      0.46921                     |
| [026/100]:    73.44           11.99           0.01216      0.46143                     |
| [027/100]:    73.44           11.99           0.01659      0.48829                     |
| [028/100]:    76.56           12.13           0.01098      0.53509                     |
| [029/100]:    78.12           12.16           0.00974      0.55659                     |

| [004/100]:    100.0           010.0           00000.0      0.30168                     |
| [005/100]:    100.0           010.0           00000.0      0.39685                     |
| [006/100]:    100.0           010.0           00000.0      0.48465                     |
| [007/100]:    100.0           010.0           00000.0      0.56116                     |
| [008/100]:    100.0           010.0           00000.0      0.62627                     |
| [009/100]:    100.0           010.0           00000.0      0.68162                     |
| [010/100]:    100.0           010.0           00000.0      0.72457                     |
| [011/100]:    100.0           010.0           00000.0      0.75806                     |
| [012/100]:    100.0           010.0           00000.0      0.78479                     |
| [013/100]:    100.0           010.0           00000.0      0.80328                     |
| [014/100]:    100.0           010.0           00000.0      0.81712                     |

| [095/100]:    100.0           010.0           00000.0      0.76447                     |
| [096/100]:    100.0           010.0           00000.0      0.76393                     |
| [097/100]:    100.0           010.0           00000.0      0.76405                     |
| [098/100]:    100.0           010.0           00000.0      0.76262                     |
| [099/100]:    100.0           010.0           00000.0      0.76385                     |
| [100/100]:    100.0           010.0           00000.0      0.76363                     |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 04      Strategy: freeze                                                         |
| Epoch         Train Accuracy  Test Accuracy   Train Loss   Test Loss                   |
+------------

| [080/100]:    87.21           75.37           0.00581      00.0134                     |
| [081/100]:    89.43           73.72           0.00476      0.01457                     |
| [082/100]:    91.53           76.68           0.00381      00.0137                     |
| [083/100]:    93.38           75.89           0.00302      0.01559                     |
| [084/100]:    094.5           78.31           0.00246      0.01368                     |
| [085/100]:    95.59           78.69           0.00203      00.0143                     |
| [086/100]:    96.27           78.62           0.00169      0.01429                     |
| [087/100]:    96.96           78.76           00.0014      0.01427                     |
| [088/100]:    97.37           79.54           0.00122      0.01394                     |
| [089/100]:    97.76           80.07           0.00107      0.01345                     |
| [090/100]:    98.02           80.26           0.00096      0.01358                     |

| [065/100]:    98.83           80.55           0.00061      0.01262                     |
| [066/100]:    098.7           080.5           0.00063      0.01276                     |
| [067/100]:    98.73           81.71           0.00065      00.0118                     |
| [068/100]:    99.01           82.12           0.00051      0.01185                     |
| [069/100]:    99.03           81.25           00.0005      0.01204                     |
| [070/100]:    99.07           81.64           0.00049      0.01159                     |
| [071/100]:    99.02           79.75           0.00049      00.0127                     |
| [072/100]:    98.77           81.86           0.00061      0.01156                     |
| [073/100]:    98.44           79.48           0.00077      0.01367                     |
| [074/100]:    97.93           80.21           0.00099      0.01245                     |
| [075/100]:    98.89           83.39           0.00056      0.01058                     |

| [050/100]:    51.09           010.0           0.02713      0.05868                     |
| [051/100]:    54.93           010.0           0.02596      0.05772                     |
| [052/100]:    051.6           010.0           0.02686      0.05763                     |
| [053/100]:    55.44           010.0           0.02519      0.05797                     |
| [054/100]:    50.83           010.0           0.02766      0.05877                     |
| [055/100]:    51.73           010.0           0.02696      0.05873                     |
| [056/100]:    50.83           010.0           0.02775      0.05868                     |
| [057/100]:    51.47           010.0           0.02704      0.05817                     |
| [058/100]:    51.34           010.0           0.02709      0.05914                     |
| [059/100]:    050.7           010.0           0.02758      0.05935                     |
| [060/100]:    050.7           010.0           0.02762      0.05972                     |

| [035/100]:    99.07           81.71           0.00045      0.01276                     |
| [036/100]:    98.98           80.58           00.0005      0.01367                     |
| [037/100]:    98.83           81.93           0.00057      0.01253                     |
| [038/100]:    99.08           81.94           0.00046      00.0121                     |
| [039/100]:    98.77           81.23           0.00061      0.01224                     |
| [040/100]:    98.61           81.59           0.00063      0.01211                     |
| [041/100]:    98.73           80.91           0.00062      0.01287                     |
| [042/100]:    98.47           81.59           0.00074      00.0121                     |
| [043/100]:    099.0           82.46           00.0005      0.01141                     |
| [044/100]:    99.41           082.3           0.00032      0.01177                     |
| [045/100]:    99.71           83.97           0.00018      0.01054                     |

| [020/100]:    51.56           15.19           0.02592      0.14794                     |
| [021/100]:    57.81           17.17           0.02315      0.14417                     |
| [022/100]:    65.62           18.32           0.01999      0.15655                     |
| [023/100]:    53.12           16.72           0.02303      0.17706                     |
| [024/100]:    43.75           16.06           0.02229      0.16911                     |
| [025/100]:    59.38           17.07           0.01617      0.15697                     |
| [026/100]:    062.5           16.83           00.0165      0.15302                     |
| [027/100]:    65.62           016.7           0.01471      0.15825                     |
| [028/100]:    54.69           16.56           0.02037      0.17035                     |
| [029/100]:    64.06           16.31           0.01376      0.16339                     |
| [030/100]:    68.75           16.02           0.01981      00.1689                     |

| [005/100]:    100.0           010.0           00000.0      4.13044                     |
| [006/100]:    100.0           010.0           00000.0      6.26697                     |
| [007/100]:    100.0           010.0           00000.0      8.68021                     |
| [008/100]:    100.0           010.0           00000.0      11.26041                     |
| [009/100]:    100.0           010.0           00000.0      13.97656                     |
| [010/100]:    100.0           010.0           00000.0      16.66061                     |
| [011/100]:    100.0           010.0           00000.0      19.22197                     |
| [012/100]:    100.0           010.0           00000.0      21.5806                     |
| [013/100]:    100.0           010.0           00000.0      23.68533                     |
| [014/100]:    100.0           010.0           00000.0      25.58064                     |
| [015/100]:    100.0           010.0           00000.0      27.24504               

| [095/100]:    100.0           010.0           00000.0      31.70657                     |
| [096/100]:    100.0           010.0           00000.0      31.67622                     |
| [097/100]:    100.0           010.0           00000.0      31.74716                     |
| [098/100]:    100.0           010.0           00000.0      31.72537                     |
| [099/100]:    100.0           010.0           00000.0      31.70908                     |
| [100/100]:    100.0           010.0           00000.0      31.68316                     |
+----------------------------------------------------------------------------------------+

Finished training. Time needed: 15 hrs 39 mins 32 secs


# Notes: