In [7]:
!pip install sacred

Collecting sacred
[?25l  Downloading https://files.pythonhosted.org/packages/f4/8c/b99f668e8ca9747dcd374bb46cac808e58f3cb8e446df1b3e667f6be9778/sacred-0.8.2-py2.py3-none-any.whl (106kB)
[K     |████████████████████████████████| 112kB 5.0MB/s 
Collecting py-cpuinfo>=4.0
[?25l  Downloading https://files.pythonhosted.org/packages/e6/ba/77120e44cbe9719152415b97d5bfb29f4053ee987d6cb63f55ce7d50fadc/py-cpuinfo-8.0.0.tar.gz (99kB)
[K     |████████████████████████████████| 102kB 4.9MB/s 
Collecting GitPython
[?25l  Downloading https://files.pythonhosted.org/packages/bc/91/b38c4fabb6e5092ab23492ded4f318ab7299b19263272b703478038c0fbc/GitPython-3.1.18-py3-none-any.whl (170kB)
[K     |████████████████████████████████| 174kB 9.0MB/s 
Collecting munch<3.0,>=2.0.2
  Downloading https://files.pythonhosted.org/packages/cc/ab/85d8da5c9a45e072301beb37ad7f833cd344e04c817d97e0cc75681d248f/munch-2.5.0-py2.py3-none-any.whl
Collecting jsonpickle<2.0,>=1.2
  Downloading https://files.pythonhosted.org/pack

Data loader

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import numpy as np
from tqdm import tqdm

import sys

sys.path.append("src/")

# import utils


def get_MNIST_loaders(batch_size, shuffle=False, train_batch=None, test_batch=None):
    if train_batch == None:
        train_loader = get_MNIST_loader(batch_size, trainable=True, shuffle=shuffle)
    else:
        train_loader = get_MNIST_loader(train_batch, trainable=True, shuffle=shuffle)

    if test_batch == None:
        test_loader = get_MNIST_loader(batch_size, trainable=False, shuffle=shuffle)
    else:
        test_loader = get_MNIST_loader(test_batch, trainable=False, shuffle=shuffle)
    return train_loader, test_loader


def get_MNIST_loader(batch_size, trainable=True, shuffle=False):
    loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            "../data",
            train=trainable,
            download=True,
            transform=torchvision.transforms.Compose(
                [torchvision.transforms.ToTensor()]
            ),
        ),
        batch_size=batch_size,
        shuffle=shuffle,
    )
    return loader


class SparseVectorDataset(Dataset):
    def __init__(self, n, dim, ones, transform=None, seed=None):
        self.samples = generate_sparse_samples(n, dim, ones, seed)
        self.transform = transform

    def __len__(self):
        return self.samples.shape[1]

    def __getitem__(self, idx):
        sample = self.samples[:, idx].reshape(-1, 1, 1)
        if self.transform:
            sample = self.transform(sample).float()

        return sample


class SparseCompImageDataset(Dataset):
    def __init__(self, n, dim, ones, real_H, phi, transform=None):
        self.sparse_vectors = generate_sparse_samples(n, dim, ones)
        print(self.sparse_vectors.shape)
        self.comp_img = np.dot(real_H, self.sparse_vectors)
        self.img = np.dot(phi, self.comp_img)
        self.samples = np.dot(phi.T, self.img)
        self.transform = transform

    def __len__(self):
        return self.samples.shape[1]

    def __getitem__(self, idx):
        sample = self.samples[:, idx].reshape(-1, 1, 1)
        img = self.img[:, idx].reshape(-1, 1, 1)
        if self.transform:
            sample = self.transform(sample).float()
            img = self.transform(img).float()
        return sample, img


class EncodingDataset(Dataset):
    def __init__(self, data_loader, net, device=None, transform=None, seed=None):
        self.samples = []
        self.c = []
        print("create encoding dataset.")
        for idx, (img, c) in tqdm(enumerate(data_loader)):
            img = img.to(device)
            img = img.view(-1, net.D_org, 1)

            if len(net.phi.size()) == 3:
                i = idx % net.phi.size(0)

            _, enc, _ = net((i, img))

            self.samples.append(enc)
            self.c.append(c)

        self.samples = torch.cat(self.samples)
        self.c = torch.cat(self.c)
        self.D_enc = net.D_enc
        self.transform = transform

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, idx):
        sample = self.samples[idx].reshape(-1, self.D_enc, 1)

        if self.transform:
            sample = self.transform(sample).float()

        return sample, self.c[idx]


def generate_sparse_samples(n, dim, ones, seed=None, unif=True):
    samples = np.zeros((n, dim))
    np.random.seed(seed)
    for i in range(n):
        ind = np.random.choice(dim, ones, replace=False)
        if unif:
            # draws amplitude from [-5,-4] U [4,5] uniformly
            samples[i][ind] = np.random.uniform(4, 5, ones) * (
                (np.random.uniform(0, 1, ones) > .5) * 2 - 1
            )
        else:
            # amplitude is 1 or -1 .5 prob of each
            samples[i][ind] = np.array([1] * ones) * (
                (np.random.uniform(0, 1, ones) > .5) * 2 - 1
            )
    return samples.T


def generate_sparse_phi(sparsity, num_phi, D_enc, D_img):
    phis = [
        torch.tensor(generate_sparse_samples(D_img, D_enc, sparsity, unif=False))
        .float()
        .t()
        for _ in range(num_phi)
    ]
    return torch.stack(phis)


def generate_simulated_data(hyp):
    seed = hyp["seed"]
    D_enc = hyp["D_enc"]
    D_org = hyp["D_org"]
    D_comp = hyp["D_comp"]
    sparsity = hyp["sparsity"]
    randomness = hyp["randomness"]
    num_phis = hyp["num_phis"]
    num_nonzero = hyp["num_nonzero"]
    num_samples = hyp["num_samples"]
    batch_size = hyp["batch_size"]

    torch.manual_seed(seed)
    real_H = utils.normalize(torch.randn(D_org, D_enc)).float()
    noise = utils.normalize(torch.randn(D_org, D_enc)) * randomness
    H_init = utils.normalize(real_H * (1 - randomness) + noise)
    phis = generate_sparse_phi(sparsity, num_phis, D_org, D_comp)

    dataset = SparseVectorDataset(
        num_samples,
        D_enc,
        num_nonzero,
        transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
        seed=seed,
    )
    data_loader = DataLoader(dataset, batch_size=batch_size)

    return real_H, H_init, phis, data_loader


def get_encoding_loaders(train_loader, test_loader, net, hyp):
    train_dataset = EncodingDataset(train_loader, net, hyp["device"])
    test_dataset = EncodingDataset(test_loader, net, hyp["device"])
    enc_tr_loader = DataLoader(train_dataset, batch_size=hyp["batch_size"])
    enc_te_loader = DataLoader(test_dataset, batch_size=hyp["batch_size"])
    return enc_tr_loader, enc_te_loader

## CrsAE Model

In [9]:
import torch
import torch.nn.functional as F
import numpy as np

# import utils


class CRsAEDense(torch.nn.Module):
    def __init__(self, hyp, H=None):
        super(CRsAEDense, self).__init__()

        self.T = hyp["num_iters"]
        self.L = hyp["L"]
        self.lam = hyp["lam"]
        self.D_in = hyp["D_in"]
        self.D_enc = hyp["D_enc"]
        self.device = hyp["device"]

        if H is None:
            self.H = torch.nn.Parameter(
                F.normalize(torch.randn(self.D_in, self.D_enc), dim=0)
            )
        else:
            self.H = torch.nn.Parameter(H)

        self.H = self.H.to(self.device)

        self.relu = torch.nn.ReLU()

    def normalize(self):
        self.H.data = F.normalize(self.H.data, dim=0)

    def forward(self, x):
        num_batches = x.shape[0]

        x_old = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        yk = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        x_new = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        t_old = torch.tensor(1, device=self.device).float()
        for t in range(self.T):
            H_wt = x - torch.matmul(self.H, yk.reshape(-1, self.D_enc, 1))
            x_new = yk + torch.matmul(torch.t(self.H), H_wt) / self.L
            x_new = self.relu(torch.abs(x_new) - self.lam / self.L) * torch.sign(x_new)

            t_new = (1 + torch.sqrt(1 + 4 * t_old * t_old)) / 2
            yk = x_new + (t_old - 1) / t_new * (x_new - x_old)

            x_old = x_new
            t_old = t_new

        z = torch.matmul(self.H, x_new)

        return z, x_new


class CRsAERandProj(torch.nn.Module):
    def __init__(self, hyp, H=None, phi=None):
        super(CRsAERandProj, self).__init__()

        self.T = hyp["num_iters"]
        self.L = hyp["L"]
        self.lam = hyp["lam"]
        self.D_in = hyp["D_in"]
        self.D_org = hyp["D_org"]
        self.D_enc = hyp["D_enc"]
        self.device = hyp["device"]
        self.eval_mode = False

        if H is None:
            self.H = F.normalize(torch.randn(self.D_org, self.D_enc), dim=0)
        else:
            self.H = H

        if phi is None:
            self.phi = F.normalize(torch.randn(1, self.D_in, self.D_org), dim=0)
        else:
            self.phi = phi

        self.H = torch.nn.Parameter(self.H)
        self.phi = torch.nn.Parameter(self.phi)
        self.phi.requires_grad = False

        self.H = self.H.to(self.device)
        self.phi = self.phi.to(self.device)

        self.relu = torch.nn.ReLU()

    def normalize(self):
        self.H.data = F.normalize(self.H.data, dim=0)

    def forward(self, x):

        # if testing use the H with the lowest err_H
        if self.eval_mode:
            H = self.bestH
        else:
            H = self.H

        # for multiple phi use ith phi for image x
        if isinstance(x, tuple):
            i, x = x
            phiH = torch.matmul(self.phi[i], H)
        else:
            phiH = torch.matmul(self.phi, H)

        num_batches = x.shape[0]

        x_old = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        yk = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        x_new = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        t_old = torch.tensor(1, device=self.device).float()

        phiH = phiH.to(self.device)

        for t in range(self.T):
            H_wt = x - torch.matmul(phiH, yk.reshape(-1, self.D_enc, 1))
            x_new = yk + torch.matmul(torch.t(phiH), H_wt) / self.L
            x_new = self.relu(torch.abs(x_new) - self.lam / self.L) * torch.sign(x_new)

            t_new = (1 + torch.sqrt(1 + 4 * t_old * t_old)) / 2
            yk = x_new + (t_old - 1) / t_new * (x_new - x_old)

            x_old = x_new
            t_old = t_new

        z = torch.matmul(phiH, x_new)

        return z, x_new


class CRsAERandProjClassifier(torch.nn.Module):
    def __init__(self, hyp, H=None, phi=None):
        super(CRsAERandProjClassifier, self).__init__()

        self.T = hyp["num_iters"]
        self.L = hyp["L"]
        self.lam = hyp["lam"]
        self.D_in = hyp["D_in"]
        self.D_org = hyp["D_org"]
        self.D_enc = hyp["D_enc"]
        self.device = hyp["device"]
        self.eval_mode = False

        if H is None:
            self.H = F.normalize(torch.randn(self.D_org, self.D_enc), dim=0)
        else:
            self.H = H

        if phi is None:
            self.phi = F.normalize(torch.randn(1, self.D_in, self.D_org), dim=0)
        else:
            self.phi = phi

        self.H = torch.nn.Parameter(self.H)
        self.phi = torch.nn.Parameter(self.phi)
        self.phi.requires_grad = False

        self.H = self.H.to(self.device)
        self.phi = self.phi.to(self.device)

        self.relu = torch.nn.ReLU()
        self.classifier = torch.nn.Linear(self.D_enc, 10)
        self.classifier = self.classifier.to(self.device)

    def normalize(self):
        self.H.data = F.normalize(self.H.data, dim=0)

    def forward(self, x):

        # if testing use the H with the lowest err_H
        if self.eval_mode:
            H = self.bestH
        else:
            H = self.H

        # for multiple phi use ith phi for image x
        if isinstance(x, tuple):
            i, x = x
            phiH = torch.matmul(self.phi[i], H)
            x = torch.matmul(self.phi[i], x)
        else:
            phiH = torch.matmul(self.phi, H)
            x = torch.matmul(self.phi, x)

        num_batches = x.shape[0]

        x_old = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        yk = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        x_new = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        t_old = torch.tensor(1, device=self.device).float()

        phiH = phiH.to(self.device)

        for t in range(self.T):

            H_wt = x - torch.matmul(phiH, yk.view(-1, self.D_enc, 1))
            x_new = yk + torch.matmul(torch.t(phiH), H_wt) / self.L
            x_new = self.relu(torch.abs(x_new) - self.lam / self.L) * torch.sign(x_new)

            t_new = (1 + torch.sqrt(1 + 4 * t_old * t_old)) / 2
            yk = x_new + (t_old - 1) / t_new * (x_new - x_old)

            x_old = x_new
            t_old = t_new

        return self.classifier(x_new.view(-1, self.D_enc))


class CRsAERandProjAeClassifier(torch.nn.Module):
    def __init__(self, hyp, H=None, phi=None):
        super(CRsAERandProjAeClassifier, self).__init__()

        self.T = hyp["num_iters"]
        self.L = hyp["L"]
        self.lam = hyp["lam"]
        self.D_in = hyp["D_in"]
        self.D_org = hyp["D_org"]
        self.D_enc = hyp["D_enc"]
        self.device = hyp["device"]
        self.eval_mode = False

        if H is None:
            self.H = F.normalize(torch.randn(self.D_org, self.D_enc), dim=0)
        else:
            self.H = H

        if phi is None:
            self.phi = F.normalize(torch.randn(1, self.D_in, self.D_org), dim=0)
        else:
            self.phi = phi

        self.H = torch.nn.Parameter(self.H)
        self.phi = torch.nn.Parameter(self.phi)
        self.phi.requires_grad = False

        self.H = self.H.to(self.device)
        self.phi = self.phi.to(self.device)

        self.relu = torch.nn.ReLU()
        self.classifier = torch.nn.Linear(self.D_enc, 10)
        self.encoding_mode = False

    def normalize(self):
        self.H.data = F.normalize(self.H.data, dim=0)

    def forward(self, x):
        if self.encoding_mode:
            i, x = x
            return self.classifier(x.view(-1, self.D_enc))

        # if testing use the H with the lowest err_H
        if self.eval_mode:
            H = self.bestH
        else:
            H = self.H

        # for multiple phi use ith phi for image x
        if isinstance(x, tuple):
            i, x = x
            phiH = torch.matmul(self.phi[i], H)
            x = torch.matmul(self.phi[i], x)
        else:
            phiH = torch.matmul(self.phi, H)
            x = torch.matmul(self.phi, x)

        num_batches = x.shape[0]

        x_old = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        yk = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        x_new = torch.zeros(num_batches, self.D_enc, 1, device=self.device)
        t_old = torch.tensor(1, device=self.device).float()

        phiH = phiH.to(self.device)

        for t in range(self.T):

            H_wt = x - torch.matmul(phiH, yk.view(-1, self.D_enc, 1))
            x_new = yk + torch.matmul(torch.t(phiH), H_wt) / self.L
            x_new = self.relu(torch.abs(x_new) - self.lam / self.L) * torch.sign(x_new)

            t_new = (1 + torch.sqrt(1 + 4 * t_old * t_old)) / 2
            yk = x_new + (t_old - 1) / t_new * (x_new - x_old)

            x_old = x_new
            t_old = t_new

        z = torch.matmul(phiH, x_new)
        return (z, x_new, self.classifier(x_new.view(-1, self.D_enc)))

## Train

In [12]:
import torch
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pickle
import os
from datetime import datetime
from sacred import Experiment

import sys

sys.path.append("src/")

# import model, generator, trainer, utils, conf

# from conf import config_ingredient

import warnings

warnings.filterwarnings("ignore")

ex = Experiment("train", ingredients=[config_ingredient])


@ex.automain
def run(cfg):

    hyp = cfg["hyp"]

    print(hyp)

    random_date = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")

    PATH = "../results/{}/{}".format(hyp["experiment_name"], random_date)
    os.makedirs(PATH)

    filename = os.path.join(PATH, "hyp.pickle")
    with open(filename, "wb") as file:
        pickle.dump(hyp, file)

    print("load data.")
    if hyp["dataset"] == "MNIST":
        train_loader, test_loader = generator.get_MNIST_loaders(
            hyp["batch_size"], shuffle=hyp["shuffle"]
        )
        phis = F.normalize(
            torch.randn(hyp["num_phis"], hyp["D_comp"], hyp["D_org"]), dim=1
        )
        H_init = None
    elif hyp["dataset"] == "simulated":
        real_H, H_init, phis, train_loader = generator.generate_simulated_data(hyp)
    else:
        print("ERROR: dataset loader is not implemented.")

    print("create model.")
    if hyp["network"] == "CRsAEDense":
        net = model.CRsAEDense(hyp, H_init)
    elif hyp["network"] == "CRsAERandProj":
        net = model.CRsAERandProj(hyp, H_init, phis)
    elif hyp["network"] == "CRsAERandProjClassifier":
        net = model.CRsAERandProjClassifier(hyp, H_init, phis)
    elif hyp["network"] == "CRsAERandProjAeClassifier":
        net = model.CRsAERandProjAeClassifier(hyp, H_init, phis)
    else:
        print("model does not exist!")

    torch.save(net.H, os.path.join(PATH, "H_init.pt"))

    criterion = torch.nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=hyp["lr"], eps=1e-3)

    if hyp["classification"]:
        net.H.requires_grad = True
        net.classifier.requires_grad = False

    print("train auto-encoder.")
    if hyp["dataset"] == "simulated":
        if hyp["network"] == "CRsAEDense":
            err = trainer.train_ae_simulated(
                net, train_loader, hyp, criterion, optimizer, real_H, PATH
            )
        elif hyp["network"] == "CRsAERandProj":
            err = trainer.train_randproj_ae_simulated(
                net, train_loader, hyp, criterion, optimizer, real_H, phis, PATH
            )

    else:
        err = trainer.train_ae(net, train_loader, hyp, criterion, optimizer, PATH)

    if hyp["classification"]:
        net.H.requires_grad = False
        net.classifier.requires_grad = True

        optimizer.zero_grad()
        enc_tr_loader, enc_te_loader = generator.get_encoding_loaders(
            train_loader, test_loader, net, hyp
        )

        criterion_class = torch.nn.CrossEntropyLoss()

        print("train classifier.")
        net.encoding_mode = True
        acc = trainer.train_classifier_encodings(
            net, enc_tr_loader, hyp, criterion_class, optimizer, enc_te_loader
        )

        final_acc = (
            trainer.test_network(train_loader, net, hyp),
            trainer.test_network(test_loader, net, hyp),
        )
        net.encoding_mode = False
        print("final_acc", final_acc)

NameError: ignored

## Trainer +Test

In [13]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import os

import sys

sys.path.append("src/")

# import utils



def train_ae_simulated(net, data_loader, hyp, criterion, optimizer, real_H, PATH):

    num_epochs = hyp["num_epochs"]
    device = hyp["device"]
    info_period = hyp["info_period"]

    err = []
    for epoch in range(num_epochs):
        for idx, code in tqdm(enumerate(data_loader)):

            img = torch.matmul(real_H, code.reshape(-1, net.D_enc, 1))
            img = img.to(device)
            # ===================forward=====================
            img_hat, _ = net(img)
            loss = criterion(img_hat, img)
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            net.normalize()

            if idx % info_period == 0:
                print("loss:{:.4f}".format(loss.item()))

        err.append(utils.err_H(real_H, net.H.data))
        # ===================log========================

        print(
            "epoch [{}/{}], loss:{:.4f}, err_H:{:.4f}".format(
                epoch + 1, num_epochs, loss.item(), err[-1]
            )
        )

        torch.save(err[-1], os.path.join(PATH, "err_epoch{}.pt".format(epoch)))

    return err


def train_randproj_ae_simulated(
    net, data_loader, hyp, criterion, optimizer, real_H, phi, PATH, test_loader=None
):

    num_epochs = hyp["num_epochs"]
    device = hyp["device"]
    info_period = hyp["info_period"]

    err = []
    min_errH = 1
    bestH = None
    last_test_loss = 0
    true_decoder = torch.matmul(phi, real_H)
    true_decoder.requires_grad = False

    # guarantee net() takes a tuple in forward pass
    if len(phi.size()) == 2:
        true_decoder = true_decoder.unsqueeze(0)

    for epoch in range(num_epochs):
        for i, sample in tqdm(enumerate(data_loader)):
            # use ith phi to encode and decode
            i = i % true_decoder.size(0)

            sample, true_decoder = sample.to(device), true_decoder.to(device)

            img = torch.matmul(true_decoder[i], sample.view(-1, net.D_enc, 1)).view(
                -1, net.D_in, 1
            )
            # ===================forward=====================
            if len(phi.size()) == 2:
                img_hat, _ = net(img)
            else:
                img_hat, _ = net((i, img))
            loss = criterion(img_hat, img)
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            net.normalize()

            if idx % info_period == 0:
                print("loss:{:.4f}".format(loss.item()))

        err.append(utils.err_H(real_H.cpu(), net.H.cpu().data))

        if test_loader != None:
            for i, sample in tqdm(enumerate(test_loader)):
                # use ith phi to encode and decode
                i = i % true_decoder.size(0)

                sample, true_decoder = sample.to(device), true_decoder.to(device)

                img = torch.matmul(true_decoder[i], sample.view(-1, net.D_enc, 1)).view(
                    -1, net.D_in, 1
                )
                # ===================forward=====================
                if len(phi.size()) == 2:
                    img_hat, _ = net(img)
                else:
                    img_hat, _ = net((i, img))
                test_loss = criterion(img_hat, img)

        # ===================log========================
        if err[-1] < min_errH:
            min_errH = err[-1]
            net.bestH = net.H.data
        if test_loader == None:
            print(
                "epoch [{}/{}], loss:{:.4f}, err_H:{:.4f}".format(
                    epoch + 1, num_epochs, loss.data, err[-1]
                )
            )
        else:
            print(
                "epoch [{}/{}], loss:{:.4f}, test_loss:{:.4f}, err_H:{:.4f}".format(
                    epoch + 1, num_epochs, loss.data, test_loss.data, err[-1]
                )
            )

        if test_loader != None:
            if np.abs(test_loss.data - last_test_loss) < 5e-4:
                return err
            else:
                last_test_loss = test_loss.data

        torch.save(err[-1], os.path.join(PATH, "err_epoch{}.pt".format(epoch)))
        torch.save(loss.item(), os.path.join(PATH, "loss_epoch{}.pt".format(epoch)))

    return err


def train_ae(net, data_loader, hyp, criterion, optimizer, PATH):

    num_epochs = hyp["num_epochs"]
    device = hyp["device"]
    info_period = hyp["info_period"]

    err = []
    min_err = None
    for epoch in range(num_epochs):
        for idx, (img, c) in tqdm(enumerate(data_loader)):

            img = img.to(device)
            data = img.view(-1, net.D_org, 1)

            if len(net.phi.size()) == 3:
                i = idx % net.phi.size(0)

            # ===================forward=====================
            output = net((i, data))
            loss = criterion(output[0], torch.matmul(net.phi[i], data))
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            net.normalize()

            if idx % info_period == 0:
                print("loss:{:.4f}".format(loss.item()))

        # ===================log========================

        if min_err is None or min_err >= loss.data:
            min_err = loss.item()
            net.bestH = net.H.cpu().data
        err.append(loss.item())
        print("epoch [{}/{}], loss:{:.4f} ".format(epoch + 1, num_epochs, loss.item()))

        torch.save(loss.item(), os.path.join(PATH, "loss_epoch{}.pt".format(epoch)))

    return err


def train_classifier_encodings(
    net, data_loader, hyp, criterion, optimizer, val_loader=None, getHs=False
):

    num_epochs = hyp["num_epochs"]
    device = hyp["device"]
    info_period = hyp["info_period"]

    train_acc = []
    val_acc = []
    Hs = []
    for epoch in tqdm(range(num_epochs)):
        for idx, (img, c) in enumerate(data_loader):
            img = img.to(device)
            c = c.to(device)

            if len(net.phi.size()) == 3:
                i = idx % net.phi.size(0)
            # ===================forward=====================

            output = net((i, img))

            loss = criterion(output, c)
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            net.normalize()

            if idx % info_period == 0:
                print("loss:{:.4f}".format(loss.item()))

        # ===================log========================
        train_acc.append(test_network(data_loader, net, hyp))
        val_acc.append(test_network(val_loader, net, hyp))
        Hs.append(net.H.cpu().data)
        print(
            "epoch [{}/{}], loss:{:.4f}, train acc:{:.4f}, val acc:{:.4f}".format(
                epoch + 1, num_epochs, loss.item(), train_acc[-1], val_acc[-1]
            )
        )
    if getHs:
        return train_acc, val_acc, Hs
    return train_acc, val_acc


def test_network(data_loader, net, hyp, getExamples=False, getClasses=False):

    device = hyp["device"]

    with torch.no_grad():
        num_correct = 0
        num_total = 0
        correct_ex = []
        incorrect_ex = []
        examples = 300
        for idx, (img, c) in tqdm(enumerate(data_loader)):

            img = img.to(device)
            c = c.to(device)

            img = img.view(-1, net.D_enc, 1)

            i = idx % net.phi.size(0)
            # ===================forward=====================
            output = net((i, img))

            correct_indicators = output.max(1)[1].data == c
            num_correct += correct_indicators.sum().item()
            num_total += c.size()[0]

            if getExamples:
                count = 0
                for j, indicator in enumerate(correct_indicators):
                    if indicator and len(correct_ex) <= examples:
                        correct_ex.append((i, img[j], c[j]))
                    elif not indicator and len(incorrect_ex) <= examples:
                        incorrect_ex.append((i, img[j], c[j]))
                    count += 1
                    if count > 4:
                        break
            if getClasses:
                correct
        # ===================log========================

    acc = num_correct / num_total
    if getExamples:
        return (acc, correct_ex, incorrect_ex)
    return acc

## Utils

In [14]:
import torch
import torch.nn.functional as F
import numpy as np


def normalize(x):
    x_normed = x / x.norm(dim=0, keepdim=True)
    return x_normed


def err_H(H, H_hat):
    err = 0
    for i in range(H.size()[1]):
        err_i = 1 - np.dot(H[:, i], H_hat[:, i]) ** 2
        if err_i > err:
            err = err_i
    return err


def err_H_min(H, H_hat):
    err = 1
    for i in range(H.size()[1]):
        err_i = 1 - np.dot(H[:, i], H_hat[:, i]) ** 2
        if err_i < err:
            err = err_i
    return err


def err_H_avg(H, H_hat):
    err = 0
    for i in range(H.size()[1]):
        err_i = 1 - np.dot(H[:, i], H_hat[:, i]) ** 2
        err += err_i
    return err / H.size()[1]


def err_H_all(H, H_hat):
    errs = []
    for i in range(H.size()[1]):
        err_i = 1 - np.dot(H[:, i], H_hat[:, i]) ** 2
        errs.append(err_i)
    return errs


def sample_var(dataset, real_H):
    return np.dot(dataset.samples.T, real_H.t()).var(1).mean()


def display_imgs(net, test_sparse, real_H, D_in):
    img = torch.matmul(test_sparse.view(1, -1), real_H.t()).view(-1, D_in, 1)
    plt.plot(img.flatten().data.numpy())
    comp_img = torch.matmul(net.phi.cpu().data, img)
    net(comp_img)
    plt.plot(torch.matmul(net.H.cpu(), net.last_encoding[0]).flatten().data.numpy())
    plt.legend(["Real image", "Recovered image"])
    plt.show()


def display_img_enc(net, real_H, dataset):
    i = 0
    net.eval_mode = True
    net.use_cuda = False
    recon_img = net(
        torch.matmul(torch.matmul(net.phi.cpu().data, real_H.cpu().data), dataset[i][0])
    ).view(1, -1)
    display_imgs(net, dataset[i][0], real_H.cpu(), net.D_org)
    plt.scatter(range(net.D_enc), dataset[i][0])
    plt.scatter(range(net.D_enc), net.last_encoding[0].cpu())
    plt.title("Learned H encodings lam = " + str(net.lam))
    plt.legend(["Real encoding", "Recovered encoding"])
    plt.show()

    net.use_cuda = True


def display_err_plot(errs, initial_err):
    plt.plot(range(len(errs) + 1), [initial_err] + list(errs))
    plt.title("Err vs Epoch")
    plt.xlabel("Epoch")
    plt.ylabel("Err")
    plt.show()


def display_plots(net, real_H, dataset, errs, initial_err):
    display_img_enc(net, real_H, dataset)
    display_err_plot(errs, initial_err)


def save_model(net, acc, initial_H, name, num_iters, lam, mse):
    torch.save(
        net.H.data,
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "H.pt",
    )
    torch.save(
        net.classifier,
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "classifier.pt",
    )
    torch.save(
        torch.tensor(acc[0]),
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "TrainAcc.pt",
    )
    torch.save(
        torch.tensor(acc[1]),
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "TestAcc.pt",
    )
    torch.save(
        torch.tensor(mse),
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "MSE.pt",
    )
    torch.save(
        initial_H,
        name
        + "_Din"
        + str(net.D_in)
        + "_Denc"
        + str(net.D_enc)
        + "_iters"
        + str(num_iters)
        + "_lam"
        + str(lam)
        + "initial_H.pt",
    )