# CIFAR10 experiment prototype

In [None]:
import torch
import numpy as np

In [None]:
import json

from pkg_resources import resource_stream

with resource_stream("cplxpaper.mnist", "template.json") as fin:
    options = json.load(fin)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from cplxpaper.auto import auto

In [None]:
options["device"] = "cuda:1"

In [None]:
options["datasets"] = {
    'cifar-train': {
        'cls': "<class 'cplxpaper.cifar.dataset.CIFAR10Train'>", 'root': './data'
    },
    'cifar-test': {
        'cls': "<class 'cplxpaper.cifar.dataset.CIFAR10Test'>", 'root': './data'
    }
}

In [None]:
options["feeds"] = {
    'train': {
        'cls': "<class 'torch.utils.data.dataloader.DataLoader'>",
        'dataset': 'cifar-train',  # 'cifar-test',
        'batch_size': 128,
        'shuffle': True,
        'pin_memory': True,
        'n_batches': -1
    },
    'test': {
        'cls': "<class 'torch.utils.data.dataloader.DataLoader'>",
        'dataset': 'cifar-test',  # 'cifar-train',
        'batch_size': 128,
        'shuffle': False,
        'pin_memory': True,
        'n_batches': -1
    }
}

In [None]:
options["features"] = {
    'cls': "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
    'signal_ndim': 2,
    'cplx': True,
    'shift': True,
#     'cls': "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>"
}

Model taken from [this tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
import torch
from collections import OrderedDict

from cplxmodule.nn import CplxToCplx
from cplxmodule.nn import CplxConv2d, CplxLinear

from cplxmodule.nn.layers import CplxReal
# from cplxmodule.nn.activation import CplxReal
from cplxmodule.nn.layers import ConcatenatedRealToCplx
from cplxmodule.nn.layers import CplxToConcatenatedReal
# from cplxmodule.nn.relevance import CplxConv2dARD, CplxLinearARD
from cplxmodule.nn.relevance.extensions import CplxLinearVDBogus as CplxLinearARD
from cplxmodule.nn.relevance.extensions import CplxConv2dVDBogus as CplxConv2dARD
from cplxmodule.nn.masked import CplxConv2dMasked, CplxLinearMasked


class CIFAR10Model(torch.nn.Sequential):
    Linear = CplxLinear
    Conv2d = CplxConv2d

    def __init__(self):
        layers = [
            ("cplx", ConcatenatedRealToCplx(copy=False, dim=-3)),

            ("conv1", self.Conv2d( 3, 32, 3, 1)),
            ("relu1", CplxToCplx[torch.nn.ReLU]()),
            ("pool1", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("conv2", self.Conv2d(32, 64, 3, 1)),
            ("relu2", CplxToCplx[torch.nn.ReLU]()),
            ("pool2", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("flat_", CplxToCplx[torch.nn.Flatten](-3, -1)),
            ("lin_1", self.Linear(6 * 6 * 64, 2048)),
            ("relu3", CplxToCplx[torch.nn.ReLU]()),
            ("lin_2", self.Linear(2048, 10)),
            ("real", CplxReal()),
            # ("real", CplxToConcatenatedReal(dim=-1)),
            # ("lin_3", torch.nn.Linear(20, 10)),
        ]

        super().__init__(OrderedDict(layers))


class CIFAR10ModelARD(CIFAR10Model):
    Linear = CplxLinearARD
    Conv2d = CplxConv2dARD


class CIFAR10ModelMasked(CIFAR10Model):
    Linear = CplxLinearMasked
    Conv2d = CplxConv2dMasked


In [None]:
options["model"]["cls"] = "<class '__main__.CIFAR10Model'>"
options["stages"]['sparsify']["model"]["cls"] = "<class '__main__.CIFAR10ModelARD'>"
options["stages"]['fine-tune']["model"]["cls"] = "<class '__main__.CIFAR10ModelMasked'>"

VGG by [kuangliu](https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py)

In [None]:
from cplxmodule.nn.batchnorm import CplxBatchNorm2d

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}


class CplxVGG(torch.nn.Module):
    Linear = CplxLinear
    Conv2d = CplxConv2d

    def __init__(self, vgg_name='VGG16'):
        super().__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.flatten = CplxToCplx[torch.nn.Flatten](-3, -1)
        self.classifier = torch.nn.Sequential(
            self.Linear(512, 10),
            CplxReal()
        )

    def forward(self, x):
        out = self.features(x)
        out = self.flatten(out)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [CplxToCplx[torch.nn.MaxPool2d](kernel_size=2, stride=2)]
            else:
                layers += [self.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           CplxBatchNorm2d(x),
                           CplxToCplx[torch.nn.ReLU]()]
                in_channels = x
        layers += [CplxToCplx[torch.nn.AvgPool2d](kernel_size=1, stride=1)]
        return torch.nn.Sequential(
            ConcatenatedRealToCplx(copy=False, dim=-3),
            *layers
        )

class CplxVGGARD(CplxVGG):
    Linear = CplxLinearARD
    Conv2d = CplxConv2dARD

class CplxVGGMasked(CplxVGG):
    Linear = CplxLinearMasked
    Conv2d = CplxConv2dMasked

In [None]:
options["model"]["cls"] = "<class '__main__.CplxVGG'>"
options["stages"]['sparsify']["model"]["cls"] = "<class '__main__.CplxVGGARD'>"
options["stages"]['fine-tune']["model"]["cls"] = "<class '__main__.CplxVGGMasked'>"

In [None]:
options["stages"]['fine-tune']['reset'] = False

options["stages"]['sparsify']['objective']['kl_div'] = 1e-1

options["stages"]['dense']["optimizer"]["weight_decay"] = 5e-4
options["stages"]['sparsify']["optimizer"]["weight_decay"] = 5e-4
options["stages"]['fine-tune']["optimizer"]["weight_decay"] = 5e-4

In [None]:
options

In [None]:
auto.run(options, './test', "cplx-cifar-vgg16")

In [None]:
from cplxpaper.auto.utils import load_snapshot

In [None]:
load_snapshot('./test/0-dense cplx-cifar.gz')["performance"]

In [None]:
load_snapshot('./test/2-fine-tune cplx-cifar.gz')["performance"]

In [None]:
losses = []
cold = load_snapshot('./test/0-dense cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/1-sparsify cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/2-fine-tune cplx-cifar.gz')
losses.append(cold['history']["loss"])

plt.semilogy(np.concatenate(losses))

In [None]:
assert False

<br>