# Machine Intelligence with Deep Learning
## Importance batching for improved training of neural networks
---

In [1]:
import timeit

In [2]:
from models.resnet import ResNet18
from utils.data_utils import DataLoader
from utils.logging_utils import *

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import os
import pandas as pd

from datetime import datetime
today = datetime.today().strftime('%Y%m%d')

In [3]:
SEEDS = [10, 42, 4] # don't change!
#STRATEGIES = ['freeze', 'shuffle', 'homogeneous', 'heterogeneous'] # can be changed
STRATEGIES = ['max_k_loss', 'min_k_loss']

In [4]:
### Training
def train(epoch, optimizer, criterion_fn, seed, dataloader, strategy, device):
    criterion = criterion_fn()
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches(strategy, \
                                                  random_state=seed, use_train=True, \
                                                  criterion=criterion_fn(reduction='none'),\
                                                  device=device)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    train_acc = 100.*correct/total
    train_loss /= total
    return train_acc, train_loss

### Testing
def test(epoch, best_acc, criterion_fn, seed, dataloader):
    criterion = criterion_fn()
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(dataloader.yield_batches('shuffle', \
                                                                               random_state=seed, use_train=False)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    test_acc = 100.*correct/total
    test_loss /= total
    if test_acc > best_acc:
        best_acc = test_acc
        net.save(best_acc, epoch, seed, strategy)
    return test_acc, test_loss

In [5]:
resume = False
given_date = '20191113' #only needed if resumed from checkpoint
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda':
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
### task: classification of the following classes
classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

    
### hyperparameters
test_acc = 0  # best test accuracy
start_epoch = 0  # start from epoch 0 or last checkpoint epoch
num_epochs = 150 # number of iterations the model gets trained
learning_rates = { # learning rate is reset after specific epochs
    '1': 0.1, # 50 epochs
    '50': 0.01, # 40 epochs
    '90': 0.005, # 30 epochs
    '120': 0.001 # 30 epochs
}
momentum = 0.9
weight_decay = 5e-4

print("Begin training.")
start = timeit.default_timer()
#Logging header
length_table = 90
log_separating_line(length_table)
log_header_line("Seeds: {}".format(SEEDS), length_table)
log_header_line("Strategies: {}".format(STRATEGIES), length_table)
log_header_line("-> Resulting number of iterations: {}".format(len(SEEDS) * len(STRATEGIES)), length_table)
log_header_line("Number of iterations: {}".format(num_epochs), length_table)
log_header_line("Learning rates: {}".format(learning_rates), length_table)
log_header_line("Resuming from checkpoint: {}".format(True if resume else False), length_table)
log_separating_line(length_table)

rows = []
for seed in SEEDS:
    for strategy in STRATEGIES:
                
        np.random.seed(seed)
        torch.manual_seed(seed)
        if device == 'cuda':
            torch.cuda.manual_seed_all(seed)
             
        ### Model
        net = ResNet18()
        net = net.to(device)
        criterion = nn.CrossEntropyLoss # no function!
        optimizer = optim.SGD(net.parameters(), lr=learning_rates['1'], momentum=momentum, weight_decay=weight_decay)
        
        ### load the data
        # if needed, specify batch sizes and shuffle settings
        dataloader = DataLoader()
        dataloader.download_cifar()
        dataloader.set_model(net)
        if strategy in ['max_k_loss', 'min_k_loss']:
            dataloader.initialize_weights(criterion(reduction='none'), device, seed=seed)
        print()
        log_separating_line(length_table)
        log_position_header(seed, strategy, length_table)
        log_separating_line(length_table)

        if resume:
            assert os.path.isdir('serialized'), 'Error: no serialized directory found!'
            ckpt = torch.load('./serialized/{}/{}_ckpt_{}.pth'.format(given_date, strategy, seed))
            test_acc, start_epoch, net = net.load(ckpt)

        for epoch in range(start_epoch, start_epoch+num_epochs):
            # reset learning rate at specific epochs
            if str(epoch+1) in learning_rates.keys():
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rates[str(epoch+1)]
                    
            train_acc, train_loss = train(epoch, optimizer, criterion, seed, dataloader, \
                                          strategy, device)
            test_acc, test_loss = test(epoch, test_acc, criterion, seed, dataloader)
            log_position_line(epoch + 1, num_epochs, train_acc, test_acc, train_loss, test_loss, length_table)
            row = {
                'iteration': epoch + 1,
                'seed': seed,
                'train': True,
                'strategy': strategy,
                'accuracy': train_acc,
                'loss': train_loss
            }
            rows.append(row)
            row = {
                'iteration': epoch + 1,
                'seed': seed,
                'train': False,
                'strategy': strategy,
                'accuracy': test_acc,
                'loss': test_loss
            }
            rows.append(row)
        log_separating_line(length_table)
            
stop = timeit.default_timer()
time_needed = stop - start
hrs = int(time_needed / 3600)
mins = int((time_needed / 60) % 60)
secs = int(time_needed % 60)
print()
print("Finished training. Time needed: {} hrs {} mins {} secs".format(hrs, mins, secs))

logging_df = pd.DataFrame(rows, columns=['epoch', 'seed', 'train', 'strategy', 'accuracy', 'loss'])   
training_logs_dir = 'evaluation_logs'
logging_df.to_csv('{}.txt'.format(os.path.join(training_logs_dir, today)), sep='\t', index=False)

Begin training.
+----------------------------------------------------------------------------------------+
| Seeds: [10, 42, 4]                                                                     |
| Strategies: ['max_k_loss', 'min_k_loss']                                               |
| -> Resulting number of iterations: 6                                                   |
| Number of iterations: 150                                                              |
| Learning rates: {'1': 0.1, '50': 0.01, '90': 0.005, '120': 0.001}                      |
| Resuming from checkpoint: False                                                        |
+----------------------------------------------------------------------------------------+
Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 10      Strategy: max_k_loss                                                    

| [078/150]:    31.25           10.19           0.03132      00.0433                     |
| [079/150]:    18.75           10.19           0.03412      0.04381                     |
| [080/150]:    025.0           10.21           00.0326      0.04376                     |
| [081/150]:    18.75           10.18           0.03518      0.04411                     |
| [082/150]:    18.75           10.23           0.03649      0.04398                     |
| [083/150]:    23.44           10.24           0.03694      00.0436                     |
| [084/150]:    23.44           10.24           00.0354      00.0433                     |
| [085/150]:    20.31           10.23           0.03624      0.04285                     |
| [086/150]:    21.88           010.2           0.03496      0.04258                     |
| [087/150]:    29.69           10.19           0.03492      0.04251                     |
| [088/150]:    21.88           10.17           0.03571      0.04211                     |

| [012/150]:    100.0           010.0           00000.0      329.08118                     |
| [013/150]:    100.0           010.0           00000.0      248.9327                     |
| [014/150]:    100.0           010.0           00000.0      191.42381                     |
| [015/150]:    100.0           010.0           00000.0      150.20345                     |
| [016/150]:    100.0           010.0           00000.0      119.11023                     |
| [017/150]:    100.0           010.0           00000.0      96.08717                     |
| [018/150]:    100.0           010.0           00000.0      78.76854                     |
| [019/150]:    100.0           010.0           00000.0      65.36083                     |
| [020/150]:    100.0           010.0           00000.0      54.99471                     |
| [021/150]:    100.0           010.0           00000.0      46.78355                     |
| [022/150]:    100.0           010.0           00000.0      40.25671       

| [102/150]:    100.0           010.0           00000.0      9.49899                     |
| [103/150]:    100.0           010.0           00000.0      9.48696                     |
| [104/150]:    100.0           010.0           00000.0      9.48633                     |
| [105/150]:    100.0           010.0           00000.0      9.51581                     |
| [106/150]:    100.0           010.0           00000.0      009.485                     |
| [107/150]:    100.0           010.0           00000.0      9.47589                     |
| [108/150]:    100.0           010.0           00000.0      9.51245                     |
| [109/150]:    100.0           010.0           00000.0      9.47035                     |
| [110/150]:    100.0           010.0           00000.0      9.48761                     |
| [111/150]:    100.0           010.0           00000.0      9.49004                     |
| [112/150]:    100.0           010.0           00000.0      9.50698                     |

| [036/150]:    67.19           10.01           0.01574      0.17528                     |
| [037/150]:    65.62           09.99           0.01509      0.20034                     |
| [038/150]:    70.31           09.99           0.01471      0.22314                     |
| [039/150]:    70.31           010.0           0.01259      0.26737                     |
| [040/150]:    68.75           010.0           00.0141      0.26813                     |
| [041/150]:    71.88           010.0           0.01419      0.26824                     |
| [042/150]:    65.62           10.03           0.01521      0.26009                     |
| [043/150]:    65.62           010.1           00.0159      0.25577                     |
| [044/150]:    60.94           011.1           0.01721      0.27176                     |
| [045/150]:    65.62           12.48           0.01455      0.23716                     |
| [046/150]:    71.88           13.15           0.01203      0.22636                     |

| [127/150]:    46.88           12.31           0.01936      0.06594                     |
| [128/150]:    48.44           12.43           0.01873      0.06498                     |
| [129/150]:    46.88           012.6           0.01924      0.06282                     |
| [130/150]:    54.69           12.79           0.01927      0.06035                     |
| [131/150]:    46.88           12.82           0.02087      0.05928                     |
| [132/150]:    037.5           12.81           0.02156      0.05901                     |
| [133/150]:    51.56           12.85           00.0219      00.0584                     |
| [134/150]:    34.38           12.67           0.02311      0.06073                     |
| [135/150]:    40.62           12.69           0.02219      0.06044                     |
| [136/150]:    39.06           12.54           0.02048      0.06209                     |
| [137/150]:    32.81           12.38           0.02341      00.0609                     |

| [061/150]:    100.0           010.0           00000.0      44.01865                     |
| [062/150]:    100.0           010.0           00000.0      44.23226                     |
| [063/150]:    100.0           010.0           00000.0      44.16966                     |
| [064/150]:    100.0           010.0           00000.0      44.21102                     |
| [065/150]:    100.0           010.0           00000.0      44.30547                     |
| [066/150]:    100.0           010.0           00000.0      44.20766                     |
| [067/150]:    100.0           010.0           00000.0      44.01004                     |
| [068/150]:    100.0           010.0           00000.0      44.16743                     |
| [069/150]:    100.0           010.0           00000.0      43.99418                     |
| [070/150]:    100.0           010.0           00000.0      44.03119                     |
| [071/150]:    100.0           010.0           00000.0      43.9433            

Files already downloaded and verified
Files already downloaded and verified

+----------------------------------------------------------------------------------------+
| Seed: 04      Strategy: max_k_loss                                                         |
| Epoch         Train Accuracy  Test Accuracy   Train Loss   Test Loss                   |
+----------------------------------------------------------------------------------------+
| [001/150]:    000.0           010.0           0.04201      0.52156                     |
| [002/150]:    53.12           010.0           0.05904      5045.07646                     |
| [003/150]:    54.69           010.0           0.05699      25927.39724                     |
| [004/150]:    32.81           010.0           0.03914      8324.08061                     |
| [005/150]:    025.0           010.0           0.03438      1297.35867                     |
| [006/150]:    28.12           010.0           0.08001      411431.459                

| [086/150]:    050.0           010.5           0.02391      0.05501                     |
| [087/150]:    57.81           10.59           0.02279      0.05506                     |
| [088/150]:    050.0           10.69           0.02341      0.05532                     |
| [089/150]:    53.12           010.9           0.03827      00.0602                     |
| [090/150]:    51.56           11.08           0.06267      0.06773                     |
| [091/150]:    57.81           11.15           0.02353      0.06886                     |
| [092/150]:    57.81           11.29           0.02196      00.0654                     |
| [093/150]:    59.38           11.25           0.02137      0.06138                     |
| [094/150]:    59.38           11.31           0.02105      0.06418                     |
| [095/150]:    062.5           11.43           0.02172      0.06342                     |
| [096/150]:    56.25           11.47           00.0233      0.06576                     |

| [020/150]:    100.0           010.0           00000.0      3606.9595                     |
| [021/150]:    100.0           010.0           00000.0      3306.87617                     |
| [022/150]:    100.0           010.0           00000.0      3047.47936                     |
| [023/150]:    100.0           010.0           00000.0      2840.74612                     |
| [024/150]:    100.0           010.0           00000.0      2664.06889                     |
| [025/150]:    100.0           010.0           00000.0      2512.43619                     |
| [026/150]:    100.0           010.0           00000.0      2387.04995                     |
| [027/150]:    100.0           010.0           00000.0      2277.03738                     |
| [028/150]:    100.0           010.0           00000.0      2179.6156                     |
| [029/150]:    100.0           010.0           00000.0      2099.60188                     |
| [030/150]:    100.0           010.0           00000.0      2

| [108/150]:    100.0           010.0           00000.0      1530.25613                     |
| [109/150]:    100.0           010.0           00000.0      1530.13623                     |
| [110/150]:    100.0           010.0           00000.0      1532.70171                     |
| [111/150]:    100.0           010.0           00000.0      1533.07144                     |
| [112/150]:    100.0           010.0           00000.0      1530.7457                     |
| [113/150]:    100.0           010.0           00000.0      1530.88639                     |
| [114/150]:    100.0           010.0           00000.0      1529.1194                     |
| [115/150]:    100.0           010.0           00000.0      1529.18119                     |
| [116/150]:    100.0           010.0           00000.0      1530.35651                     |
| [117/150]:    100.0           010.0           00000.0      1531.90639                     |
| [118/150]:    100.0           010.0           00000.0      1

# Notes: