In [None]:
import torch
import torch.nn as nn
import numpy as np
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Normalize
import matplotlib.pyplot as plt
from tqdm import tqdm

from models import SimpleMNISTCNN, VanillaVAE
from train import train_mnist, train_vae, generate_meta_dataloader, MNISTNetDataset

In [None]:
def get_data(num_nets):
    nets = []
    outs = []
    for _ in range(num_nets):
        net, losses = train_mnist(SimpleMNISTCNN)
        nets.append(net)
        # print(accuracy)
    return nets

nets = get_data(200)

In [None]:
vae = VanillaVAE(3015, 500, 6, 200, 500, 6)

In [None]:
import pickle

In [None]:
with open("200_trained_nets.pkl", 'wb') as f:
    pickle.dump(nets, f)

In [None]:
with open("20_trained_nets.pkl", 'rb') as f:
    nets = pickle.load(f)

In [None]:
oldestest_nets = nets

In [None]:
print(len(nets))

In [None]:
dset = MNISTNetDataset(nets, 64)


In [None]:
dl = generate_meta_dataloader(oldestest_nets, batch_size=50)

In [None]:
#TODO: Simpler model please

In [None]:


def flat_to_net(flat_net):
    with torch.no_grad():
        test_net = SimpleMNISTCNN()
        index = 0
        for p in test_net.parameters():
            end_index = torch.cumprod(torch.tensor(p.shape), 0)[-1] + index
            p.copy_(flat_net[index:end_index].reshape(p.shape))
            index = end_index
        return test_net

def train_vae(vae, dl, optimizer, losses=[], epochs=5,  use_reconstruction=True, batch_size=1, mnist_batch_size=256):
    softmax = torch.nn.Softmax(dim=-1)
    test = DataLoader(MNIST('.', train=False, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])), batch_size = mnist_batch_size, shuffle=True)

    for _ in tqdm(range(epochs)):
        for nets in dl:
            optimizer.zero_grad()
            mean, std, output = vae.forward_train(nets)
            dist_loss = vae.calc_loss(mean, std, constant=0.005)

            # 
            new_nets = [flat_to_net(flat) for flat in output]
            old_nets = [flat_to_net(n) for n in nets]
            symmetric_kl = 0
            total = 0
            # print("Running the new nets stuff")
            for old_net, new_net in zip(old_nets, new_nets):
                for mnist_batch, _ in test:
                    total += 1.0
                    new_out = softmax(new_net(mnist_batch))
                    with torch.no_grad():
                        old_out = softmax(old_net(mnist_batch))
                        older_out = softmax(oldestest_nets[0](mnist_batch))
                    # symmetrized KL
                    # print(torch.log(old_out / (new_out + 1e-8)))
                    # print("Old out shape: ", old_out.shape)
                    # print("New out shape: ", new_out.shape)
                    symmetric_kl += torch.sum((0.5 * torch.log((old_out + 1e-8) / (new_out + 1e-8)) * old_out + 0.5 * torch.log((new_out + 1e-8) / (old_out + 1e-8)) * new_out)) / new_out.shape[0]
                    # print("New: ", new_out[0])
                    # print("Old: ", old_out[0])
                    # print("OG: ", older_out[0])
                    # print(symmetric_kl)
                    break
            # print("Pre division symmetric KL:", symmetric_kl)
            # print("Total: ", total)
            # print("Nets shape: ", nets.shape)
            symmetric_kl /= total * nets.shape[0]
            if use_reconstruction:
                reconstruction = torch.sum((nets - output)**2) / nets.shape[0] / nets.shape[1]
            else:
                reconstruction = 0
            print("Symmetric kl loss: ", symmetric_kl * 100)
            print("Reconstruction loss: ", reconstruction *100)
            print("Distribution loss: ", dist_loss)
            vae_loss = symmetric_kl * 100 + reconstruction * 100 + dist_loss
            vae_loss.backward()
            losses.append(vae_loss.detach())
            optimizer.step()
            # print(vae_loss)
    return losses

In [None]:
vae = VanillaVAE(3015, 500, 3, 2, 500, 3)
optimizer = torch.optim.Adam(vae.parameters())
losses = []

In [None]:
losses = train_vae(vae, dl, optimizer, losses=losses, epochs=300)

# losses = train_vae(vae, dl, optimizer, losses=losses, epochs=100, use_reconstruction = False)
# 1.12
# 3015, 500, 3, 100, 500, 3 works :) visible on epoch ~35-40
# 3015, 500, 3, 50, 500, 3 works :) visible on epoch ~35-40
# 3015, 500, 3, 30, 500, 3 works (ish) :) visible on epoch ~35-40
# 3015, 500, 3, 20, 500, 3 works (ish) :) visible on epoch ~35-40
# 3015, 500, 3, 10, 500, 3 works (ish) :) visible on epoch ~35-40
# 3015, 500, 3, 5, 500, 3 works (ish) :) visible on epoch ~35-40

In [None]:
plt.plot(losses)
plt.show()

In [None]:
for nets in dl:
    print(nets)
    print(vae(nets))

In [None]:
means = []
for x in dl:
    mean, std, _ = vae.forward_train(x)
    means.append(mean)
means = torch.cat(means)

In [None]:
print(means.shape)
plt.scatter(means[:,0].detach(), means[:,1].detach())

In [None]:
# Prof Gu meeting notes.... not worth trying to interpret without convo

# 2 directions -> Use MNIST data
# Look independent of data -> "Metric" in neural networks
# Want theoretical results, think about cost function -> Convex, semi convex, bounded by convex.... Very interesting problem
# Next thing I should try -> Analysis is easier for analysis from entropy idea
# let v_i = |x_i - y_i|:  -sum((v_i) log |(v_i)|)
# 1st, let it be intuitive, then find Riemannian metric -> Take features of last layer directly somehow
# Currently have image data, basically done
# Synthetic aperture radar data (SAR) -> Images. So many open problems
# Weather, earthquakes, detect people underground
# RICCI FLOW

# Graph Ricci Curvature
# Point net