# MNIST experiment prototype

In [None]:
import torch
import numpy as np

In [None]:
import json

from pkg_resources import resource_stream

with resource_stream("cplxpaper.mnist", "template.json") as fin:
    options = json.load(fin)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
from cplxpaper.auto import auto

In [None]:
options["device"] = "cuda:1"

In [None]:
options["datasets"] = {
    'mnist-train': {
        'cls': "<class 'cplxpaper.mnist.dataset.MNISTTrain'>", 'root': './data'
    },
    'mnist-test': {
        'cls': "<class 'cplxpaper.mnist.dataset.MNISTTest'>", 'root': './data'
    }
}

In [None]:
options["features"] = {
    'cls': "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
    'signal_ndim': 2,
    'cplx': True,
    'shift': True,
#     'cls': "<class 'cplxpaper.auto.feeds.FeedRawFeatures'>"
}

In [None]:
options["model"]["cls"] = "<class 'cplxpaper.mnist.models.complex.MNISTModel'>"
options["stages"]['sparsify']["model"]["cls"] = "<class 'cplxpaper.mnist.models.complex.MNISTModelARD'>"
options["stages"]['fine-tune']["model"]["cls"] = "<class 'cplxpaper.mnist.models.complex.MNISTModelMasked'>"

In [None]:
options["stages"]['fine-tune']['reset'] = True

In [None]:
# options

In [None]:
auto.run(options, './test', "cplx-mnist")

In [None]:
from cplxpaper.auto.utils import load_snapshot

In [None]:
load_snapshot('./test/0-dense cplx-mnist.gz')["performance"]

In [None]:
load_snapshot('./test/2-fine-tune cplx-mnist.gz')["performance"]

In [None]:
losses = []
cold = load_snapshot('./test/0-dense cplx-mnist.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/1-sparsify cplx-mnist.gz')
losses.append(cold['history']["loss"])
cold = load_snapshot('./test/2-fine-tune cplx-mnist.gz')
losses.append(cold['history']["loss"])

plt.semilogy(np.concatenate(losses))

In [None]:
assert False

<br>

In [None]:
cold = load_snapshot('./test/2-fine-tune cplx-mnist.gz')

In [None]:

reim = zip(cold['model']["conv1.weight.real"].numpy(), cold['model']["conv1.weight.imag"].numpy())
for re, im in reim:
    plt.imshow((re*re+im*im)[0])
    plt.show()

In [None]:
import torch
from collections import OrderedDict

from cplxmodule.nn import CplxToCplx
from cplxmodule.nn import CplxConv2d, CplxLinear

from cplxmodule.nn.layers import CplxReal
from cplxmodule.nn.layers import ConcatenatedRealToCplx
from cplxmodule.nn.layers import CplxToConcatenatedReal
from cplxmodule.nn.relevance import CplxConv2dARD, CplxLinearARD
from cplxmodule.nn.masked import CplxConv2dMasked, CplxLinearMasked


class MNISTModel(torch.nn.Sequential):
    Linear = CplxLinear
    Conv2d = CplxConv2d

    def __init__(self):
        layers = [
            ("cplx", ConcatenatedRealToCplx(copy=False, dim=-3)),

            ("conv1", self.Conv2d( 1, 20, 5, 1)),
            ("relu1", CplxToCplx[torch.nn.ReLU]()),
            ("pool1", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("conv2", self.Conv2d(20, 50, 5, 1)),
            ("relu2", CplxToCplx[torch.nn.ReLU]()),
            ("pool2", CplxToCplx[torch.nn.AvgPool2d](2, 2)),
            ("flat_", CplxToCplx[torch.nn.Flatten](-3, -1)),
            ("lin_1", self.Linear(4 * 4 * 50, 500)),
            ("relu3", CplxToCplx[torch.nn.ReLU]()),
            ("lin_2", self.Linear(500, 10)),
            ("real", CplxReal()),
            # ("real", CplxToConcatenatedReal(dim=-1)),
            # ("lin_3", torch.nn.Linear(20, 10)),
        ]

        super().__init__(OrderedDict(layers))


class MNISTModelARD(MNISTModel):
    Linear = CplxLinearARD
    Conv2d = CplxConv2dARD


class MNISTModelMasked(MNISTModel):
    Linear = CplxLinearMasked
    Conv2d = CplxConv2dMasked


<br>

In [None]:
devtype = dict(device=torch.device("cuda:1"))
datasets = auto.get_datasets({
    'mnist-train': {'cls': "<class 'cplxpaper.mnist.dataset.MNISTTrain'>", 'root': './mnist/data'},
  'mnist-test': {'cls': "<class 'cplxpaper.mnist.dataset.MNISTTest'>", 'root': './mnist/data'}
})
feeds = auto.get_feeds(
    datasets, devtype,
    {
        'cls': "<class 'cplxpaper.auto.feeds.FeedFourierFeatures'>",
        "signal_ndim": 2,
        "cplx": True,
        "shift": True
    }, {
    'train': {
        'cls': "<class 'torch.utils.data.dataloader.DataLoader'>",
        'dataset': 'mnist-test',
        'batch_size': 128,
        'shuffle': True,
        'pin_memory': True,
        'n_batches': -1
    },
    'test': {
        'cls': "<class 'torch.utils.data.dataloader.DataLoader'>",
        'dataset': 'mnist-train',
        'batch_size': 128,
        'shuffle': False,
        'pin_memory': True,
        'n_batches': -1
    }
})

In [None]:
model = CplxMNISTModel().to(**devtype)

In [None]:
feed = feeds['test']

In [None]:
bx, by = next(iter(feeds["train"]))

In [None]:
model(bx)

In [None]:
from cplxpaper.auto.feeds import feed_forward_pass

feed_pred = feed_forward_pass(feed, model)
logits, y_true = map(np.concatenate, zip(*feed_pred))

In [None]:
logits.argmax(-1)

In [None]:
import tqdm

In [None]:
y_true, y_pred, logits = predict(model, tqdm.tqdm(feeds['test']))

In [None]:
from cplxmodule.utils.stats import sparsity

In [None]:
cm = confusion_matrix(y_true, y_pred)
tp = cm.diagonal()
fp, fn = cm.sum(axis=1) - tp, cm.sum(axis=0) - tp

f_sparsity = sparsity(model, hard=True, threshold=-0.5)

In [None]:
tp.sum() / cm.sum()

In [None]:
import os

In [None]:
import torch

In [None]:
from cplxpaper.auto import auto

In [None]:
from cplxpaper.mnist.dataset import MNISTTrain

In [None]:
options = {
    "datasets": {
        "mnist-train": {
#             "cls": "<class 'cplxpaper.mnist.dataset.MNISTTrain'>",
            "cls": "<class 'cplxpaper.cifar.dataset.CIFAR10Train'>",
            "root": os.path.abspath("./mnist/data")
        },
        "mnist-test": {
#             "cls": "<class 'cplxpaper.mnist.dataset.MNISTTest'>",
            "cls": "<class 'cplxpaper.cifar.dataset.CIFAR10Test'>",
            "root": os.path.abspath("./mnist/data")
        },
#         "musicnet-valid": {
#             "cls": "<class 'cplxpaper.musicnet.dataset.MusicNetRAM'>",
#             "filename": "./musicnet/data/musicnet_11khz_valid.h5",
#             "window": 4096,
#             "stride": 1
#         }
    },
    "features": {
        "cls": "<class 'cplxpaper.auto.feed.FeedFourierFeatures'>",
        "signal_ndim": 2,
        "cplx": True,
        "shift": True
    },
    "feeds": {
        "mnist-train": {
            "cls": "<class 'torch.utils.data.dataloader.DataLoader'>",
            "dataset": "mnist-test",
            "batch_size": 128,
            "shuffle": True,
            "pin_memory": True,
            "n_batches": -1,
        },
        "mnist-test": {
            "cls": "<class 'torch.utils.data.dataloader.DataLoader'>",
            "dataset": "mnist-train",
            "batch_size": 128,
            "shuffle": True,
            "pin_memory": True,
            "n_batches": -1,
        },
#         "musicnet-valid": {
#             "cls": "<class 'cplxpaper.musicnet.dataset.MusicNetDataLoader'>",
#             "dataset": "musicnet-valid",
#             "pin_memory": True,
#             "n_batches": 20
#         }
    },
    "objective_terms": {
        "loss": {
            "cls": "<class 'torch.nn.modules.loss.CrossEntropyLoss'>",
            "reduction": "mean"
        },
        "kl_div": {
            "cls": "<class 'cplxpaper.auto.objective.ARDPenaltyObjective'>",
            "reduction": "mean",
            "coef": 1.0
        }
    },
    "scorers": {},
    
}

In [None]:
devtype = dict(device=torch.device("cuda:1"))

datasets = auto.get_datasets(options["datasets"])
feeds = auto.get_feeds(datasets, devtype, options["features"], options["feeds"])
objective_terms = auto.get_objective_terms(datasets, options["objective_terms"])

scorers = auto.get_scorers(feeds, options["scorers"])

In [None]:
bx, by = next(iter(feeds["musicnet-valid"]))

In [None]:
bx

In [None]:
# [*map(tqdm.tqdm._decr_instances, list(tqdm.tqdm._instances))]

In [None]:
import tqdm

for bx, by in tqdm.tqdm(feeds["mnist-test"]):
    pass

In [None]:
bx, by = next(iter(feed))

In [None]:
z = torch.randn(320, 1, 28, 28, 2).cpu()

In [None]:
%%timeit -n 2000
torch.fft(z, signal_ndim=2)
torch.cuda.synchronize()