In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import ipywidgets as widgets
import matplotlib.pyplot as plt
import matplotlib.animation
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from ipywidgets import interact, interactive, fixed, interact_manual

from nupic.research.frameworks.mandp.autoencoder import MandpAutoencoder
from nupic.research.frameworks.mandp.foliage import FoliageDataset

# Dataloaders workers cause problems in Jupyter
# (Possibly because the FoliageDataset has length infinity?)
from torch import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

In [None]:
model = MandpAutoencoder(step_size=200,
                         num_modules=30,
                         krecon=20.0,
                         kmag=40.0,
                         kphase=25.0,
                         local_delta=False)
trainer = pl.Trainer(limit_train_batches=2000, max_epochs=1)
trainer.fit(model)

In [None]:
w = model.encoder.weight.detach().view(-1, 64, 64)

for i in range(0, w.shape[0], 2):
    vmax = max(w[i].max(), w[i + 1].max())
    vmin = min(w[i].min(), w[i + 1].min())
    
    fig, ax = plt.subplots(1, 2)
    fig.set_figheight(2)
    fig.set_figwidth(4)
    ax[0].imshow(w[i], vmin=vmin, vmax=vmax)
    ax[1].imshow(w[i + 1], vmin=vmin, vmax=vmax)
    ax[0].axis("off")
    ax[1].axis("off")
    plt.show()

In [None]:
dataset = FoliageDataset()
it = iter(dataset)
batch = next(it)

In [None]:
encodings = model(batch).detach()
decodings = torch.nn.functional.linear(encodings, model.encoder.weight.transpose(0, 1)).detach()
complex_encodings = torch.view_as_complex(encodings.view(encodings.shape[0], -1, 2))

In [None]:
imgs = batch.view(10, 64, 64)
dec_imgs = decodings.view(10, 64, 64)
rings = complex_encodings.transpose(0, 1)

def draw_img(t):
    fig, ax = plt.subplots(1, 3)
    fig.set_figheight(4)
    fig.set_figwidth(14)
    ax[0].imshow(imgs[t]) 
    ax[1].imshow(dec_imgs[t])
    enc = encodings[t]
    complex_enc = torch.view_as_complex(enc.view(-1, 2))
    
    rings = complex_encodings.transpose(0, 1)
    for i, ring in enumerate(rings):
        color = f"C{i % 10}"
        ax[2].plot(torch.real(ring), torch.imag(ring), "-o", color=color, markersize=2)
        ax[2].plot(torch.real(ring[t]), torch.imag(ring[t]), "o", color=color)
        
    mval = max(torch.real(rings).abs().max(), torch.imag(rings).abs().max())
    
    ax[2].plot(0, 0, "X", color="black")
    ax[2].set_ylim(-mval, mval)
    ax[2].set_xlim(-mval, mval)
    
    ax[0].axis("off")
    ax[1].axis("off")
    ax[2].axis("off")
    plt.show()
    

interact(draw_img, t=widgets.IntSlider(min=0, max=imgs.shape[0] - 1, step=1, value=0))

In [None]:
fig, ax = plt.subplots(1, 3)
fig.set_figheight(4)
fig.set_figwidth(14)
ax[0].axis("off")
ax[1].axis("off")
ax[2].axis("off")

def update_figure(t):
    ax[0].imshow(imgs[t]) 
    ax[1].imshow(dec_imgs[t])
    enc = encodings[t]
    complex_enc = torch.view_as_complex(enc.view(-1, 2))

    ax[2].clear()
    rings = complex_encodings.transpose(0, 1)
    for i, ring in enumerate(rings):
        color = f"C{i % 10}"
        ax[2].plot(torch.real(ring), torch.imag(ring), "-o", color=color, markersize=2)
        ax[2].plot(torch.real(ring[t]), torch.imag(ring[t]), "o", color=color)

    mval = max(torch.real(rings).abs().max(), torch.imag(rings).abs().max())

    ax[2].plot(0, 0, "X", color="black")
    ax[2].set_ylim(-mval, mval)
    ax[2].set_xlim(-mval, mval)
    ax[2].axis("off")


moviewriter = matplotlib.animation.FFMpegFileWriter()
with moviewriter.saving(fig, "rings.mp4", dpi=100):
    for t in range(10):
        update_figure(t)
        moviewriter.grab_frame()