### Imports

In [1]:
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as t
from torch import Tensor
from torch.utils.data import DataLoader

### Download & Preprocess

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DATA_LOCATION = "data.pt"

# if not os.path.exists(HYP['misc']['data_location']):

CIFAR10_MEAN, CIFAR10_STD = [
    torch.tensor(
        [0.4913997551666284, 0.48215855929893703, 0.4465309133731618], device=DEVICE
    ),
    torch.tensor(
        [0.24703225141799082, 0.24348516474564, 0.26158783926049628], device=DEVICE
    ),
]

PREP_TRANSFORM = t.Compose(
    [
        TO_TENSOR := t.ToTensor(),
        NORMALIZE := t.Normalize(CIFAR10_MEAN, CIFAR10_STD),
    ]
)

TRAIN_TRANSFORM = t.Compose([PAD := t.Pad(BORDER := 4), PREP_TRANSFORM])

cifar10 = torchvision.datasets.CIFAR10(
    "cifar10/", download=True, train=True, transform=TRAIN_TRANSFORM
)
cifar10_eval = torchvision.datasets.CIFAR10(
    "cifar10/", download=False, train=False, transform=PREP_TRANSFORM
)


data = {"train": cifar10, "eval": cifar10_eval}
# torch.save(data, HYP['misc']['data_location'])

# else:
#     # This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
#     # So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
#     # HYP dictionary, then we should be good. :)
#     data = torch.load(HYP['misc']['data_location'])

Files already downloaded and verified


### Training loop

In [3]:
########################################
#          Training Helpers            #
########################################


def sgd_optimizer(trainable_parameters, weight_decay: float):
    return torch.optim.SGD(
        trainable_parameters,
        weight_decay,
        **(SGD_DEFAULT_KWARGS := {"nesterov": True, "momentum": 0.9}),
    )


def piecewise_linear_scheduler(
    optimizer: torch.optim.Optimizer, epochs: list[int], learning_rates: list[float]
):
    return torch.optim.lr_scheduler.SequentialLR(
        optimizer,
        schedulers=[
            torch.optim.lr_scheduler.LinearLR(
                optimizer,
                start_factor=learning_rates[i],
                end_factor=learning_rates[i + 1],
            )
            for i in range(len(learning_rates) - 1)
        ],
        milestones=epochs[1:-1],
    )


########################################
#           Train and Eval             #
########################################


def train_test(
    model: nn.Module,
    criterion,
    optimizer: torch.optim.Optimizer,
    scheduler: torch.optim.lr_scheduler._LRScheduler,
    train_set: torchvision.datasets.CIFAR10,
    test_set: torchvision.datasets.CIFAR10,
    num_epochs: int,
    batch_size: int,
    num_workers: int = 0,
):
    train_loader = DataLoader(
        train_set, batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_set, batch_size, shuffle=False, num_workers=num_workers
    )

    for epoch in range(num_epochs):
        print("Epoch {}/{}".format(epoch, num_epochs))

        train_epoch(model, criterion, optimizer, train_loader)
        test_epoch(model, criterion, test_loader)
        scheduler.step()


def train_epoch(
    model: nn.Module,
    criterion: nn.modules.loss._Loss,
    optimizer: torch.optim.Optimizer,
    train_loader: DataLoader,
):
    model.train()

    TRANSFORM = t.Compose(
        [CROP := t.RandomCrop(IMAGE_SIZE := (32, 32)), FLIP := t.RandomVerticalFlip()]
    )

    train_correct = 0
    train_loss = 0

    size = 0
    for i, (batch, targets) in enumerate(train_loader):
        bs = targets.size(0)

        batch = TRANSFORM(batch)
        batch.to(DEVICE)

        output = model(batch)
        loss = criterion(output, targets)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        pred = output.max(1, keepdim=True)[1]
        train_correct += pred.eq(targets.view_as(pred)).sum().item()
        train_loss += loss

        size += bs

        if i % 100 == 10:
            print(
                "{:.2f}% Train - Loss: {:.4f} ; Acc: {:.2f}%".format(
                    100 * i / len(train_loader),
                    train_loss / size,
                    100 * train_correct / size,
                )
            )


def test_epoch(
    model: nn.Module, criterion: nn.modules.loss._Loss, test_loader: DataLoader
):
    model.eval()

    test_correct = 0
    test_loss = 0

    size = 0
    for batch, targets in test_loader:
        bs = targets.size(0)

        output = model(batch)

        loss = criterion(output, targets)
        test_loss += loss

        pred = output.max(1, keepdim=True)[1]
        test_correct += pred.eq(targets.view_as(pred)).sum().item()

        size += bs

    print(
        "Train - Loss: {:.4f} ; Acc: {:.2f}%".format(
            test_loss / size, 100 * test_correct / size
        )
    )

### Network Definition

In [4]:
#############################################
#            Network Components             #
#############################################


class Cat(nn.Module):
    def __init__(self, modules: OrderedDict[str, nn.Module]) -> None:
        super().__init__()

        for name, module in modules.items():
            self.add_module(name, module)

    def forward(self, x: Tensor):
        return torch.cat([module(x) for module in self.children()])


class Add(nn.Module):
    def __init__(self, modules: OrderedDict[str, nn.Module]) -> None:
        super().__init__()

        for name, module in modules.items():
            self.add_module(name, module)

    def forward(self, x: Tensor):
        return sum([module(x) for module in self.children()])


class Id(nn.Module):
    def forward(self, x: Tensor):
        return x


class Flatten(nn.Module):
    def forward(self, x: Tensor):
        return x.view(x.size(0), x.size(1))


class BatchNorm(nn.BatchNorm2d):
    def __init__(
        self,
        num_features,
        weight_requires_grad=True,
        bias_requires_grad=True,
        weights_init=False,
        *args,
        **kwargs,
    ):
        super().__init__(num_features, *args, **kwargs)

        if weights_init:
            self.weight.data.fill_(1.0)
            self.bias.data.fill_(0.0)

        self.weight.requires_grad = weight_requires_grad
        self.bias.requires_grad = bias_requires_grad


class Conv(nn.Conv2d):
    def __init__(self, in_channels, out_channels, padding=None, *args, **kwargs):
        kwargs = {
            **kwargs,
            **(
                DEFAULT_CONV_KWARGS := {
                    "kernel_size": 3,
                    "padding": "same",
                    "bias": False,
                }
            ),
        }

        if padding is not None:
            kwargs["padding"] = padding

        super().__init__(in_channels, out_channels, *args, **kwargs)


class ResBlock(nn.Sequential):
    def __init__(self, c_in: int, c_out: int, stride: int = 1) -> None:
        bn1 = BatchNorm(c_in)
        relu1 = nn.ReLU(inplace=True)

        branch = nn.Sequential(
            OrderedDict(
                [
                    (
                        "conv1",
                        Conv(c_in, c_out, kernel_size=3, stride=stride, padding=1),
                    ),
                    ("bn2", BatchNorm(c_out)),
                    ("relu2", nn.ReLU(inplace=True)),
                    ("conv2", Conv(c_out, c_out, bias=False)),
                ]
            )
        )

        is_projection_needed = (stride != 1) or (c_in != c_out)

        super().__init__(
            OrderedDict(
                [
                    ("bn1", bn1),
                    ("relu1", relu1),
                    (
                        "res",
                        Add(
                            OrderedDict(
                                [
                                    (
                                        (
                                            "conv3",
                                            Conv(
                                                c_in,
                                                c_out,
                                                kernel_size=1,
                                                stride=stride,
                                                padding=1,
                                                bias=False,
                                            ),
                                        )
                                        if is_projection_needed
                                        else ("id", Id())
                                    ),
                                    ("branch", branch),
                                ]
                            )
                        ),
                    ),
                ]
            )
        )


#############################################
#            Network Architercture          #
#############################################


class DawnNet(nn.Sequential):
    def __init__(
        self, c=64, Block=ResBlock, prep_bn_relu=False, concat_pool=False, **kw
    ) -> None:
        if isinstance(c, int):
            c = [c, 2 * c, 4 * c, 4 * c]

        prep = nn.Sequential(
            OrderedDict([("conv", Conv(in_channels=3, out_channels=c[0], bias=False))])
        )

        if prep_bn_relu:
            prep.add_module("bn", BatchNorm(c[0], **kw))
            prep.add_module("relu", nn.ReLU(True))

        classifier_pool = (
            Cat(
                OrderedDict(
                    [("maxpool", nn.MaxPool2d(4)), ("avgpool", nn.AvgPool2d(4))]
                )
            )
            if concat_pool
            else nn.MaxPool2d(4)
        )

        super().__init__(
            OrderedDict(
                [
                    ("prep", prep),
                    (
                        "layer1",
                        nn.Sequential(
                            OrderedDict(
                                [
                                    ("block0", Block(c[0], c[0], **kw)),
                                    ("block1", Block(c[0], c[0], **kw)),
                                ]
                            )
                        ),
                    ),
                    (
                        "layer2",
                        nn.Sequential(
                            OrderedDict(
                                [
                                    ("block0", Block(c[0], c[1], stride=2, **kw)),
                                    ("block1", Block(c[1], c[1], **kw)),
                                ]
                            )
                        ),
                    ),
                    (
                        "layer3",
                        nn.Sequential(
                            OrderedDict(
                                [
                                    ("block0", Block(c[1], c[2], stride=2, **kw)),
                                    ("block1", Block(c[2], c[2], **kw)),
                                ]
                            )
                        ),
                    ),
                    (
                        "layer4",
                        nn.Sequential(
                            OrderedDict(
                                [
                                    ("block0", Block(c[2], c[3], stride=2, **kw)),
                                    ("block1", Block(c[3], c[3], **kw)),
                                ]
                            )
                        ),
                    ),
                    (
                        "final",
                        nn.Sequential(
                            OrderedDict(
                                [
                                    ("pool", classifier_pool),
                                    ("flatten", Flatten()),
                                    (
                                        "linear",
                                        nn.Linear(
                                            (2 * c[3] if concat_pool else c[3]),
                                            10,
                                            bias=True,
                                        ),
                                    ),
                                ]
                            )
                        ),
                    ),
                    ("logits", Id()),
                ]
            )
        )

### [Post 1: Baseline](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_1/) - DAWNbench baseline + no initial bn-relu+ efficient dataloading/augmentation, 1 dataloader process (301s)

In [5]:
net = DawnNet()

criterion = nn.CrossEntropyLoss()
optimizer = sgd_optimizer(TRAINABLE_PARAMETERS:=net.parameters(), weight_decay=5e-4*(BATCHSIZE := 128))
scheduler = piecewise_linear_scheduler(optimizer, 
                                       EPOCHS:=[0, 15, 30, NUM_EPOCHS:=35], 
                                       LR:=[0, 0.1, 0.005, 0])

# print(net.final)
train_test(net, criterion, optimizer, scheduler, data["train"], data["eval"], NUM_EPOCHS, BATCHSIZE, NUM_WORKERS:=1)


Epoch 0
10%
Train - Loss: 0.0211 ; Acc: 9.87%


KeyboardInterrupt: 

### Project-hlb-CIFAR10

#### Parameters

In [None]:
BATCHSIZE = 512
BIAS_SCALER = 32

# To replicate the ~95.77% accuracy in 188 seconds runs, simply change the base_depth from 64->128 and the num_epochs from 10->80
HYP = {
    'opt': {
        'bias_lr':        1.15 * 1.35 * 1. * BIAS_SCALER/BATCHSIZE, # TODO: How we're expressing this information feels somewhat clunky, is there maybe a better way to do this? :'))))
        'non_bias_lr':    1.15 * 1.35 * 1. / BATCHSIZE,
        'bias_decay':     .85 * 4.8e-4 * BATCHSIZE/BIAS_SCALER,
        'non_bias_decay': .85 * 4.8e-4 * BATCHSIZE,
        'scaling_factor': 1./10,
        'percent_start': .2,
    },
    'net': {
        'whitening': {
            'kernel_size': 2,
            'num_examples': 50000,
        },
        'batch_norm_momentum': .8,
        'cutout_size': 0,
        'pad_amount': 3,
        'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
    },
    'misc': {
        'ema': {
            'epochs': 2,
            'decay_base': .986,
            'every_n_steps': 2,
        },
        'train_epochs': 10,
        'device': DEVICE,
        'data_location': DATA_LOCATION,
    }
}

SCALER = 2. ## You can play with this on your own if you want, for the first beta I wanted to keep things simple (for now) and leave it out of the hyperparams dict
DEPTHS = {
    'init':   round(SCALER**-1*HYP['net']['base_depth']), # 64  w/ scaler at base value
    'block1': round(SCALER**1*HYP['net']['base_depth']), # 128 w/ scaler at base value
    'block2': round(SCALER**2*HYP['net']['base_depth']), # 256 w/ scaler at base value
    'block3': round(SCALER**3*HYP['net']['base_depth']), # 512 w/ scaler at base value
    'num_classes': 10
}

#### Init Helper Functions

In [None]:
def get_patches(x, patch_shape=(3, 3), dtype=torch.float32):
    # TODO: Annotate
    c, (h, w) = x.shape[1], patch_shape
    return x.unfold(2,h,1).unfold(3,w,1).transpose(1,3).reshape(-1,c,h,w).to(dtype) # TODO: Annotate?

def get_whitening_parameters(patches):
    # TODO: Let's annotate this, please! :'D / D':
    n,c,h,w = patches.shape
    est_covariance = torch.cov(patches.view(n, c*h*w).t())
    eigenvalues, eigenvectors = torch.linalg.eigh(est_covariance, UPLO='U') # this is the same as saying we want our eigenvectors, with the specification that the matrix be an upper triangular matrix (instead of a lower-triangular matrix)
    return eigenvalues.flip(0).view(-1, 1, 1, 1), eigenvectors.t().reshape(c*h*w,c,h,w).flip(0)

# Run this over the training set to calculate the patch statistics, then set the initial convolution as a non-learnable 'whitening' layer
def init_whitening_conv(layer, train_set=None, num_examples=None, previous_block_data=None, pad_amount=0, freeze=True, whiten_splits=None):
    if train_set is not None and previous_block_data is None:
        if pad_amount > 0:
            previous_block_data = train_set[:num_examples,:,pad_amount:-pad_amount,pad_amount:-pad_amount] # if it's none, we're at the beginning of our network.
        else:
            previous_block_data = train_set[:num_examples,:,:,:]
    if whiten_splits is None:
         previous_block_data_split = [previous_block_data] # list of length 1 so we can reuse the splitting code down below
    else:
         previous_block_data_split = previous_block_data.split(whiten_splits, dim=0)

    eigenvalue_list, eigenvector_list = [], []
    for data_split in previous_block_data_split:
        eigenvalues, eigenvectors = get_whitening_parameters(get_patches(data_split, patch_shape=layer.weight.data.shape[2:])) # center crop to remove padding
        eigenvalue_list.append(eigenvalues)
        eigenvector_list.append(eigenvectors)

    eigenvalues = torch.stack(eigenvalue_list, dim=0).mean(0)
    eigenvectors = torch.stack(eigenvector_list, dim=0).mean(0)
    # for some reason, the eigenvalues and eigenvectors seem to come out all in float32 for this? ! ?! ?!?!?!? :'(((( </3
    set_whitening_conv(layer, eigenvalues.to(dtype=layer.weight.dtype), eigenvectors.to(dtype=layer.weight.dtype), freeze=freeze)
    data = layer(previous_block_data.to(dtype=layer.weight.dtype))
    return data

def set_whitening_conv(conv_layer, eigenvalues, eigenvectors, eps=1e-2, freeze=True):
    shape = conv_layer.weight.data.shape
    conv_layer.weight.data[-eigenvectors.shape[0]:, :, :, :] = (eigenvectors/torch.sqrt(eigenvalues+eps))[-shape[0]:, :, :, :]
    ## We don't want to train this, since this is implicitly whitening over the whole dataset
    ## For more info, see David Page's original blogposts (link in the README.md as of this commit.)
    if freeze: 
        conv_layer.weight.requires_grad = False

#### Network Definition

In [None]:
#############################################
#         (EXTRA) Network Components        #
#############################################

# can hack any changes to each residual group that you want directly in here
class ConvGroup(nn.Module):
    def __init__(self, channels_in, channels_out, residual, short, pool, se):
        super().__init__()
        self.short = short
        self.pool = pool
        self.se = se

        self.residual = residual
        self.channels_in = channels_in
        self.channels_out = channels_out

        self.conv1 = Conv(channels_in, channels_out)
        self.pool1 = nn.MaxPool2d(2)
        self.norm1 = BatchNorm(channels_out)
        self.activ = nn.GELU()          

        if not short:
            self.conv2 = Conv(channels_out, channels_out)
            self.conv3 = Conv(channels_out, channels_out)
            self.norm2 = BatchNorm(channels_out)
            self.norm3 = BatchNorm(channels_out)

            self.se1 = nn.Linear(channels_out, channels_out//16)
            self.se2 = nn.Linear(channels_out//16, channels_out)

    def forward(self, x):
        x = self.conv1(x)
        if self.pool:
            x = self.pool1(x)
        x = self.norm1(x)
        x = self.activ(x)
        if self.short: # layer 2 doesn't necessarily need the residual, so we just return it.
            return x
        residual = x
        if self.se:
            mult = torch.sigmoid(self.se2(self.activ(self.se1(torch.mean(residual, dim=(2,3)))))).unsqueeze(-1).unsqueeze(-1)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.activ(x)
        x = self.conv3(x)
        
        if self.se:
            x = x * mult

        x = self.norm3(x)
        x = self.activ(x)
        x = x + residual # haiku

        return x

# Set to 1 for now just to debug a few things....
class TemperatureScaler(nn.Module):
    def __init__(self, init_val):
        super().__init__()
        self.scaler = torch.tensor(init_val)

    def forward(self, x):
        x.float() ## save precision for the gradients in the backwards pass
                  ## I personally believe from experience that this is important
                  ## for a few reasons. I believe this is the main functional difference between
                  ## my implementation, and David's implementation...
        return x.mul(self.scaler)

class FastGlobalMaxPooling(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x):
        # Previously was chained torch.max calls.
        # requires less time than AdaptiveMax2dPooling -- about ~.3s for the entire run, in fact (which is pretty significant! :O :D :O :O <3 <3 <3 <3)
        return torch.amax(x, dim=(2,3)) # Global maximum pooling
    
#############################################
#            Network Definition             #
#############################################
    
class SpeedyResNet(nn.Module):
    def __init__(self, network_dict):
        super().__init__()
        self.net_dict = network_dict # flexible, defined in the make_net function

    # This allows you to customize/change the execution order of the network as needed.
    def forward(self, x):
        if not self.training:
            x = torch.cat((x, torch.flip(x, (-1,))))
        x = self.net_dict['initial_block']['whiten'](x)
        x = self.net_dict['initial_block']['project'](x)
        x = self.net_dict['initial_block']['norm'](x)
        x = self.net_dict['initial_block']['activation'](x)
        x = self.net_dict['residual1'](x)
        x = self.net_dict['residual2'](x)
        x = self.net_dict['residual3'](x)
        x = self.net_dict['pooling'](x)
        x = self.net_dict['linear'](x)
        x = self.net_dict['temperature'](x)
        if not self.training:
            # Average the predictions from the lr-flipped inputs during eval
            orig, flipped = x.split(x.shape[0]//2, dim=0)
            x = .5 * orig + .5 * flipped
        return x
    
def make_net():
    # TODO: A way to make this cleaner??
    # Note, you have to specify any arguments overlapping with defaults (i.e. everything but in/out depths) as kwargs so that they are properly overridden (TODO cleanup somehow?)
    whiten_conv_depth = 3*HYP['net']['whitening']['kernel_size']**2
    network_dict = nn.ModuleDict({
        'initial_block': nn.ModuleDict({
            'whiten': Conv(3, whiten_conv_depth, kernel_size=HYP['net']['whitening']['kernel_size'], padding=0),
            'project': Conv(whiten_conv_depth, DEPTHS['init'], kernel_size=1),
            'norm': BatchNorm(DEPTHS['init'], weight=False),
            'activation': nn.GELU(),
        }),
        'residual1': ConvGroup(DEPTHS['init'], DEPTHS['block1'], residual=True, short=False, pool=True, se=True),
        'residual2': ConvGroup(DEPTHS['block1'], DEPTHS['block2'], residual=True, short=True, pool=True, se=True),
        'residual3': ConvGroup(DEPTHS['block2'], DEPTHS['block3'], residual=True, short=False, pool=True, se=True),
        'pooling': FastGlobalMaxPooling(),
        'linear': nn.Linear(DEPTHS['block3'], DEPTHS['num_classes'], bias=False),
        'temperature': TemperatureScaler(HYP['opt']['scaling_factor'])
    })

    net = SpeedyResNet(network_dict)
    net = net.to(HYP['misc']['device'])
    net = net.to(memory_format=torch.channels_last) # to appropriately use tensor cores/avoid thrash while training
    net.train()
    net.half() # Convert network to half before initializing the initial whitening layer.

    ## Initialize the whitening convolution
    with torch.no_grad():
        # Initialize the first layer to be fixed weights that whiten the expected input values of the network be on the unit HYPersphere. (i.e. their...average vector length is 1.?, IIRC)
        init_whitening_conv(net.net_dict['initial_block']['whiten'],
                            data['train']['images'].index_select(0, torch.randperm(data['train']['images'].shape[0], device=data['train']['images'].device)),
                            num_examples=HYP['net']['whitening']['num_examples'],
                            pad_amount=HYP['net']['pad_amount'],
                            whiten_splits=5000) ## Hardcoded for now while we figure out the optimal whitening number
                                                ## If you're running out of memory (OOM) feel free to decrease this, but
                                                ## the index lookup in the dataloader may give you some trouble depending
                                                ## upon exactly how memory-limited you are

    return net

#### Data Preprocessing (EXTRA)

In [None]:
## This is actually (I believe) a pretty clean implementation of how to do something like this, since shifted-square masks unique to each depth-channel can actually be rather
## tricky in practice. That said, if there's a better way, please do feel free to submit it! This can be one of the harder parts of the code to understand (though I personally get
## stuck on the fold/unfold process for the lower-level convolution calculations.
def make_random_square_masks(inputs, mask_size):
    ##### TODO: Double check that this properly covers the whole range of values. :'( :')
    if mask_size == 0:
        return None # no need to cutout or do anything like that since the patch_size is set to 0
    is_even = int(mask_size % 2 == 0)
    in_shape = inputs.shape

    # seed centers of squares to cutout boxes from, in one dimension each
    mask_center_y = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-2]-mask_size//2-is_even)
    mask_center_x = torch.empty(in_shape[0], dtype=torch.long, device=inputs.device).random_(mask_size//2-is_even, in_shape[-1]-mask_size//2-is_even)

    # measure distance, using the center as a reference point
    to_mask_y_dists = torch.arange(in_shape[-2], device=inputs.device).view(1, 1, in_shape[-2], 1) - mask_center_y.view(-1, 1, 1, 1)
    to_mask_x_dists = torch.arange(in_shape[-1], device=inputs.device).view(1, 1, 1, in_shape[-1]) - mask_center_x.view(-1, 1, 1, 1)

    to_mask_y = (to_mask_y_dists >= (-(mask_size // 2) + is_even)) * (to_mask_y_dists <= mask_size // 2)
    to_mask_x = (to_mask_x_dists >= (-(mask_size // 2) + is_even)) * (to_mask_x_dists <= mask_size // 2)

    final_mask = to_mask_y * to_mask_x ## Turn (y by 1) and (x by 1) boolean masks into (y by x) masks through multiplication. Their intersection is square, hurray! :D

    return final_mask

def batch_cutout(inputs, patch_size):
    with torch.no_grad():
        cutout_batch_mask = make_random_square_masks(inputs, patch_size)
        if cutout_batch_mask is None:
            return inputs # if the mask is None, then that's because the patch size was set to 0 and we will not be using cutout today.
        # TODO: Could be fused with the crop operation for sheer speeeeeds. :D <3 :))))
        cutout_batch = torch.where(cutout_batch_mask, torch.zeros_like(inputs), inputs)
        return cutout_batch
    
def batch_crop(inputs, crop_size):
    with torch.no_grad():
        crop_mask_batch = make_random_square_masks(inputs, crop_size)
        cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(inputs.shape[0], inputs.shape[1], crop_size, crop_size)
        return cropped_batch

def batch_flip_lr(batch_images, flip_chance=.5):
    with torch.no_grad():
        # TODO: Is there a more elegant way to do this? :') :'((((
        return torch.where(torch.rand_like(batch_images[:, 0, 0, 0].view(-1, 1, 1, 1)) < flip_chance, torch.flip(batch_images, (-1,)), batch_images)


#### Training Helpers 

In [None]:
import copy


class NetworkEMA(nn.Module):
    def __init__(self, net, decay):
        super().__init__() # init the parent module so this module is registered properly
        self.net_ema = copy.deepcopy(net).eval().requires_grad_(False) # copy the model
        self.decay = decay ## you can update/hack this as necessary for update scheduling purposes :3

    def update(self, current_net):
        with torch.no_grad():
            for ema_net_parameter, incoming_net_parameter in zip(self.net_ema.state_dict().values(), current_net.state_dict().values()): # potential bug: assumes that the network architectures don't change during training (!!!!)
                if incoming_net_parameter.dtype in (torch.half, torch.float):
                    ema_net_parameter.mul_(self.decay).add_(incoming_net_parameter.detach().mul(1. - self.decay)) # update the ema values in place, similar to how optimizer momentum is coded

    def forward(self, inputs):
        with torch.no_grad():
            return self.net_ema(inputs)

# TODO: Could we jit this in the (more distant) future? :)
@torch.no_grad()
def get_batches(data_dict, key, batchsize):
    num_epoch_examples = len(data_dict[key]['images'])
    shuffled = torch.randperm(num_epoch_examples, device='cuda')
    crop_size = 32
    ## Here, we prep the dataset by applying all data augmentations in batches ahead of time before each epoch, then we return an iterator below
    ## that iterates in chunks over with a random derangement (i.e. shuffled indices) of the individual examples. So we get perfectly-shuffled
    ## batches (which skip the last batch if it's not a full batch), but everything seems to be (and hopefully is! :D) properly shuffled. :)
    if key == 'train':
        images = batch_crop(data_dict[key]['images'], crop_size) # TODO: hardcoded image size for now?
        images = batch_flip_lr(images)
        images = batch_cutout(images, patch_size=HYP['net']['cutout_size'])
    else:
        images = data_dict[key]['images']

    # Send the images to an (in beta) channels_last to help improve tensor core occupancy (and reduce NCHW <-> NHWC thrash) during training
    images = images.to(memory_format=torch.channels_last)
    for idx in range(num_epoch_examples // batchsize):
        if not (idx+1)*batchsize > num_epoch_examples: ## Use the shuffled randperm to assemble individual items into a minibatch
            yield images.index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]), \
                  data_dict[key]['targets'].index_select(0, shuffled[idx*batchsize:(idx+1)*batchsize]) ## Each item is only used/accessed by the network once per epoch. :D


def init_split_parameter_dictionaries(network):
    params_non_bias = {'params': [], 'lr': HYP['opt']['non_bias_lr'], 'momentum': .85, 'nesterov': True, 'weight_decay': HYP['opt']['non_bias_decay']}
    params_bias     = {'params': [], 'lr': HYP['opt']['bias_lr'],     'momentum': .85, 'nesterov': True, 'weight_decay': HYP['opt']['bias_decay']}

    for name, p in network.named_parameters():
        if p.requires_grad:
            if 'bias' in name:
                params_bias['params'].append(p)
            else:
                params_non_bias['params'].append(p)
    return params_non_bias, params_bias


## Hey look, it's the soft-targets/label-smoothed loss! Native to PyTorch. Now, _that_ is pretty cool, and simplifies things a lot, to boot! :D :)
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.2, reduction='none')

logging_columns_list = ['epoch', 'train_loss', 'val_loss', 'train_acc', 'val_acc', 'ema_val_acc', 'total_time_seconds']
# define the printing function and print the column heads
def print_training_details(columns_list, separator_left='|  ', separator_right='  ', final="|", column_heads_only=False, is_final_entry=False):
    print_string = ""
    if column_heads_only:
        for column_head_name in columns_list:
            print_string += separator_left + column_head_name + separator_right
        print_string += final
        print('-'*(len(print_string))) # print the top bar
        print(print_string)
        print('-'*(len(print_string))) # print the bottom bar
    else:
        for column_value in columns_list:
            print_string += separator_left + column_value + separator_right
        print_string += final
        print(print_string)
    if is_final_entry:
        print('-'*(len(print_string))) # print the final output bar

print_training_details(logging_columns_list, column_heads_only=True) ## print out the training column heads before we print the actual content for each run.


#### Train & Eval

In [None]:
def main():
    # Initializing constants for the whole run.
    net_ema = None ## Reset any existing network emas, we want to have _something_ to check for existence so we can initialize the EMA right from where the network is during training
                   ## (as opposed to initializing the network_ema from the randomly-initialized starter network, then forcing it to play catch-up all of a sudden in the last several epochs)

    total_time_seconds = 0.
    current_steps = 0.
    
    # TODO: Doesn't currently account for partial epochs really (since we're not doing "real" epochs across the whole BATCHSIZE)....
    num_steps_per_epoch      = len(data['train']['images']) // BATCHSIZE
    total_train_steps        = num_steps_per_epoch * HYP['misc']['train_epochs']
    ema_epoch_start          = HYP['misc']['train_epochs'] - HYP['misc']['ema']['epochs']
    num_cooldown_before_freeze_steps = 0
    num_low_lr_steps_for_ema = HYP['misc']['ema']['epochs'] * num_steps_per_epoch

    ## I believe this wasn't logged, but the EMA update power is adjusted by being raised to the power of the number of "every n" steps
    ## to somewhat accomodate for whatever the expected information intake rate is. The tradeoff I believe, though, is that this is to some degree noisier as we
    ## are intaking fewer samples of our distribution-over-time, with a higher individual weight each. This can be good or bad depending upon what we want.
    projected_ema_decay_val  = HYP['misc']['ema']['decay_base'] ** HYP['misc']['ema']['every_n_steps']

    # Adjust pct_start based upon how many epochs we need to finetune the ema at a low lr for
    pct_start = HYP['opt']['percent_start'] * (total_train_steps/(total_train_steps - num_low_lr_steps_for_ema))

    # Get network
    net = make_net()

    ## Stowing the creation of these into a helper function to make things a bit more readable....
    non_bias_params, bias_params = init_split_parameter_dictionaries(net)

    # One optimizer for the regular network, and one for the biases. This allows us to use the superconvergence onecycle training policy for our networks....
    opt = torch.optim.SGD(**non_bias_params)
    opt_bias = torch.optim.SGD(**bias_params)

    #opt = torch.optim.SGD(**non_bias_params)
    #opt_bias = torch.optim.SGD(**bias_params)

    ## Not the most intuitive, but this basically takes us from ~0 to max_lr at the point pct_start, then down to .1 * max_lr at the end (since 1e16 * 1e-15 = .1 --
    ##   This quirk is because the final lr value is calculated from the starting lr value and not from the maximum lr value set during training)
    initial_div_factor = 1e16 # basically to make the initial lr ~0 or so :D
    final_lr_ratio = .135
    lr_sched      = torch.optim.lr_scheduler.OneCycleLR(opt,  max_lr=non_bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)
    lr_sched_bias = torch.optim.lr_scheduler.OneCycleLR(opt_bias, max_lr=bias_params['lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=total_train_steps-num_low_lr_steps_for_ema, anneal_strategy='linear', cycle_momentum=False)

    ## For accurately timing GPU code
    starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    ## There's another repository that's mainly reorganized David's code while still maintaining some of the functional structure, and it
    ## has a timing feature too, but there's no synchronizes so I suspect the times reported are much faster than they may be in actuality
    ## due to some of the quirks of timing GPU operations.
    torch.cuda.synchronize() ## clean up any pre-net setup operations
    

    if True: ## Sometimes we need a conditional/for loop here, this is placed to save the trouble of needing to indent
        for epoch in range(HYP['misc']['train_epochs']):
          #################
          # Training Mode #
          #################
          torch.cuda.synchronize()
          starter.record()
          net.train()

          loss_train = None
          accuracy_train = None

          for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=BATCHSIZE)):
              ## Run everything through the network
              outputs = net(inputs)
              
              loss_scale_scaler = 1./16 # Hardcoded for now, preserves some accuracy during the loss summing process, balancing out its regularization effects
              ## If you want to add other losses or hack around with the loss, you can do that here.
              loss = loss_fn(outputs, targets).mul(loss_scale_scaler).sum().div(loss_scale_scaler) ## Note, as noted in the original blog posts, the summing here does a kind of loss scaling
                                                     ## (and is thus BATCHSIZE dependent as a result). This can be somewhat good or bad, depending...

              # we only take the last-saved accs and losses from train
              if epoch_step % 50 == 0:
                  train_acc = (outputs.detach().argmax(-1) == targets).float().mean().item()
                  train_loss = loss.detach().cpu().item()/BATCHSIZE

              loss.backward()

              ## Step for each optimizer, in turn.
              opt.step()
              opt_bias.step()

              if current_steps < total_train_steps - num_low_lr_steps_for_ema - 1: # the '-1' is because the lr scheduler tends to overshoot (even below 0 if the final lr is ~0) on the last step for some reason.
                  # We only want to step the lr_schedulers while we have training steps to consume. Otherwise we get a not-so-friendly error from PyTorch
                  lr_sched.step()
                  lr_sched_bias.step()

              ## Using 'set_to_none' I believe is slightly faster (albeit riskier w/ funky gradient update workflows) than under the default 'set to zero' method
              opt.zero_grad(set_to_none=True)
              opt_bias.zero_grad(set_to_none=True)
              current_steps += 1

              if epoch >= ema_epoch_start and current_steps % HYP['misc']['ema']['every_n_steps'] == 0:          
                  ## Initialize the ema from the network at this point in time if it does not already exist.... :D
                  if net_ema is None or epoch_step < num_cooldown_before_freeze_steps: # don't snapshot the network yet if so!
                      net_ema = NetworkEMA(net, decay=projected_ema_decay_val)
                      continue
                  net_ema.update(net)
          ender.record()
          torch.cuda.synchronize()
          total_time_seconds += 1e-3 * starter.elapsed_time(ender)

          ####################
          # Evaluation  Mode #
          ####################
          net.eval()

          EVAL_BATCHSIZE = 1000
          assert data['eval']['images'].shape[0] % EVAL_BATCHSIZE == 0, "Error: The eval BATCHSIZE must evenly divide the eval dataset (for now, we don't have drop_remainder implemented yet)."
          loss_list_val, acc_list, acc_list_ema = [], [], []
          
          with torch.no_grad():
              for inputs, targets in get_batches(data, key='eval', batchsize=EVAL_BATCHSIZE):
                  if epoch >= ema_epoch_start:
                      outputs = net_ema(inputs)
                      acc_list_ema.append((outputs.argmax(-1) == targets).float().mean())
                  outputs = net(inputs)
                  loss_list_val.append(loss_fn(outputs, targets).float().mean())
                  acc_list.append((outputs.argmax(-1) == targets).float().mean())
                  
              val_acc = torch.stack(acc_list).mean().item()
              ema_val_acc = None
              # TODO: We can fuse these two operations (just above and below) all-together like :D :))))
              if epoch >= ema_epoch_start:
                  ema_val_acc = torch.stack(acc_list_ema).mean().item()

              val_loss = torch.stack(loss_list_val).mean().item()
          # We basically need to look up local variables by name so we can have the names, so we can pad to the proper column width.
          ## Printing stuff in the terminal can get tricky and this used to use an outside library, but some of the required stuff seemed even
          ## more heinous than this, unfortunately. So we switched to the "more simple" version of this!
          format_for_table = lambda x, locals: (f"{locals[x]}".rjust(len(x))) \
                                                    if type(locals[x]) == int else "{:0.4f}".format(locals[x]).rjust(len(x)) \
                                                if locals[x] is not None \
                                                else " "*len(x)

          # Print out our training details (sorry for the complexity, the whole logging business here is a bit of a hot mess once the columns need to be aligned and such....)
          ## We also check to see if we're in our final epoch so we can print the 'bottom' of the table for each round.
          print_training_details(list(map(partial(format_for_table, locals=locals()), logging_columns_list)), is_final_entry=(epoch == HYP['misc']['train_epochs'] - 1))
    return ema_val_acc # Return the final ema accuracy achieved (not using the 'best accuracy' selection strategy, which I think is okay here....)

if __name__ == "__main__":
    acc_list = []
    for run_num in range(25):
        acc_list.append(torch.tensor(main()))
    print("Mean and variance:", (torch.mean(torch.stack(acc_list)).item(), torch.var(torch.stack(acc_list)).item()))