In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)


In [2]:
import torch

import numpy as np
import matplotlib.pylab as plt

from os.path import join
from pathlib import Path

from models.VQVAE import VectorQuantizerEMA, Encoder, Decoder

import pytorch_lightning as pl

from torchvision.utils import make_grid
from customLoader import CustomMinecraftData
from torchvision.transforms import transforms


In [3]:
class VQVAE(pl.LightningModule):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0,
                 batch_size=256, lr=0.001, split=0.95, img_size=64):
        super(VQVAE, self).__init__()


        self.batch_size = batch_size
        self.lr = lr
        self.split = split

        self._encoder = Encoder(3, num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)
        # self._pre_vq_conv = nn.Conv2d(in_channels=num_hiddens,
        #                               out_channels=embedding_dim,
        #                               kernel_size=1,
        #                               stride=1)
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim,
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
        self._decoder = Decoder(num_hiddens,
                                num_hiddens,
                                num_residual_layers,
                                num_residual_hiddens)

        self.example_input_array = torch.rand(batch_size, 3, img_size, img_size)

        self.transform = transforms.Compose([
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5,0.5,0.5), (1.0,1.0,1.0))
                                ])

    def forward(self, x):
        z = self._encoder(x)
        # z = self._pre_vq_conv(z)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)

        return loss, x_recon, perplexity

    def training_step(self, batch, batch_idx):

        vq_loss, data_recon, perplexity = self(batch)
        recon_error = F.mse_loss(data_recon, batch)
        loss = recon_error + vq_loss

        self.log('loss/train', loss, on_step=False, on_epoch=True)
        self.log('perplexity/train', perplexity, on_step=False, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):

        vq_loss, data_recon, perplexity = self(batch)
        recon_error = F.mse_loss(data_recon, batch)
        loss = recon_error + vq_loss

        self.log('loss/val', loss, on_step=False, on_epoch=True)
        self.log('perplexity/val', perplexity, on_step=False, on_epoch=True)

        if batch_idx == 0:
            grid = make_grid(data_recon[:64].cpu().data)
            grid = grid.permute(1,2,0)
            self.logger.experiment.log({"Images": [wandb.Image(grid.numpy())]})

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(params=self.parameters(), lr=self.lr, weight_decay=1e-5)

    def train_dataloader(self):
        train_dataset = CustomMinecraftData('CustomTrajectories1', 'train', self.split, transform=self.transform)
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)
        return train_dataloader

    def val_dataloader(self):
        val_dataset = CustomMinecraftData('CustomTrajectories1', 'val', self.split, transform=self.transform)
        val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)
        return val_dataloader


    def get_centroids(self, idx):
        z_idx = torch.tensor(idx).cuda()
        embeddings = torch.index_select(self._vq_vae._embedding.weight.detach(), dim=0, index=z_idx)
        embeddings = embeddings.view((1,2,2,64))
        embeddings = embeddings.permute(0, 3, 1, 2).contiguous()

        return self._decoder(embeddings)

    def save_encoding_indices(self, x):
        z = self._encoder(x)
        z = self._pre_vq_conv(z)
        _, _, _, encoding_indices = self._vq_vae(z)
        return encoding_indices

In [4]:
conf = {
  'split': 0.95,
  'lr': 0.001,
  'batch_size': 256,
  'num_hiddens': 64,
  'num_residual_hiddens': 32,
  'num_residual_layers': 2,
  'embedding_dim': 256,
  'num_embeddings': 10,
  'commitment_cost': 0.25,
  'decay': 0.99
}

In [5]:
vqvae = VQVAE(**conf).cuda()
vqvae.eval()

VQVAE(
  (_encoder): Encoder(
    (_conv_1): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_3): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_4): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_conv_5): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (_residual_stack): ResidualStack(
      (_layers): ModuleList(
        (0): Residual(
          (_block): Sequential(
            (0): ReLU(inplace=True)
            (1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (2): ReLU(inplace=True)
            (3): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          )
        )
        (1): Residual(
          (_block): Sequential(
            (0): ReLU(inplace=True)
            (1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bi

In [6]:
vqvae._vq_vae._embedding.weight

Parameter containing:
tensor([[ 0.4548, -1.6205,  0.9078,  ...,  2.1332, -2.9290,  0.2161],
        [-0.1996,  0.4195, -0.6094,  ...,  0.7398, -0.9086,  0.3735],
        [-2.0932,  2.0800, -1.5830,  ...,  0.3970, -2.2560,  0.7920],
        ...,
        [ 1.7407,  0.2839,  0.6586,  ..., -2.5021, -0.2781,  0.6295],
        [ 0.5438,  1.0325,  0.5431,  ..., -1.2149,  0.4331,  0.1137],
        [-0.1712, -0.0969, -0.1637,  ..., -0.6019, -0.5315,  0.1484]],
       device='cuda:0', requires_grad=True)

In [7]:
path = '../results/vqvae_0.2/mineRL/y77fc26u/checkpoints/epoch=808-step=61483.ckpt'

path = '../results/vqvae_0.1/mineRL/2wgoga4p/checkpoints/epoch=833-step=62549.ckpt'
path = '../results/vqvae_0.3/mineRL/1c4o6jgy/checkpoints/epoch=499-step=37999.ckpt'
checkpoint = torch.load(path)


In [8]:
vqvae.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [9]:
vqvae._vq_vae._embedding.weight

Parameter containing:
tensor([[0.0807, 0.1703, 0.1083,  ..., 0.0195, 0.4289, 0.1417],
        [0.1158, 0.2472, 0.3529,  ..., 1.0685, 0.2352, 0.1642],
        [0.0281, 1.5890, 1.1100,  ..., 2.0561, 0.2099, 0.2285],
        ...,
        [0.0887, 0.4713, 0.4223,  ..., 0.3388, 0.3095, 0.1533],
        [0.0366, 0.5491, 0.3956,  ..., 0.7268, 0.3135, 0.1216],
        [0.0397, 1.0619, 0.7591,  ..., 1.2831, 0.2222, 0.1758]],
       device='cuda:0', requires_grad=True)

In [11]:
for i in range(10):
    out = vqvae.get_centroids(i)
    img = out.squeeze().permute(1,2,0).detach().cpu().numpy()
    img = img + 0.5
    img[img>1] = 1
    #plt.imshow(img)
    #plt.show()
    plt.imsave(f"../goal_states/flat_biome_vqvae/centroid_{i}.png", img)