In [1]:
from pathlib import Path
import os
from glob import glob
import re

import torch
import torch.optim
import torch.nn.functional as F

import seaborn as sns

from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import matplotlib.pyplot as plt
import matplotlib

import numpy as np

from IPython.display import display, clear_output
from PIL import Image
from tqdm.auto import tqdm

from torchvision.io import write_video
from torchvision.transforms import ToTensor, Normalize

from predictive_coding.models.models import Autoencoder
from predictive_coding.dataset import EnvironmentDataset, collate_fn


# Autoencoding

In this Google Colab notebook, we apply a pre-trained autoencoding neural network to a dataset containing observations from an agent navigating the Minecraft environment. First, we load the autoencoding neural network. Next, we import the validation dataset that captures episodes of an agent moving through various terrains in Minecraft. Our goal is to utilize the network's decoder to generate latent vectors from the data. These vectors provide a condensed representation of the agent's visual information within the environment.


In [2]:
device = 'cuda:0'

# Initialize the autoencoding architecture
model = Autoencoder(in_channels=3, out_channels=3, layers=[2, 2, 2, 2])
model = model.to(device)

ckpt = torch.load('../weights/autoencoding.ckpt', map_location=device)
model.load_state_dict(ckpt)
model.eval()
clear_output()


In [3]:
# Load the environment observations
dataset = EnvironmentDataset(Path("../datasets/val-dataset"))
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=2,
    pin_memory=True,
)


In [None]:
# Generate the latent vectors from the autoencoding neural network
latents = []
positions = []
for batch_idx, batch in enumerate(tqdm(dataloader)):
    images, actions, states = batch
    B, L, C, H, W = images.shape
    images = images.to(device).reshape(B*L, C, H, W)
    states = states.reshape(B*L, -1)

    with torch.no_grad():
        features = model.encoder(images)
        codes = model.decoder.get_codes(features)
        latents.append(codes[1].cpu())
        positions.append(states)

latents = torch.cat(latents, dim=0).cpu().numpy()
positions = torch.cat(positions, dim=0).cpu().numpy()


In [None]:
folds = glob('../datasets/ae-images/*')
images = torch.empty(len(folds), 3, 64, 64)
positions = torch.zeros(len(folds), 2)
for idx, fold in enumerate(folds):
    x, y = re.findall('../datasets/ae-images/(.*)_(.*)', fold)[0]
    positions[idx] = torch.from_numpy(np.load(f'{fold}/states.npy'))[-1, :2]
    
    for tidx in range(10):
        if not os.path.exists(f'{fold}/{tidx}.png'):
            images[idx, tidx] = torch.zeros(3, 64, 64)
            print(tidx, x, y)
            continue

        image = Image.open(f'{fold}/{tidx}.png')
        image = Normalize([121.6697, 149.3242, 154.9510], [40.7521,  47.7267, 103.2739])(ToTensor()(image))
    
        images[idx, tidx] = image
    

In [None]:
latents = []

model = model.to('cuda:0')

bsz = 100
for idx in range(len(images) // bsz + 1):
    batch = images[bsz*idx:bsz*(idx+1)].to('cuda:0')
    B, L, C, H, W = batch.shape
    batch = batch.to("cuda:0").reshape(B*L, C, H, W)
    act = actions[bsz*idx:bsz*(idx+1)].to("cuda:0").reshape(B*L, -1)
    
    with torch.no_grad():
        features = model.encoder(batch)
        codes = model.decoder.get_codes(features)
        latents.append(codes[1].cpu().reshape(B, L, -1, 8, 8))
        
latents = torch.cat(latents, dim=0)


In [None]:
from torch import nn
import torch
import torch.nn.functional as F

class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
        
    def forward(self, x):
        return self.func(x)

net = nn.Sequential(
    nn.Conv2d(128, 256, 3, padding=1),
    nn.MaxPool2d(2),
    Lambda(lambda x: x.reshape(-1, 256*4*4)),
    nn.Linear(64*8*8, 64),
    nn.ReLU(),
    nn.Linear(64, 2)
)

net = net.to("cuda:0")
    
optimizer = torch.optim.AdamW(net.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5000, gamma=0.1)


In [None]:
batch_size = 512
for epoch in range(10000):
    batch_idx = np.arange(0, len(latents))
    np.random.shuffle(batch_idx)
    batch_idx = batch_idx[:len(batch_idx) // batch_size * batch_size].reshape(-1, batch_size)
    for it, idx in enumerate(batch_idx):
        optimizer.zero_grad()
        batch = latents[idx, -1].to("cuda:0")
        pos = positions[idx, :2].to("cuda:0") / 30
        pred = net(batch)
        loss = F.mse_loss(pred, pos)
        loss.backward()
        optimizer.step()
        if it % 100:
            with torch.no_grad():
                pred = net(latents[:1000, -1].to("cuda:0")).cpu() * 30
                print(F.mse_loss(pred, positions[:1000, :2]), end='\r')
    scheduler.step()
        

In [None]:
with torch.no_grad():
    batch = latents[:, -1].to("cuda:0")
    ae_pos = positions[:, :2].cpu() / 30
    ae_pred = net(batch).cpu()
    ae = torch.linalg.norm(ae_pred * 30 - positions[:, :2], dim=1).cpu().numpy()


In [None]:
null = np.linalg.norm(np.random.randint(30, size=(2, 1000)) - np.random.randint(30, size=(2, 1000)), axis=0)
grid = np.stack(np.mgrid[-20:20,-30:35])
rand = np.linalg.norm(np.random.randn(*grid.shape), axis=0).reshape(-1)


In [None]:
from matplotlib.lines import Line2D

for idx in range(len(ae_pos)):
    plt.plot([-ae_pred[idx, 1], -ae_pos[idx, 1]], [ae_pred[idx, 0], ae_pos[idx, 0]], alpha=0.5, color=plt.cm.Greys(0.3), zorder=-1)
act = plt.scatter(-ae_pos[:, 1], ae_pos[:, 0], s=1*scale, c=plt.cm.Greys(0.6), label='Actual')
p = plt.scatter(-ae_pred[:, 1], ae_pred[:, 0], s=1*scale, c=plt.cm.Blues(0.75), label='Predicted')
plt.gca().set_aspect("equal")
plt.xlabel("x (lattice units)")
plt.ylabel("y (lattice units)")
plt.legend([Line2D([0], [0], marker='o', color=plt.cm.Greys(0.6), markersize=2, linestyle="None"), 
            Line2D([0], [0], marker='o', color=plt.cm.Blues(0.75), markersize=2, linestyle="None"),
            Line2D([0], [0], color=plt.cm.Greys(0.5), lw=2.0)], ['Actual', 'Predicted', 'Error'], 
           loc='upper right', bbox_to_anchor=(1.1, 1.0), prop=font)


In [None]:
import seaborn as sns
from matplotlib.lines import Line2D

sns.histplot(pc, kde=True, stat="density", label="Autoenoding")
sns.kdeplot(null, fill=True, label="Random Pairs", color=plt.cm.Dark2(0))
sns.kdeplot(rand, fill=True, label="Noise Model ($\sigma=1$)", color=plt.cm.Dark2(1))
plt.xlabel("Error ($\Vert x - \hat{x}(z) \Vert_{\ell_2}$) (lattice units)")
plt.legend()
