In [None]:
from pathlib import Path
import os
from glob import glob
import re

import torch
import torch.optim
import torch.nn.functional as F

import seaborn as sns

from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib

import numpy as np

from IPython.display import display, clear_output
from PIL import Image
from tqdm.auto import tqdm

from torchvision.io import write_video
from torchvision.transforms import ToTensor, Normalize

from predictive_coding.models.models import PredictiveCoder
from predictive_coding.dataset import EnvironmentDataset, collate_fn


In [None]:
device = 'cuda:0'

# Initialize the predictive coding architecture
model = PredictiveCoder(in_channels=3, out_channels=3, layers=[2, 2, 2, 2], seq_len=10, num_skip=3)
model = model.to(device)

ckpt = torch.load('../weights/predictive_coding.ckpt', map_location=device)
model.load_state_dict(ckpt)
model.eval()
clear_output()


In [None]:
# Load the environment observations
dataset = EnvironmentDataset(Path("../datasets/val-dataset"))
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True,
)


In [None]:
# Generate the latent vectors from the predictive coding neural network
latents = []
positions = []
for batch_idx, batch in enumerate(tqdm(dataloader)):
    images, actions, states = batch
    B, L, C, H, W = images.shape
    images = images.to(device).reshape(B*L, C, H, W)
    states = states.reshape(B*L, -1)

    with torch.no_grad():
        features = model.encoder(images)
        features = [f.reshape(B, L, f.shape[1], f.shape[2], f.shape[3]) for f in features]
        features = [attention(feature, feature, feature, model.mask)[0] for attention, feature in zip(model.attention, features)]
        features = [f.reshape(B*L, f.shape[2], f.shape[3], f.shape[4]) for f in features]
        codes = model.decoder.get_codes(features)
        latents.append(codes[1].cpu())
        positions.append(states)

latents = torch.cat(latents, dim=0).cpu().numpy()
positions = torch.cat(positions, dim=0).cpu().numpy()


In [None]:
from tqdm.auto import tqdm

histogram = []

for idx in tqdm(range(128)):
    quant = np.quantile(np.mean(latents, axis=(2, 3))[:, idx], 0.85)
    units = positions[np.mean(latents, axis=(2, 3))[:, idx] > quant]
    hist = plt.hist2d(units[:, 0].numpy(), units[:, 1].numpy(), bins=(41, 66), cmap='Blues', range=[[-22, 22], [-30, 36]])
    xedges, yedges = hist[1:3]
    histogram.append(hist[0])

histogram = np.stack(histogram, axis=0)


In [None]:
from scipy.spatial import ConvexHull
fig, axes = plt.subplots(nrows=16, ncols=8, figsize=(20, 40))

for idx in range(128):
    i, j = idx // 8, idx % 8
    ax = axes[i, j]
    
    quant = np.quantile(np.mean(latents, axis=(2, 3))[:, idx], 0.9)
    units = positions[np.mean(latents, axis=(2, 3))[:, idx] > quant]
    cov = np.cov(units, rowvar=False)
    mu = units.mean(axis=0).reshape(-1, 1, 1).numpy()
    approx_areas += [np.multiply.reduce(np.sqrt(np.linalg.svd(cov)[1])) * np.pi]
    areas += [(histogram[idx] > 0).sum()]
    
    grid = np.mgrid[-22:22:0.1, -30:36:0.1]

    gauss = 1/(2*np.pi) * np.linalg.det(cov)**(-0.5) * np.exp(
        -0.5 * np.einsum(
        "ijk,ijk->jk",
        (grid - mu),
        np.einsum("ij,jkl", np.linalg.inv(cov), (grid - mu))
    ))

    dalpha = 0.9
    im = histogram[idx] > 0
    ax.imshow(im, cmap="Blues", alpha=im*dalpha, extent=[-30, 36, 22, -22])
    thres = np.quantile(gauss, 0.8)
    ax.imshow(gauss, cmap="Blues", alpha=0.6, extent=[-30, 36, 22, -22])
    ax.axis("off")



In [None]:
sns.displot(approx_areas, kde=True, fill=True)
plt.xlabel("Area (Gaussian approximation, lattice units)")


In [None]:
mask = histogram.sum(axis=0) != 0
plt.bar(np.arange(mask.sum()), np.sort((histogram[:, mask] > 0).reshape(128, -1).sum(axis=0)), width=1)
plt.ylim([0, 128])
plt.yticks(np.linspace(0, 128, 4).astype(np.int32))
plt.xticks(np.linspace(0, 1381, 4).astype(np.int32))
plt.ylabel("Number of\nactive latent units")
plt.xlabel("Environment block")


In [None]:
mask = histogram.sum(axis=0) != 0

plt.bar(np.arange(128), np.sort((histogram[:, mask] > 0).reshape(128, -1).sum(axis=1)), width=1)

plt.ylim([200, 240])
plt.yticks(np.linspace(200, 240, 5).astype(np.int32))
plt.xticks(np.linspace(0, 128, 4).astype(np.int32))
plt.ylabel("Number of\nactive environment\n blocks")
plt.xlabel("Latent unit")
