# CIFAR10 experiment prototype

In [None]:
import torch
import numpy as np

In [None]:
import json

from pkg_resources import resource_stream

with resource_stream("cplxpaper.mnist", "template.json") as fin:
    options = json.load(fin)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from cplxpaper.auto import auto

In [None]:
options["device"] = "cuda:1"

In [None]:
options["datasets"] = {
    'mnist-train': {
        'cls': "<class 'cplxpaper.cifar.dataset.CIFAR10Train'>", 'root': './data'
    },
    'mnist-test': {
        'cls': "<class 'cplxpaper.cifar.dataset.CIFAR10Test'>", 'root': './data'
    }
}

In [None]:
options["features"] = {
    'cls': "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
    'signal_ndim': 2,
    'cplx': True,
    'shift': True,
#     'cls': "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>"
}

In [None]:
import torch
from collections import OrderedDict

from cplxmodule.nn import CplxToCplx
from cplxmodule.nn import CplxConv2d, CplxLinear

from cplxmodule.nn.layers import CplxReal
from cplxmodule.nn.layers import ConcatenatedRealToCplx
from cplxmodule.nn.layers import CplxToConcatenatedReal
from cplxmodule.nn.relevance import CplxConv2dARD, CplxLinearARD
# from cplxmodule.nn.relevance.extensions import CplxLinearVDBogus as CplxLinearARD
# from cplxmodule.nn.relevance.extensions import CplxConv2dVDBogus as CplxConv2dARD
from cplxmodule.nn.masked import CplxConv2dMasked, CplxLinearMasked


class CIFAR10Model(torch.nn.Sequential):
    Linear = CplxLinear
    Conv2d = CplxConv2d

    def __init__(self):
        layers = [
            ("cplx", ConcatenatedRealToCplx(copy=False, dim=-3)),

            ("conv1", self.Conv2d( 3, 20, 5, 1)),
            ("relu1", CplxToCplx[torch.nn.ReLU]()),
            ("pool1", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("conv2", self.Conv2d(20, 50, 5, 1)),
            ("relu2", CplxToCplx[torch.nn.ReLU]()),
            ("pool2", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("flat_", CplxToCplx[torch.nn.Flatten](-3, -1)),
            ("lin_1", self.Linear(5 * 5 * 50, 500)),
            ("relu3", CplxToCplx[torch.nn.ReLU]()),
            ("lin_2", self.Linear(500, 10)),
            ("real", CplxReal()),
            # ("real", CplxToConcatenatedReal(dim=-1)),
            # ("lin_3", torch.nn.Linear(20, 10)),
        ]

        super().__init__(OrderedDict(layers))


class CIFAR10ModelARD(CIFAR10Model):
    Linear = CplxLinearARD
    Conv2d = CplxConv2dARD


class CIFAR10ModelMasked(CIFAR10Model):
    Linear = CplxLinearMasked
    Conv2d = CplxConv2dMasked


In [None]:
options["model"]["cls"] = "<class '__main__.CIFAR10Model'>"
options["stages"]['sparsify']["model"]["cls"] = "<class '__main__.CIFAR10ModelARD'>"
options["stages"]['fine-tune']["model"]["cls"] = "<class '__main__.CIFAR10ModelMasked'>"

In [None]:
options["stages"]['fine-tune']['reset'] = False

options["stages"]['sparsify']['objective']['kl_div'] = 1e-1

In [None]:
options

In [None]:
auto.run(options, './test', "cplx-cifar")

In [None]:
from cplxpaper.auto.utils import load_snapshot

In [None]:
load_snapshot('./test/0-dense cplx-cifar.gz')["performance"]

In [None]:
load_snapshot('./test/2-fine-tune cplx-cifar.gz')["performance"]

In [None]:
losses = []
cold = load_snapshot('./test/0-dense cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/1-sparsify cplx-cifar.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/2-fine-tune cplx-cifar.gz')
losses.append(cold['history']["loss"])

plt.semilogy(np.concatenate(losses))

In [None]:
assert False

<br>