In [2]:
from collections import OrderedDict
from functools import partial
import os

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import DataLoader

import torchvision 
import torchvision.transforms as t


In [3]:
DEFAULT_CONV_KWARGS = {'kernel_size': 3, 'padding': 'same', 'bias': False}

BATCHSIZE = 512
BIAS_SCALER = 32

# To replicate the ~95.77% accuracy in 188 seconds runs, simply change the base_depth from 64->128 and the num_epochs from 10->80
HYP = {
    'opt': {
        'bias_lr':        1.15 * 1.35 * 1. * BIAS_SCALER/BATCHSIZE, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :'))))
        'non_bias_lr':    1.15 * 1.35 * 1. / BATCHSIZE,
        'bias_decay':     .85 * 4.8e-4 * BATCHSIZE/BIAS_SCALER,
        'non_bias_decay': .85 * 4.8e-4 * BATCHSIZE,
        'scaling_factor': 1./10,
        'percent_start': .2,
    },
    'net': {
        'whitening': {
            'kernel_size': 2,
            'num_examples': 50000,
        },
        'batch_norm_momentum': .8,
        'cutout_size': 0,
        'pad_amount': 3,
        'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
    },
    'misc': {
        'ema': {
            'epochs': 2,
            'decay_base': .986,
            'every_n_steps': 2,
        },
        'train_epochs': 10,
        'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
        'data_location': 'data.pt',
    }
}

SCALER = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
DEPTHS = {
    'init':   round(SCALER**-1*HYP['net']['base_depth']), # 64  w/ scaler at base value
    'block1': round(SCALER**1*HYP['net']['base_depth']), # 128 w/ scaler at base value
    'block2': round(SCALER**2*HYP['net']['base_depth']), # 256 w/ scaler at base value
    'block3': round(SCALER**3*HYP['net']['base_depth']), # 512 w/ scaler at base value
    'num_classes': 10
}


In [4]:
if not os.path.exists(HYP['misc']['data_location']):

        CIFAR10_MEAN, CIFAR10_STD = [
            torch.tensor([0.4913997551666284, 0.48215855929893703,
                         0.4465309133731618], device=HYP['misc']['device']),
            torch.tensor([0.24703225141799082, 0.24348516474564,
                         0.26158783926049628],  device=HYP['misc']['device'])
        ]

        PREP_TRANSFORM = t.Compose([TO_TENSOR := t.ToTensor(),
                                    NORMALIZE := t.Normalize(CIFAR10_MEAN, CIFAR10_STD),
                                    FLATTEN := t.Lambda(lambda x:x.view(-1)),
                                    PAD := t.Pad(BORDER := 4)])

        TRAIN_TRANSFORM = t.Compose([PREP_TRANSFORM,
                                     CROP := t.CenterCrop(IMAGE_SIZE := (32, 32)),
                                     FLIP := t.RandomVerticalFlip()
                                     ])

        cifar10 = torchvision.datasets.CIFAR10(
            'cifar10/', download=True,  train=True,  transform=TRAIN_TRANSFORM)
        cifar10_eval = torchvision.datasets.CIFAR10(
            'cifar10/', download=False, train=False, transform=PREP_TRANSFORM)

        # use   the dataloader to get a single batch of all of the dataset items at once.
        train_dataset_gpu_loader = DataLoader(cifar10, batch_size=len(cifar10), drop_last=True, persistent_workers=False)

        eval_dataset_gpu_loader = DataLoader(cifar10_eval, batch_size=len(cifar10_eval), drop_last=True, persistent_workers=False)

                                                 
        train_dataset_gpu = zip([item.to(device=HYP['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))])
        eval_dataset_gpu  = zip([item.to(device=HYP['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader))])

        data = {
            'train': train_dataset_gpu,
            'eval': eval_dataset_gpu
        }

        torch.save(data, HYP['misc']['data_location'])

else:
    ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
    ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
    ## HYP dictionary, then we should be good. :)
    data = torch.load(HYP['misc']['data_location'])


## As you'll note above and below, one difference is that we don't count loading the raw data to GPU since it's such a variable operation, and can sort of get in the way
## of measuring other things. That said, measuring the preprocessing (outside of the PADding) is still important to us.


Files already downloaded and verified


TypeError: Tensor is not a torch image.

In [None]:
def zeros_like(weights):
    return [torch.zeros_like(w) for w in weights]

def nesterov_update(w, dw, v, lr, weight_decay, momentum):
    dw.add_(weight_decay, w).mul_(-lr)
    v.mul_(momentum).add_(dw)
    w.add_(dw.add_(momentum, v))

def optimiser(weights, param_schedule, update, state_init):
    weights = list(weights)
    return {'update': update, 'param_schedule': param_schedule, 'step_number': 0, 'weights': weights,  'opt_state': state_init(weights)}

SGD = partial(optimiser, update=nesterov_update, state_init=zeros_like)

def train(model, lr_schedule, train_set, test_set, batch_size, num_workers=0):
    train_batches = DataLoader(train_set, batch_size, shuffle=True, set_random_choices=True, num_workers=num_workers)
    test_batches = DataLoader(test_set, batch_size, shuffle=False, num_workers=num_workers)
    
    lr = lambda step: lr_schedule(step/len(train_batches))/batch_size
    opts = [SGD(trainable_params(model).values(), {'lr': lr, 'weight_decay': (5e-4*batch_size), 'momentum': 0.9})]
    logs, state = Table(), {MODEL: model, LOSS: x_ent_loss, OPTS: opts}
    for epoch in range(lr_schedule.knots[-1]):
        logs.append({**{'epoch': epoch+1, 'lr': lr_schedule(epoch+1)}, 
                          **train_epoch(state, Timer(torch.cuda.synchronize), train_batches, test_batches)})
    return logs

def train_epoch(state, timer, train_batches, valid_batches, train_steps=default_train_steps, valid_steps=default_valid_steps, 
            on_epoch_end=(lambda state: state)):
    train_summary, train_time = epoch_stats(on_epoch_end(reduce(train_batches, state, train_steps))), timer()
    valid_summary, valid_time = epoch_stats(reduce(valid_batches, state, valid_steps)), timer(include_in_total=False) #DAWNBench rules
    return {
        'train': {**{'time': train_time}, **train_summary}, 
        'valid': {**{'time': valid_time}, **valid_summary}, 
        'total time': timer.total_time
    }

In [5]:
class Cat(nn.Module):
    def __init__(self, modules: OrderedDict[str, nn.Module]) -> None:
        super().__init__()

        for name, module in modules.items():
            setattr(Cat, name, module)

    def forward(self, x: Tensor):
        return torch.cat([module(x) for _, module in self.modules()]) # type: ignore

class Add(nn.Module):
    def __init__(self, modules: OrderedDict[str, nn.Module]) -> None:
        super().__init__()

        for name, module in modules.items():
            setattr(Cat, name, module)

    def forward(self, x: Tensor):
        return sum([module(x) for _, module in self.modules()]) # type: ignore

class Id(nn.Module):
    def forward(self, x: Tensor): return x

class Flatten(nn.Module):
    def forward(self, x: Tensor): return x.view(x.size(0), x.size(1))


In [6]:
class BatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, weight_requires_grad=True, bias_requires_grad=True, weights_init=False, *args, **kwargs):

        super().__init__(num_features, *args, **kwargs)

        if weights_init:
            self.weight.data.fill_(1.0)
            self.bias.data.fill_(0.0)

        self.weight.requires_grad=weight_requires_grad
        self.bias.requires_grad=bias_requires_grad

# Allows us to set default arguments for the whole convolution itself.
class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        kwargs = {**kwargs, **DEFAULT_CONV_KWARGS}
        super().__init__(in_channels, out_channels, *args, **kwargs)


In [7]:
class ResBlock(nn.Sequential):
    def __init__(self, c_in: int, c_out: int, stride: int=1) -> None:

        bn1 = BatchNorm(c_in)
        relu1 = nn.ReLU(inplace = True)

        branch = nn.Sequential(OrderedDict([
            ('conv1', Conv(c_in, c_out, kernel_size=3,
                            stride=stride, padding=1)),
            ('bn2', BatchNorm(c_out)),
            ('relu2', nn.ReLU(inplace=True)),
            ('conv2', Conv(c_out, c_out, bias=False)),]))

        super().__init__(
        bn1,
        relu1,
        Add(OrderedDict([
            (('conv3', Conv(
                c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False)) if (projection := (stride != 1) or (c_in != c_out)) else ("id", Id())),
            ("branch", branch)
        ])),
        )




class DawnNet(nn.Sequential):

    def __init__(self, c = 64, Block = ResBlock, prep_bn_relu = False, concat_pool = True, **kw) -> None:

        if isinstance(c, int):
            c=[c, 2*c, 4*c, 4*c]

        classifier_pool = Cat(OrderedDict([('maxpool', nn.MaxPool2d(4)),
                                           (('avgpool', (nn.AvgPool2d(4), ['in'])) if concat_pool else ('pool', nn.MaxPool2d(4)))
                                           ]))

        super().__init__(OrderedDict([
            ('input', (None, [])),
            ('prep', nn.Sequential(OrderedDict([
                ('conv', Conv(3, c[0], bias=False)),
                ('bn', BatchNorm(c[0], **kw)),
                ('relu', nn.ReLU(True) if prep_bn_relu else None)
            ]))),
            ('layer1', nn.Sequential(OrderedDict([
                ('block0', Block(c[0], c[0], **kw)),
                ('block1', Block(c[0], c[0], **kw))
            ]))),
            ('layer2', nn.Sequential(OrderedDict([
                ('block0', Block(c[0], c[1], stride=2, **kw)),
                ('block1', Block(c[1], c[1], **kw))
            ]))),
            ('layer3', nn.Sequential(OrderedDict([
                ('block0', Block(c[1], c[2], stride=2, **kw)),
                ('block1', Block(c[2], c[2], **kw))
            ]))),
            ('layer4', nn.Sequential(OrderedDict([
                ('block0', Block(c[2], c[3], stride=2, **kw)),
                ('block1', Block(c[3], c[3], **kw))
            ]))),
            ('final', nn.Sequential(OrderedDict([
                ('pool', classifier_pool),
                ('flatten', Flatten()),
                ('linear', nn.Linear((2*c[3] if concat_pool else c[3]), 10, bias=True))
            ]))),
            ('logits', Id()),
            ]))


In [8]:
net = DawnNet()
print(net)

ValueError: padding='same' is not supported for strided convolutions

In [None]:
# can hack any changes to each residual group that you want directly in here
class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out, residual, short, pool, se):
        super().__init__()
        self.short = short
        self.pool = pool
        self.se = se

        self.residual = residual
        self.channels_in = channels_in
        self.channels_out = channels_out

        self.conv1 = Conv(channels_in, channels_out)
        self.pool1 = nn.MaxPool2d(2)
        self.norm1 = BatchNorm(channels_out)
        self.activ = nn.GELU()          

        if not short:
            self.conv2 = Conv(channels_out, channels_out)
            self.conv3 = Conv(channels_out, channels_out)
            self.norm2 = BatchNorm(channels_out)
            self.norm3 = BatchNorm(channels_out)

            self.se1 = nn.Linear(channels_out, channels_out//16)
            self.se2 = nn.Linear(channels_out//16, channels_out)

    def forward(self, x):
        x = self.conv1(x)
        if self.pool:
            x = self.pool1(x)
        x = self.norm1(x)
        x = self.activ(x)
        if self.short: # layer 2 doesn't necessarily need the residual, so we just return it.
            return x
        residual = x
        if self.se:
            mult = torch.sigmoid(self.se2(self.activ(self.se1(torch.mean(residual, dim=(2,3)))))).unsqueeze(-1).unsqueeze(-1)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activ(x)
        x = self.conv3(x)
        
        if self.se:
            x = x * mult

        x = self.norm3(x)
        x = self.activ(x)
        x = x + residual # haiku

        return x

# Set to 1 for now just to debug a few things....
class TemperatureScaler(nn.Module):
    def __init__(self, init_val):
        super().__init__()
        self.scaler = torch.tensor(init_val)

    def forward(self, x):
        x.float() ## save precision for the gradients in the backwards pass
                  ## I personally believe from experience that this is important
                  ## for a few reasons. I believe this is the main functional difference between
                  ## my implementation, and David's implementation...
        return x.mul(self.scaler)

class FastGlobalMaxPooling(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        # Previously was chained torch.max calls.
        # requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3)
        return torch.amax(x, dim=(2,3)) # Global maximum pooling

In [None]:
def make_net():
    # TODO: A way to make this cleaner??
    # Note, you have to specify any arguments overlapping with defaults (i.e. everything but in/out depths) as kwargs so that they are properly overridden (TODO cleanup somehow?)
    whiten_conv_depth = 3*HYP['net']['whitening']['kernel_size']**2
    network_dict = nn.ModuleDict({
        'initial_block': nn.ModuleDict({
            'whiten': Conv(3, whiten_conv_depth, kernel_size=HYP['net']['whitening']['kernel_size'], padding=0),
            'project': Conv(whiten_conv_depth, DEPTHS['init'], kernel_size=1),
            'norm': BatchNorm(DEPTHS['init'], weight=False),
            'activation': nn.GELU(),
        }),
        'residual1': ConvGroup(DEPTHS['init'], DEPTHS['block1'], residual=True, short=False, pool=True, se=True),
        'residual2': ConvGroup(DEPTHS['block1'], DEPTHS['block2'], residual=True, short=True, pool=True, se=True),
        'residual3': ConvGroup(DEPTHS['block2'], DEPTHS['block3'], residual=True, short=False, pool=True, se=True),
        'pooling': FastGlobalMaxPooling(),
        'linear': nn.Linear(DEPTHS['block3'], DEPTHS['num_classes'], bias=False),
        'temperature': TemperatureScaler(HYP['opt']['scaling_factor'])
    })

    net = SpeedyResNet(network_dict)
    net = net.to(HYP['misc']['device'])
    net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
    net.train()
    net.half() # Convert network to half before initializing the initial whitening layer.

    ## Initialize the whitening convolution
    with torch.no_grad():
        # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit HYPersphere. (i.e. their...average vector length is 1.?, IIRC)
        init_whitening_conv(net.net_dict['initial_block']['whiten'],
                            data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)),
                            num_examples=HYP['net']['whitening']['num_examples'],
                            pad_amount=HYP['net']['pad_amount'],
                            whiten_splits=5000) ## Hardcoded for now while we figure out the optimal whitening number
                                                ## If you're running out of memory (OOM) feel free to decrease this, but
                                                ## the index lookup in the dataloader may give you some trouble depending
                                                ## upon exactly how memory-limited you are

    return net

In [None]:
class SpeedyResNet(nn.Module):
    def __init__(self, network_dict):
        super().__init__()
        self.net_dict = network_dict # flexible, defined in the make_net function

    # This allows you to customize/change the execution order of the network as needed.
    def forward(self, x):
        if not self.training:
            x = torch.cat((x, torch.flip(x, (-1,))))
        x = self.net_dict['initial_block']['whiten'](x)
        x = self.net_dict['initial_block']['project'](x)
        x = self.net_dict['initial_block']['norm'](x)
        x = self.net_dict['initial_block']['activation'](x)
        x = self.net_dict['residual1'](x)
        x = self.net_dict['residual2'](x)
        x = self.net_dict['residual3'](x)
        x = self.net_dict['pooling'](x)
        x = self.net_dict['linear'](x)
        x = self.net_dict['temperature'](x)
        if not self.training:
            # Average the predictions from the lr-flipped inputs during eval
            orig, flipped = x.split(x.shape[0]//2, dim=0)
            x = .5 * orig + .5 * flipped
        return x

In [None]:
scaler = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
depths = {
    'init':   round(scaler**-1*hyp['net']['base_depth']), # 64  w/ scaler at base value
    'block1': round(scaler**1*hyp['net']['base_depth']), # 128 w/ scaler at base value
    'block2': round(scaler**2*hyp['net']['base_depth']), # 256 w/ scaler at base value
    'block3': round(scaler**3*hyp['net']['base_depth']), # 512 w/ scaler at base value
    'num_classes': 10
}



def make_net():
    # TODO: A way to make this cleaner??
    # Note, you have to specify any arguments overlapping with defaults (i.e. everything but in/out depths) as kwargs so that they are properly overridden (TODO cleanup somehow?)
    whiten_conv_depth = 3*hyp['net']['whitening']['kernel_size']**2
    network_dict = nn.ModuleDict({
        'initial_block': nn.ModuleDict({
            'whiten': Conv(3, whiten_conv_depth, kernel_size=hyp['net']['whitening']['kernel_size'], padding=0),
            'project': Conv(whiten_conv_depth, depths['init'], kernel_size=1),
            'norm': BatchNorm(depths['init'], weight=False),
            'activation': nn.GELU(),
        }),
        'residual1': ConvGroup(depths['init'], depths['block1'], residual=True, short=False, pool=True, se=True),
        'residual2': ConvGroup(depths['block1'], depths['block2'], residual=True, short=True, pool=True, se=True),
        'residual3': ConvGroup(depths['block2'], depths['block3'], residual=True, short=False, pool=True, se=True),
        'pooling': FastGlobalMaxPooling(),
        'linear': nn.Linear(depths['block3'], depths['num_classes'], bias=False),
        'temperature': TemperatureScaler(hyp['opt']['scaling_factor'])
    })

    net = SpeedyResNet(network_dict)
    net = net.to(hyp['misc']['device'])
    net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
    net.train()
    net.half() # Convert network to half before initializing the initial whitening layer.

    ## Initialize the whitening convolution
    with torch.no_grad():
        # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit hypersphere. (i.e. their...average vector length is 1.?, IIRC)
        init_whitening_conv(net.net_dict['initial_block']['whiten'],
                            data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)),
                            num_examples=hyp['net']['whitening']['num_examples'],
                            pad_amount=hyp['net']['pad_amount'],
                            whiten_splits=5000) ## Hardcoded for now while we figure out the optimal whitening number
                                                ## If you're running out of memory (OOM) feel free to decrease this, but
                                                ## the index lookup in the dataloader may give you some trouble depending
                                                ## upon exactly how memory-limited you are

    return net