In [1]:
import math
import time
import numpy as np
import copy

import dataset
import VanillaVisionTransformer

import matplotlib.pyplot as plt
import matplotlib

# Pytorch packages
import torch
import torch.optim as optim
import torch.nn as nn

# torchvision
import torchvision
from torchvision import transforms

# Tqdm progress bar
from tqdm import tqdm_notebook


In [None]:
# sigopt setup
import sigopt
import os

os.environ["SIGOPT_API_TOKEN"] = "XWBIVDWCVQXALUZQFDHNGOELLLKDJBMOJALEPCNQXQGBNIGC"
os.environ['SIGOPT_PROJECT'] = 'vanilla_vit'

In [2]:
# Check device availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("You are using device: %s" % device)

You are using device: cuda


In [4]:
torch.cuda.empty_cache()

LEARNING_RATE = 0.003
MOMENTUM = 0.5
WEIGHT_DECAY_REGULARIZATION_TERM = 0.005
NUM_EPOCHS = 4
BATCH_SIZE = 32

PATCH_SIZE = 8
HIDDEN_DIM = 512
EMBED_DIM = 256 # aka mlp_dim
NUM_CHANNELS = 3
NUM_HEADS = 8
NUM_LAYERS = 6
DROPOUT = 0.2

# patch size is the size of each patch
def img_to_patches(x, patch_size, flatten_channels=True):
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n

        self.avg = self.sum / self.count

def train(epoch, data_loader, model, optimizer, criterion):
    iter_time = AverageMeter()
    losses = AverageMeter()

    for idx, (data, target) in enumerate(data_loader):
        start = time.time()
        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        data = img_to_patches(data, patch_size = PATCH_SIZE, flatten_channels=True)
        out = model(data).to(device)
        loss = criterion(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss, out.shape[0])

        iter_time.update(time.time() - start)
        if idx % 10 == 0:
            print(('Epoch: [{0}][{1}/{2}]\t'
                   'Time {iter_time.val:.3f} ({iter_time.avg:.3f})\t'
                   'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  .format(epoch, idx, len(data_loader), iter_time=iter_time, loss=losses)))
    
    return losses.avg

def validate(epoch, validation_loader, model, criterion):
    iter_time = AverageMeter()
    losses = AverageMeter()

    for idx, (data, target) in enumerate(validation_loader):
        start = time.time()

        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        out = None
        loss = None
        data = img_to_patches(data, patch_size = PATCH_SIZE, flatten_channels=True)

        with torch.no_grad():
            out = model(data).to(device)
            loss = criterion(out, target)

        losses.update(loss, out.shape[0])

        iter_time.update(time.time() - start)
        if idx % 10 == 0:
            print(('Epoch: [{0}][{1}/{2}]\t'
                   'Time {iter_time.val:.3f} ({iter_time.avg:.3f})\t')
                  .format(epoch, idx, len(validation_loader), iter_time=iter_time, loss=losses))

    print("* Average Loss @1: {loss.avg:.4f}".format(loss=losses))
    return losses.avg

def RMSELoss(yhat, y, eps=1e-6):
    return torch.sqrt(torch.mean((yhat-y)**2) + eps)

In [5]:
def main():
    # Normalizing images per the paper and resizing each image to 64 x 192.
    transform = transforms.Compose([
        # Citation:
        # https://pytorch.org/vision/stable/transforms.html#scriptable-transforms
        transforms.Resize((64, 192)),
    ])
    # Loading in images with normalization and resizing applied.
    training_set, validation_set, test_set = dataset.load_nvidia_dataset(batch_size=BATCH_SIZE, transform=transform)

    image_size = training_set.dataset[0][0].shape

    model = VanillaVisionTransformer.VisionTransformer(device,
                                                image_size,
                                               PATCH_SIZE, 
                                               HIDDEN_DIM,
                                               EMBED_DIM,
                                               NUM_CHANNELS,
                                               NUM_HEADS,
                                               NUM_LAYERS,
                                               DROPOUT).to(device)
    
    criterion = RMSELoss

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    
    best = 1.0
    best_model = None
    for epoch in range(NUM_EPOCHS):
        
        # Training.
        train_loss = train(epoch, training_set, model, optimizer, criterion)
        scheduler.step(train_loss)
                
        # Validation.
        valid_loss = validate(epoch, validation_set, model, criterion)
        
        if valid_loss < best:
            best = valid_loss
            best_model = copy.deepcopy(model)

    print('Best Loss @1: {:.4f}'.format(best))

    torch.save(best_model.state_dict(), './checkpoints/vanilla_vision_transformer.pth')
    
if __name__ == '__main__':
    main()

Epoch: [0][0/2282]	Time 1.118 (1.118)	Loss 0.2422 (0.2422)	
Epoch: [0][10/2282]	Time 0.017 (0.144)	Loss 0.7475 (2.5214)	
Epoch: [0][20/2282]	Time 0.016 (0.083)	Loss 0.4117 (1.5236)	
Epoch: [0][30/2282]	Time 0.017 (0.061)	Loss 0.2708 (1.1299)	
Epoch: [0][40/2282]	Time 0.016 (0.050)	Loss 0.3608 (0.9303)	
Epoch: [0][50/2282]	Time 0.016 (0.044)	Loss 0.3638 (0.8005)	
Epoch: [0][60/2282]	Time 0.016 (0.039)	Loss 0.4449 (0.7223)	
Epoch: [0][70/2282]	Time 0.015 (0.036)	Loss 0.3211 (0.6558)	
Epoch: [0][80/2282]	Time 0.016 (0.033)	Loss 0.3920 (0.6094)	
Epoch: [0][90/2282]	Time 0.016 (0.032)	Loss 0.2332 (0.5741)	
Epoch: [0][100/2282]	Time 0.016 (0.030)	Loss 0.3934 (0.5457)	
Epoch: [0][110/2282]	Time 0.016 (0.029)	Loss 0.2206 (0.5191)	
Epoch: [0][120/2282]	Time 0.015 (0.028)	Loss 0.1986 (0.4940)	
Epoch: [0][130/2282]	Time 0.016 (0.027)	Loss 0.2234 (0.4769)	
Epoch: [0][140/2282]	Time 0.017 (0.026)	Loss 0.2002 (0.4615)	
Epoch: [0][150/2282]	Time 0.015 (0.025)	Loss 0.3080 (0.4464)	
Epoch: [0][160/2282

Epoch: [0][1320/2282]	Time 0.016 (0.018)	Loss 0.2490 (0.2846)	
Epoch: [0][1330/2282]	Time 0.018 (0.018)	Loss 0.3415 (0.2842)	
Epoch: [0][1340/2282]	Time 0.017 (0.018)	Loss 0.1989 (0.2840)	
Epoch: [0][1350/2282]	Time 0.017 (0.018)	Loss 0.2808 (0.2841)	
Epoch: [0][1360/2282]	Time 0.016 (0.018)	Loss 0.2137 (0.2836)	
Epoch: [0][1370/2282]	Time 0.017 (0.018)	Loss 0.2304 (0.2838)	
Epoch: [0][1380/2282]	Time 0.016 (0.018)	Loss 0.4046 (0.2840)	
Epoch: [0][1390/2282]	Time 0.017 (0.018)	Loss 0.2223 (0.2838)	
Epoch: [0][1400/2282]	Time 0.017 (0.018)	Loss 0.1531 (0.2836)	
Epoch: [0][1410/2282]	Time 0.016 (0.018)	Loss 0.2037 (0.2834)	
Epoch: [0][1420/2282]	Time 0.016 (0.018)	Loss 0.2398 (0.2832)	
Epoch: [0][1430/2282]	Time 0.016 (0.018)	Loss 0.4569 (0.2830)	
Epoch: [0][1440/2282]	Time 0.017 (0.018)	Loss 0.2845 (0.2828)	
Epoch: [0][1450/2282]	Time 0.015 (0.018)	Loss 0.1814 (0.2826)	
Epoch: [0][1460/2282]	Time 0.017 (0.018)	Loss 0.4879 (0.2824)	
Epoch: [0][1470/2282]	Time 0.017 (0.018)	Loss 0.1984 (0

Epoch: [0][530/571]	Time 0.004 (0.004)	
Epoch: [0][540/571]	Time 0.004 (0.004)	
Epoch: [0][550/571]	Time 0.004 (0.004)	
Epoch: [0][560/571]	Time 0.004 (0.004)	
Epoch: [0][570/571]	Time 0.015 (0.004)	
* Average Loss @1: 0.2651
Epoch: [1][0/2282]	Time 0.021 (0.021)	Loss 0.1885 (0.1885)	
Epoch: [1][10/2282]	Time 0.016 (0.064)	Loss 0.1587 (0.2607)	
Epoch: [1][20/2282]	Time 0.016 (0.042)	Loss 0.3167 (0.2559)	
Epoch: [1][30/2282]	Time 0.016 (0.034)	Loss 0.2093 (0.2538)	
Epoch: [1][40/2282]	Time 0.016 (0.030)	Loss 0.2829 (0.2660)	
Epoch: [1][50/2282]	Time 0.017 (0.027)	Loss 0.2349 (0.2669)	
Epoch: [1][60/2282]	Time 0.016 (0.025)	Loss 0.2473 (0.2692)	
Epoch: [1][70/2282]	Time 0.016 (0.024)	Loss 0.2800 (0.2675)	
Epoch: [1][80/2282]	Time 0.017 (0.023)	Loss 0.4251 (0.2679)	
Epoch: [1][90/2282]	Time 0.016 (0.022)	Loss 0.2709 (0.2698)	
Epoch: [1][100/2282]	Time 0.016 (0.022)	Loss 0.2423 (0.2683)	
Epoch: [1][110/2282]	Time 0.016 (0.021)	Loss 0.2106 (0.2683)	
Epoch: [1][120/2282]	Time 0.016 (0.021)	L

Epoch: [1][1290/2282]	Time 0.016 (0.017)	Loss 0.2092 (0.2592)	
Epoch: [1][1300/2282]	Time 0.016 (0.017)	Loss 0.3978 (0.2594)	
Epoch: [1][1310/2282]	Time 0.016 (0.017)	Loss 0.4418 (0.2595)	
Epoch: [1][1320/2282]	Time 0.016 (0.017)	Loss 0.1723 (0.2591)	
Epoch: [1][1330/2282]	Time 0.015 (0.017)	Loss 0.3459 (0.2587)	
Epoch: [1][1340/2282]	Time 0.016 (0.017)	Loss 0.1617 (0.2585)	
Epoch: [1][1350/2282]	Time 0.016 (0.017)	Loss 0.1602 (0.2585)	
Epoch: [1][1360/2282]	Time 0.016 (0.017)	Loss 0.3473 (0.2588)	
Epoch: [1][1370/2282]	Time 0.016 (0.017)	Loss 0.3110 (0.2588)	
Epoch: [1][1380/2282]	Time 0.015 (0.017)	Loss 0.3493 (0.2590)	
Epoch: [1][1390/2282]	Time 0.017 (0.017)	Loss 0.1813 (0.2591)	
Epoch: [1][1400/2282]	Time 0.016 (0.017)	Loss 0.1500 (0.2592)	
Epoch: [1][1410/2282]	Time 0.016 (0.017)	Loss 0.2557 (0.2591)	
Epoch: [1][1420/2282]	Time 0.016 (0.017)	Loss 0.3611 (0.2592)	
Epoch: [1][1430/2282]	Time 0.016 (0.017)	Loss 0.3554 (0.2590)	
Epoch: [1][1440/2282]	Time 0.016 (0.017)	Loss 0.1624 (0

Epoch: [1][480/571]	Time 0.004 (0.004)	
Epoch: [1][490/571]	Time 0.004 (0.004)	
Epoch: [1][500/571]	Time 0.004 (0.004)	
Epoch: [1][510/571]	Time 0.004 (0.004)	
Epoch: [1][520/571]	Time 0.005 (0.004)	
Epoch: [1][530/571]	Time 0.005 (0.004)	
Epoch: [1][540/571]	Time 0.004 (0.004)	
Epoch: [1][550/571]	Time 0.004 (0.004)	
Epoch: [1][560/571]	Time 0.004 (0.004)	
Epoch: [1][570/571]	Time 0.004 (0.004)	
* Average Loss @1: 0.2656
Epoch: [2][0/2282]	Time 0.017 (0.017)	Loss 0.2038 (0.2038)	
Epoch: [2][10/2282]	Time 0.016 (0.016)	Loss 0.1867 (0.2569)	
Epoch: [2][20/2282]	Time 0.018 (0.016)	Loss 0.3025 (0.2503)	
Epoch: [2][30/2282]	Time 0.015 (0.016)	Loss 0.4294 (0.2599)	
Epoch: [2][40/2282]	Time 0.016 (0.016)	Loss 0.1912 (0.2587)	
Epoch: [2][50/2282]	Time 0.015 (0.016)	Loss 0.2230 (0.2583)	
Epoch: [2][60/2282]	Time 0.016 (0.016)	Loss 0.4813 (0.2636)	
Epoch: [2][70/2282]	Time 0.016 (0.016)	Loss 0.2094 (0.2668)	
Epoch: [2][80/2282]	Time 0.015 (0.016)	Loss 0.3978 (0.2701)	
Epoch: [2][90/2282]	Time 0

Epoch: [2][1260/2282]	Time 0.017 (0.016)	Loss 0.1481 (0.2587)	
Epoch: [2][1270/2282]	Time 0.016 (0.016)	Loss 0.4373 (0.2588)	
Epoch: [2][1280/2282]	Time 0.018 (0.016)	Loss 0.2608 (0.2589)	
Epoch: [2][1290/2282]	Time 0.016 (0.016)	Loss 0.2994 (0.2593)	
Epoch: [2][1300/2282]	Time 0.017 (0.016)	Loss 0.2610 (0.2590)	
Epoch: [2][1310/2282]	Time 0.016 (0.016)	Loss 0.3015 (0.2587)	
Epoch: [2][1320/2282]	Time 0.016 (0.016)	Loss 0.1750 (0.2588)	
Epoch: [2][1330/2282]	Time 0.016 (0.016)	Loss 0.2373 (0.2588)	
Epoch: [2][1340/2282]	Time 0.016 (0.016)	Loss 0.2337 (0.2588)	
Epoch: [2][1350/2282]	Time 0.017 (0.016)	Loss 0.2122 (0.2589)	
Epoch: [2][1360/2282]	Time 0.016 (0.016)	Loss 0.1895 (0.2591)	
Epoch: [2][1370/2282]	Time 0.016 (0.016)	Loss 0.2315 (0.2589)	
Epoch: [2][1380/2282]	Time 0.017 (0.016)	Loss 0.2911 (0.2590)	
Epoch: [2][1390/2282]	Time 0.017 (0.016)	Loss 0.2784 (0.2588)	
Epoch: [2][1400/2282]	Time 0.016 (0.016)	Loss 0.1909 (0.2589)	
Epoch: [2][1410/2282]	Time 0.016 (0.016)	Loss 0.3501 (0

Epoch: [2][430/571]	Time 0.004 (0.004)	
Epoch: [2][440/571]	Time 0.004 (0.004)	
Epoch: [2][450/571]	Time 0.004 (0.004)	
Epoch: [2][460/571]	Time 0.004 (0.004)	
Epoch: [2][470/571]	Time 0.004 (0.004)	
Epoch: [2][480/571]	Time 0.004 (0.004)	
Epoch: [2][490/571]	Time 0.004 (0.004)	
Epoch: [2][500/571]	Time 0.005 (0.004)	
Epoch: [2][510/571]	Time 0.004 (0.004)	
Epoch: [2][520/571]	Time 0.004 (0.004)	
Epoch: [2][530/571]	Time 0.004 (0.004)	
Epoch: [2][540/571]	Time 0.004 (0.004)	
Epoch: [2][550/571]	Time 0.004 (0.004)	
Epoch: [2][560/571]	Time 0.004 (0.004)	
Epoch: [2][570/571]	Time 0.004 (0.004)	
* Average Loss @1: 0.2672
Epoch: [3][0/2282]	Time 0.016 (0.016)	Loss 0.4502 (0.4502)	
Epoch: [3][10/2282]	Time 0.016 (0.016)	Loss 0.2252 (0.2511)	
Epoch: [3][20/2282]	Time 0.016 (0.016)	Loss 0.2131 (0.2662)	
Epoch: [3][30/2282]	Time 0.016 (0.016)	Loss 0.3058 (0.2606)	
Epoch: [3][40/2282]	Time 0.016 (0.016)	Loss 0.2176 (0.2534)	
Epoch: [3][50/2282]	Time 0.018 (0.016)	Loss 0.2783 (0.2545)	
Epoch: [3

Epoch: [3][1220/2282]	Time 0.016 (0.017)	Loss 0.2920 (0.2588)	
Epoch: [3][1230/2282]	Time 0.016 (0.017)	Loss 0.4134 (0.2590)	
Epoch: [3][1240/2282]	Time 0.016 (0.017)	Loss 0.1524 (0.2589)	
Epoch: [3][1250/2282]	Time 0.016 (0.017)	Loss 0.2324 (0.2591)	
Epoch: [3][1260/2282]	Time 0.016 (0.017)	Loss 0.2955 (0.2590)	
Epoch: [3][1270/2282]	Time 0.016 (0.017)	Loss 0.2942 (0.2588)	
Epoch: [3][1280/2282]	Time 0.017 (0.017)	Loss 0.3457 (0.2585)	
Epoch: [3][1290/2282]	Time 0.015 (0.017)	Loss 0.2535 (0.2582)	
Epoch: [3][1300/2282]	Time 0.015 (0.017)	Loss 0.1730 (0.2583)	
Epoch: [3][1310/2282]	Time 0.016 (0.017)	Loss 0.3568 (0.2582)	
Epoch: [3][1320/2282]	Time 0.016 (0.017)	Loss 0.1748 (0.2581)	
Epoch: [3][1330/2282]	Time 0.016 (0.017)	Loss 0.2599 (0.2582)	
Epoch: [3][1340/2282]	Time 0.017 (0.017)	Loss 0.3535 (0.2582)	
Epoch: [3][1350/2282]	Time 0.016 (0.017)	Loss 0.3152 (0.2582)	
Epoch: [3][1360/2282]	Time 0.016 (0.017)	Loss 0.2196 (0.2580)	
Epoch: [3][1370/2282]	Time 0.016 (0.017)	Loss 0.2369 (0

Epoch: [3][370/571]	Time 0.004 (0.004)	
Epoch: [3][380/571]	Time 0.005 (0.004)	
Epoch: [3][390/571]	Time 0.005 (0.004)	
Epoch: [3][400/571]	Time 0.004 (0.004)	
Epoch: [3][410/571]	Time 0.005 (0.004)	
Epoch: [3][420/571]	Time 0.004 (0.004)	
Epoch: [3][430/571]	Time 0.004 (0.004)	
Epoch: [3][440/571]	Time 0.004 (0.004)	
Epoch: [3][450/571]	Time 0.005 (0.004)	
Epoch: [3][460/571]	Time 0.004 (0.004)	
Epoch: [3][470/571]	Time 0.004 (0.004)	
Epoch: [3][480/571]	Time 0.004 (0.004)	
Epoch: [3][490/571]	Time 0.005 (0.004)	
Epoch: [3][500/571]	Time 0.004 (0.004)	
Epoch: [3][510/571]	Time 0.004 (0.004)	
Epoch: [3][520/571]	Time 0.004 (0.004)	
Epoch: [3][530/571]	Time 0.004 (0.004)	
Epoch: [3][540/571]	Time 0.004 (0.004)	
Epoch: [3][550/571]	Time 0.004 (0.004)	
Epoch: [3][560/571]	Time 0.004 (0.004)	
Epoch: [3][570/571]	Time 0.004 (0.004)	
* Average Loss @1: 0.2644
Epoch: [4][0/2282]	Time 0.015 (0.015)	Loss 0.3202 (0.3202)	
Epoch: [4][10/2282]	Time 0.016 (0.016)	Loss 0.2671 (0.3278)	
Epoch: [4][20

Epoch: [4][1190/2282]	Time 0.016 (0.016)	Loss 0.2761 (0.2602)	
Epoch: [4][1200/2282]	Time 0.016 (0.016)	Loss 0.1667 (0.2600)	
Epoch: [4][1210/2282]	Time 0.016 (0.016)	Loss 0.2298 (0.2602)	
Epoch: [4][1220/2282]	Time 0.016 (0.016)	Loss 0.2565 (0.2599)	
Epoch: [4][1230/2282]	Time 0.016 (0.016)	Loss 0.1734 (0.2600)	
Epoch: [4][1240/2282]	Time 0.016 (0.016)	Loss 0.2122 (0.2597)	
Epoch: [4][1250/2282]	Time 0.016 (0.016)	Loss 0.1908 (0.2597)	
Epoch: [4][1260/2282]	Time 0.016 (0.016)	Loss 0.2189 (0.2599)	
Epoch: [4][1270/2282]	Time 0.016 (0.016)	Loss 0.2622 (0.2599)	
Epoch: [4][1280/2282]	Time 0.016 (0.016)	Loss 0.1398 (0.2596)	
Epoch: [4][1290/2282]	Time 0.016 (0.016)	Loss 0.1891 (0.2598)	
Epoch: [4][1300/2282]	Time 0.016 (0.016)	Loss 0.3439 (0.2596)	
Epoch: [4][1310/2282]	Time 0.016 (0.016)	Loss 0.3983 (0.2596)	
Epoch: [4][1320/2282]	Time 0.016 (0.016)	Loss 0.2898 (0.2598)	
Epoch: [4][1330/2282]	Time 0.016 (0.016)	Loss 0.1498 (0.2597)	
Epoch: [4][1340/2282]	Time 0.016 (0.016)	Loss 0.1938 (0

Epoch: [4][320/571]	Time 0.004 (0.004)	
Epoch: [4][330/571]	Time 0.004 (0.004)	
Epoch: [4][340/571]	Time 0.005 (0.004)	
Epoch: [4][350/571]	Time 0.004 (0.004)	
Epoch: [4][360/571]	Time 0.004 (0.004)	
Epoch: [4][370/571]	Time 0.004 (0.004)	
Epoch: [4][380/571]	Time 0.004 (0.004)	
Epoch: [4][390/571]	Time 0.004 (0.004)	
Epoch: [4][400/571]	Time 0.004 (0.004)	
Epoch: [4][410/571]	Time 0.004 (0.004)	
Epoch: [4][420/571]	Time 0.004 (0.004)	
Epoch: [4][430/571]	Time 0.004 (0.004)	
Epoch: [4][440/571]	Time 0.005 (0.004)	
Epoch: [4][450/571]	Time 0.004 (0.004)	
Epoch: [4][460/571]	Time 0.004 (0.004)	
Epoch: [4][470/571]	Time 0.004 (0.004)	
Epoch: [4][480/571]	Time 0.005 (0.004)	
Epoch: [4][490/571]	Time 0.004 (0.004)	
Epoch: [4][500/571]	Time 0.004 (0.004)	
Epoch: [4][510/571]	Time 0.004 (0.004)	
Epoch: [4][520/571]	Time 0.004 (0.004)	
Epoch: [4][530/571]	Time 0.004 (0.004)	
Epoch: [4][540/571]	Time 0.004 (0.004)	
Epoch: [4][550/571]	Time 0.004 (0.004)	
Epoch: [4][560/571]	Time 0.005 (0.004)	


## Driver Code

In [None]:
def evaluate(assignments, args):
    
    # log source of hyperparameter suggestion
    sigopt.log_metadata('optimizer', args['optimizer'])
    sigopt.log_model("Vanilla Vision Transformer")
    sigopt.log_dataset("Udacity self-driving dataset ")

    sigopt.params.setdefault("learning_rate", args['learning_rate'])
    sigopt.params.setdefault("momentum", args['momentum'])
    sigopt.params.setdefault("reg", args['reg'])
    sigopt.params.setdefault("batch_size", args['batch_size'])
    
    
        # Normalizing images per the paper and resizing each image to 64 x 192.
    transform = transforms.Compose([
        # Citation:
        # https://pytorch.org/vision/stable/transforms.html#scriptable-transforms
        transforms.Resize((64, 192)),
    ])
    # Loading in images with normalization and resizing applied.
    training_set, validation_set, test_set = dataset.load_nvidia_dataset(batch_size=BATCH_SIZE, transform=transform)

    image_size = training_set.dataset[0][0].shape

    model = VanillaVisionTransformer.VisionTransformer(device,
                                                image_size,
                                               PATCH_SIZE, 
                                               HIDDEN_DIM,
                                               EMBED_DIM,
                                               NUM_CHANNELS,
                                               NUM_HEADS,
                                               NUM_LAYERS,
                                               DROPOUT).to(device)
    
    criterion = nn.MSELoss()

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
#     scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    
    best = 1.0
    best_model = None
    for epoch in range(NUM_EPOCHS):
        
        # Training.
        train_loss = train(epoch, training_set, model, optimizer, criterion)
#         scheduler.step(train_loss)
                
        # Validation.
        valid_loss = validate(epoch, validation_set, model, criterion)
        
        if valid_loss < best:
            best = valid_loss
            best_model = copy.deepcopy(model)

    print('Best Loss @1: {:.4f}'.format(best))

    torch.save(best_model.state_dict(), './checkpoints/vanilla_vision_transformer.pth')  
    sigopt.log_metric(name='MSE', value=best)
    return best.cpu().numpy()

In [None]:
args = {
    'epochs': 10,
    'batch_size': 256,
    'warmup' : 0,
    'learning_rate': 0.005,
    'momentum': 0.9,
    'reg': 0.0005
}

In [None]:
conn = sigopt.Connection(client_token=os.environ.get("SIGOPT_API_TOKEN"))
experiment = conn.experiments().create(
    
    name="NVIDIA CNN Optimization",
 
    parameters=[
        dict(name="momentum", bounds=dict(min=0.0, max=1), type="double"),
        dict(name="reg", bounds=dict(min=0.00001, max=1), type="double", transformation="log"),
        dict(name="learning_rate", bounds=dict(min=0.00001 ,max=1), type="double", transformation="log"),
        dict(name="batch_size",  type="categorical", categorical_values: [32, 64, 128, 256, 512, 1024])
        ],
 
    metrics=[
        dict(name="RMSE", objective="minimize", strategy="optimize")
        ],
 
    observation_budget = 5,
)
 
print("Explore your experiment: https://app.sigopt.com/experiment/" + experiment.id + "/analysis")

In [None]:
#Optimization Loop
for _ in range(experiment.observation_budget):
    suggestion = conn.experiments(experiment.id).suggestions().create()
    assignments = suggestion.assignments
    value = evaluate(assignments, args)
 
    conn.experiments(experiment.id).observations().create(
        suggestion=suggestion.id,
        value=value
    )
    
    #update experiment object
    experiment = conn.experiments(experiment.id).fetch()
 
assignments = conn.experiments(experiment.id).best_assignments().fetch().data[0].assignments  
 
print("BEST ASSIGNMENTS FOUND: \n", assignments)