In [5]:
import math
import os
import copy

import torch
import torch.nn.functional as F
from torch import nn

import torchvision
from torchvision import transforms

lt_conv_kwargs = {'kernel_size': 3, 'padding': 'same', 'bias': False}

batchsize = 1024
bias_scaler = 56

hyp = {
    'opt': {
        'bias_lr':        1.64 * bias_scaler/512, # TODO: Is there maybe a better way to express the bias and batchnorm scaling? :'))))
        'non_bias_lr':    1.64 / 512,
        'bias_decay':     1.08 * 6.45e-4 * batchsize/bias_scaler,
        'non_bias_decay': 1.08 * 6.45e-4 * batchsize,
        'scaling_factor': 1./9,
        'percent_start': .23,
        'loss_scale_scaler': 1./128, # * Regularizer inside the loss summing (range: ~1/512 - 16+). FP8 should help with this somewhat too, whenever it comes out. :)
    },
    'net': {
        'whitening': {
            'kernel_size': 2,
            'num_examples': 50000,
        },
        'batch_norm_momentum': .5, # * Don't forget momentum is 1 - momentum here (due to a quirk in the original paper... >:( )
        'conv_norm_pow': 2.6,
        'cutmix_size': 3,
        'cutmix_epochs': 6,
        'pad_amount': 2,
        'base_depth': 64 ## This should be a factor of 8 in some way to stay tensor core friendly
    },
    'misc': {
        'ema': {
            'epochs': 10, # Slight bug in that this counts only full epochs and then additionally runs the EMA for any fractional epochs at the end too
            'decay_base': .95,
            'decay_pow': 3.,
            'every_n_steps': 5,
        },
        'train_epochs': 12.1,
        'device': 'cpu',
        'data_location': 'data.pt',
    }
}


if not os.path.exists(hyp['misc']['data_location']):

        transform = transforms.Compose([
            transforms.ToTensor()])

        cifar10      = torchvision.datasets.CIFAR10('cifar10/', download=True,  train=True,  transform=transform)
        cifar10_eval = torchvision.datasets.CIFAR10('cifar10/', download=False, train=False, transform=transform)

        # use the dataloader to get a single batch of all of the dataset items at once.
        train_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10, batch_size=len(cifar10), drop_last=True,
                                                  shuffle=True, num_workers=2, persistent_workers=False)
        eval_dataset_gpu_loader = torch.utils.data.DataLoader(cifar10_eval, batch_size=len(cifar10_eval), drop_last=True,
                                                  shuffle=False, num_workers=1, persistent_workers=False)

        train_dataset_gpu = {}
        eval_dataset_gpu = {}

        train_dataset_gpu['images'], train_dataset_gpu['targets'] = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(train_dataset_gpu_loader))]
        eval_dataset_gpu['images'],  eval_dataset_gpu['targets']  = [item.to(device=hyp['misc']['device'], non_blocking=True) for item in next(iter(eval_dataset_gpu_loader)) ]

        cifar10_std, cifar10_mean = torch.std_mean(train_dataset_gpu['images'], dim=(0, 2, 3)) # dynamically calculate the std and mean from the data. this shortens the code and should help us adapt to new datasets!

        def batch_normalize_images(input_images, mean, std):
            return (input_images - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)

        # preload with our mean and std
        batch_normalize_images = partial(batch_normalize_images, mean=cifar10_mean, std=cifar10_std)

        ## Batch normalize datasets, now. Wowie. We did it! We should take a break and make some tea now.
        train_dataset_gpu['images'] = batch_normalize_images(train_dataset_gpu['images'])
        eval_dataset_gpu['images']  = batch_normalize_images(eval_dataset_gpu['images'])

        data = {
            'train': train_dataset_gpu,
            'eval': eval_dataset_gpu
        }

        ## Convert dataset to FP16 now for the rest of the process....
        data['train']['images'] = data['train']['images'].half().requires_grad_(False)
        data['eval']['images']  = data['eval']['images'].half().requires_grad_(False)

        # Convert this to one-hot to support the usage of cutmix (or whatever strange label tricks/magic you desire!)
        data['train']['targets'] = F.one_hot(data['train']['targets']).half()
        data['eval']['targets'] = F.one_hot(data['eval']['targets']).half()

        torch.save(data, hyp['misc']['data_location'])

else:
    ## This is effectively instantaneous, and takes us practically straight to where the dataloader-loaded dataset would be. :)
    ## So as long as you run the above loading process once, and keep the file on the disc it's specified by default in the above
    ## hyp dictionary, then we should be good. :)
    data = torch.load(hyp['misc']['data_location'])



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to cifar10/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [03:25<00:00, 828604.75it/s] 


Extracting cifar10/cifar-10-python.tar.gz to cifar10/


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9f59e4e200>
Traceback (most recent call last):
  File "/home/jkr/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/home/jkr/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1461, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception in thread QueueFeederThread:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/queues.py", line 239, in _feed
    reader_close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 177, in close
    self._close()
  File "/usr/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad fil

RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx