# CIFAR10 experiment prototype

If you plan to use DenseNet you'd better read [the paper](https://arxiv.org/abs/1608.06993.pdf)
* it seems that using `ADAM` is a recipe for a disaster
* it is better to use `SGD` with a scheduler.

[Also this medium post](https://medium.com/@wwwbbb8510/lessons-learned-from-reproducing-resnet-and-densenet-on-cifar-10-dataset-6e25b03328da)

In [None]:
import torch
import numpy as np

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import json
from pkg_resources import resource_filename

filename = resource_filename("cplxpaper.cifar.models.vgg", "template.json")

with open(filename, "r") as fin:
    manifest = json.load(fin)

In [None]:
manifest["datasets"] = {
    'train': {
        'cls': "<class 'cplxpaper.cifar.dataset.AugmentedCIFAR10_Train'>",
        'root': '/home/ivan.nazarov/Github/complex_paper/experiments/cifar/data'
    }, 'test': {
        'cls': "<class 'cplxpaper.cifar.dataset.AugmentedCIFAR10_Test'>",
        'root': '/home/ivan.nazarov/Github/complex_paper/experiments/cifar/data'
    }
}

In [None]:
manifest["features"] = {
    'cls': "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>"
#     "cls": "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
#     "signal_ndim": 2,
#     "shift": True,
#     "cplx": True,
}

manifest["model"] = {
    "cls": "<class 'cplxpaper.cifar.models.vgg.complex.VGG'>",
    "vgg_name": 'VGG16',
    "n_outputs": 10,
    "upcast": True,
    "half": True
}

In [None]:
manifest["stages"]["dense"]["model"] = {
    "cls": "<class 'cplxpaper.cifar.models.vgg.complex.VGG'>"
}


manifest["stages"]["sparsify"]["model"] = {
    "cls": "<class 'cplxpaper.cifar.models.vgg.complex.VGGARD'>"
}

manifest["stages"]["fine-tune"]["model"] = {
    "cls": "<class 'cplxpaper.cifar.models.vgg.complex.VGGMasked'>"
}

In [None]:
manifest["stages"]["dense"]["lr_scheduler"] = {
    'cls': "<class 'cplxpaper.musicnet.lr_scheduler.FastStepScheduler'>"
}

In [None]:
manifest["objective_terms"]["kl_div"]["coef"] = 2e-5

In [None]:
manifest["stages"]["dense"]["n_epochs"] = 1  # 20
manifest["stages"]["sparsify"]["n_epochs"] = 2  # 40
manifest["stages"]["fine-tune"]["n_epochs"] = 1  # 20

In [None]:
manifest["stages"]["dense"]["grad_clip"] = 0.5
manifest["stages"]["sparsify"]["grad_clip"] = 0.5
manifest["stages"]["fine-tune"]["grad_clip"] = 0.5

In [None]:
manifest["device"] = "cuda:3"

<br>

In [None]:
from cplxpaper.auto import auto

In [None]:
auto.run(manifest, './test_vgg', "cplx-cifar-vgg16-test")

In [None]:
assert False

In [None]:
devtype = dict(device=torch.device(manifest["device"]))
datasets = auto.get_datasets(manifest["datasets"])
feeds = auto.get_feeds(datasets, devtype, manifest["features"], manifest["feeds"])

In [None]:
datasets["train"]

In [None]:
settings = manifest["stages"]["fine-tune"]
new = auto.state_create(manifest["model"], settings, devtype)

In [None]:
for bx, by in feeds['train']:
    break

In [None]:
from cplxmodule.nn.utils.sparsity import named_sparsity

sparsity = dict(named_sparsity(new.model, threshold=-0.5, hard=True))
n_zer, n_par = map(sum, zip(*sparsity.values()))
# n_zer, n_par

In [None]:
n_zer, n_par

In [None]:
new.model(bx).shape

<br>

Model taken from [this tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)

In [None]:
import torch
from collections import OrderedDict

from cplxmodule.nn import CplxToCplx
from cplxmodule.nn import CplxConv2d, CplxLinear

from cplxmodule.nn import CplxReal
from cplxmodule.nn.modules.casting import ConcatenatedRealToCplx
from cplxmodule.nn.modules.casting import CplxToConcatenatedReal
# from cplxmodule.nn.relevance import CplxConv2dVD, CplxLinearVD
from cplxmodule.nn.relevance.extensions import CplxLinearVDBogus as CplxLinearVD
from cplxmodule.nn.relevance.extensions import CplxConv2dVDBogus as CplxConv2dVD
from cplxmodule.nn.masked import CplxConv2dMasked, CplxLinearMasked


class CIFAR10Model(torch.nn.Sequential):
    Linear = CplxLinear
    Conv2d = CplxConv2d

    def __init__(self):
        layers = [
            ("cplx", ConcatenatedRealToCplx(copy=False, dim=-3)),

            ("conv1", self.Conv2d( 3, 32, 3, 1)),
            ("relu1", CplxToCplx[torch.nn.ReLU]()),
            ("pool1", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("conv2", self.Conv2d(32, 64, 3, 1)),
            ("relu2", CplxToCplx[torch.nn.ReLU]()),
            ("pool2", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("flat_", CplxToCplx[torch.nn.Flatten](-3, -1)),
            ("lin_1", self.Linear(6 * 6 * 64, 2048)),
            ("relu3", CplxToCplx[torch.nn.ReLU]()),
            ("lin_2", self.Linear(2048, 10)),
            ("real", CplxReal()),
            # ("real", CplxToConcatenatedReal(dim=-1)),
            # ("lin_3", torch.nn.Linear(20, 10)),
        ]

        super().__init__(OrderedDict(layers))


class CIFAR10ModelVD(CIFAR10Model):
    Linear = CplxLinearVD
    Conv2d = CplxConv2dVD


class CIFAR10ModelMasked(CIFAR10Model):
    Linear = CplxLinearMasked
    Conv2d = CplxConv2dMasked


In [None]:
options["stages"]['fine-tune']['reset'] = False

options["stages"]['sparsify']['objective']['kl_div'] = 1e-1

options["stages"]['dense']["optimizer"]["weight_decay"] = 5e-4
options["stages"]['sparsify']["optimizer"]["weight_decay"] = 5e-4
options["stages"]['fine-tune']["optimizer"]["weight_decay"] = 5e-4

In [None]:
options

In [None]:
auto.run(options, './test', "cplx-cifar-vgg16")

In [None]:
from cplxpaper.auto.utils import load_snapshot

In [None]:
load_snapshot('./test/0-dense cplx-cifar.gz')["performance"]

In [None]:
load_snapshot('./test/2-fine-tune cplx-cifar.gz')["performance"]

In [None]:
losses = []
cold = load_snapshot('./test/0-dense cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/1-sparsify cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/2-fine-tune cplx-cifar.gz')
losses.append(cold['history']["loss"])

plt.semilogy(np.concatenate(losses))

In [None]:
assert False

<br>