In [6]:
from __future__ import print_function


import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
sys.path.append('../')


from six.moves import xrange


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid



In [71]:
class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        #inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        print('flat_input.shape', flat_input.shape)
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized, perplexity, encodings

In [72]:
from src.networks.cnns import MLP
class ActionVQVAE(nn.Module):
    def __init__(self,embed_dim, num_embeddings, commitment_cost, decay, epsilon=1e-5):
        super(ActionVQVAE, self).__init__()
        self.encoder = MLP(2, embed_dim, [64, 64, 64])
        self.vq = VectorQuantizerEMA(num_embeddings, embed_dim, commitment_cost, decay, epsilon)
        self.decoder = MLP(embed_dim, 2, [64, 64, 64])
    
    def forward(self, x):
        x = self.encoder(x)
        loss, quantized, perplexity, _ = self.vq(x)
        x_recons = self.decoder(quantized)
        return loss, x_recons, quantized, perplexity
    
    def compute_loss(self, x):
        loss, x_recons, perplexity = self(x)
        mse_loss = F.mse_loss(x_recons, x)
        loss = loss + mse_loss
        
        loss_dict = {'loss': loss, 'mse_loss': mse_loss, 'vq_loss': loss, 'perplexity': perplexity}
        return loss, loss_dict



In [73]:
from collections import defaultdict
from progress.bar import Bar
import einops 
def train_epoch(model,optimizer, dataloader,device):
    model.train()

    total_loss = 0
    epoch_loss_dict = defaultdict(float)
    for _, batch in enumerate(dataloader):
        
        optimizer.zero_grad()
        action = batch[1].to(device)
        action = einops.rearrange(action, 'b t (n a) -> (b t n) a', a=2)
        loss, loss_dict = model.compute_loss(action)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
        optimizer.step()
        for k in loss_dict:
            epoch_loss_dict[k] += float(loss_dict[k])
        total_loss += float(loss)
        
    for k in epoch_loss_dict:
        epoch_loss_dict[k] /= len(dataloader)
    return total_loss/len(dataloader), epoch_loss_dict

def val_epoch(model,dataloader,device):
    model.eval()
    
    epoch_loss_dict = defaultdict(float)
    total_loss = 0
    with torch.no_grad():
        for _, batch in enumerate(dataloader):
            action = batch[1].to(device)
            action = einops.rearrange(action, 'b t (n a) -> (b t n) a', a=2)
            loss,  loss_dict = model.compute_loss(action)
            for k in loss_dict:
                epoch_loss_dict[k] += float(loss_dict[k])
            total_loss += float(loss)
            
    for k in epoch_loss_dict:
        epoch_loss_dict[k] /= len(dataloader)

    return  total_loss/len(dataloader), epoch_loss_dict
                


def train(model, optimizer, train_dataloader, val_dataloader, num_epochs,device,scheduler):
    info_bar = Bar('Training', max=num_epochs)
    min_val_loss = 100000
    
    for epoch in range(num_epochs):
        train_loss, train_loss_dict = train_epoch(model,optimizer, train_dataloader,device)
        val_loss, val_loss_dict = val_epoch(model,val_dataloader,device)   
        
        scheduler.step()
        
        short_epoch_info = "Epoch: {},  train Loss: {}, Val Loss: {}".format(epoch,train_loss,val_loss )   
        epoch_info = f"Epoch: {epoch}, TRAIN: "
        for k in train_loss_dict:
            epoch_info += f"{k}: {train_loss_dict[k]:.5f}, "
        epoch_info += "VAL: "
        for k in val_loss_dict:
            epoch_info += f"{k}: {val_loss_dict[k]:.5f}, "
        #epoch_info = f"Epoch: {epoch},TRAIN : DYN Loss: {train_dyn_loss} VAE LOSS: {train_vae_loss}  INV LOSS: {train_inv_loss}||   VAL : DYN Loss: {val_dyn_loss} VAE LOSS: {val_vae_loss} INV LOSS: {val_inv_loss}"
        print(epoch_info)
        #logger.info(epoch_info)
        torch.save(model.state_dict(), f"last_dynamics.pt")
        if min_val_loss > val_loss:
            min_val_loss = val_loss
            torch.save(model.state_dict(), f"best_val_dynamics.pt")

        Bar.suffix = short_epoch_info
        info_bar.next()
    info_bar.finish()
    return model


In [74]:

from src.datasets import SequenceImageTransitionDataset
data_path = "/cluster/home/gboeshertz/patch_rl/data/visual_150transitions_4_all_sprite_mover_True_4instantmoveno_targets.npz"
dataset = SequenceImageTransitionDataset(data_path=data_path,sequence_length=2,)

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), len(dataset)-int(len(dataset)*0.8)])
train_dataloader = DataLoader(train_dataset,batch_size=128,shuffle=True,num_workers=1)
val_dataloader = DataLoader(val_dataset,batch_size=128,shuffle=False,num_workers=1)


model = ActionVQVAE(2, 16, 0.25, 0.99)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)




torch.Size([150, 1, 128, 128, 3])
Creating MLP with input size 2 and output size 2
Creating MLP with layer sizes [64, 64, 64]
Creating MLP with input size 2 and output size 2
Creating MLP with layer sizes [64, 64, 64]


Epoch: 0, TRAIN: loss: 0.32786, mse_loss: 0.32777, vq_loss: 0.32786, perplexity: 1.00000, VAL: loss: 0.32050, mse_loss: 0.31877, vq_loss: 0.32050, perplexity: 1.00000, 
Epoch: 1, TRAIN: loss: 0.32485, mse_loss: 0.32318, vq_loss: 0.32485, perplexity: 1.00000, VAL: loss: 0.31642, mse_loss: 0.31604, vq_loss: 0.31642, perplexity: 1.00000, 
Epoch: 2, TRAIN: loss: 0.31939, mse_loss: 0.31901, vq_loss: 0.31939, perplexity: 1.00000, VAL: loss: 0.31427, mse_loss: 0.31402, vq_loss: 0.31427, perplexity: 1.00000, 
Epoch: 3, TRAIN: loss: 0.31590, mse_loss: 0.31565, vq_loss: 0.31590, perplexity: 1.00000, VAL: loss: 0.31282, mse_loss: 0.31253, vq_loss: 0.31282, perplexity: 1.00000, 
Epoch: 4, TRAIN: loss: 0.31310, mse_loss: 0.31278, vq_loss: 0.31310, perplexity: 1.00000, VAL: loss: 0.31192, mse_loss: 0.31155, vq_loss: 0.31192, perplexity: 1.00000, 
Epoch: 5, TRAIN: loss: 0.31085, mse_loss: 0.31045, vq_loss: 0.31085, perplexity: 1.00000, VAL: loss: 0.31150, mse_loss: 0.31104, vq_loss: 0.31150, perplexi

ActionVQVAE(
  (encoder): MLP(
    (layers): Sequential(
      (0): Linear(in_features=2, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=64, bias=True)
      (5): ReLU()
      (6): Linear(in_features=64, out_features=2, bias=True)
    )
  )
  (vq): VectorQuantizerEMA(
    (_embedding): Embedding(16, 2)
  )
  (decoder): MLP(
    (layers): Sequential(
      (0): Linear(in_features=2, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=64, bias=True)
      (5): ReLU()
      (6): Linear(in_features=64, out_features=2, bias=True)
    )
  )
)

In [76]:
loss, recons,quant,_ = model(torch.randn(1,1,2))

flat_input.shape torch.Size([1, 2])


In [78]:
quant

tensor([[[-102.6851,   39.7359]]], grad_fn=<AddBackward0>)

In [None]:
train(model, optimizer, train_dataloader, val_dataloader, 10,"cpu",scheduler)

In [85]:
class ActionDiscretizer():
    def __init__(self,num_actions, num_discrete_bins) -> None:
        self.num_actions = num_actions
        self.num_discrete_bins = num_discrete_bins
        self.bins = torch.linspace(-1,1,self.num_discrete_bins)
    
    def discretize(self,action):
        action = action.reshape(-1,self.num_actions)
        action = torch.bucketize(action,self.bins)
        return action

In [86]:
act_disc = ActionDiscretizer(2,16)

In [87]:
act = torch.randn(1,2)
print(act)
print(act_disc.discretize(act))

tensor([[10,  0]])