In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

import numpy as np
import matplotlib.pyplot as plt

In [None]:
epochs = 2
z_size = 32
n_hidden = 256
n_gaussians = 5
batch_size=10

In [None]:
from MemoryDataset import MemoryCellDataset
from vision_module import *

In [None]:
from vision_module import *
state_dict = '/content/drive/MyDrive/weights/vae(1).torch'
vae = torch.load(state_dict, map_location='cpu')
vae

VAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): ReLU()
    (8): Flatten()
  )
  (fc1): Linear(in_features=1024, out_features=32, bias=True)
  (fc2): Linear(in_features=1024, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=1024, bias=True)
  (decoder): Sequential(
    (0): UnFlatten()
    (1): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2))
    (2): ReLU()
    (3): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
    (4): ReLU()
    (5): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2))
    (6): ReLU()
    (7): ConvTranspose2d(32, 1, kernel_size=(6, 6), stride=(2, 2))
    (8): Sigmoid()
  )
)

In [None]:
img_path = './drive/MyDrive/data_rollouts2/CarRacing_random_with_act'
state_dict = './drive/MyDrive/weights/vae(1).torch'

dataset = MemoryCellDataset(state_dict_path=state_dict,
                            csv_path='./data_memory_cell.csv',
                            img_data_path=img_path)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

./drive/MyDrive/weights/vae(1).torch


In [None]:
next(iter(dataloader))

In [None]:
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class MDNRNN(nn.Module): 

  def __init__(self, z_size=32, act_size=3, n_hidden=256, n_gaussians=5, n_layers=1):
    super(MDNRNN, self).__init__()

    self.z_size = z_size
    self.act_size = act_size  # vect[sterzo, accelleratore, freno]
    self.input_size = z_size + act_size
    self.n_gaussians = n_gaussians
    self.n_layers = n_layers
    self.n_hidden = n_hidden
    self.lstm = nn.LSTM(self.input_size, n_hidden, n_layers, batch_first=True)
    self.fc1 = nn.Linear(n_hidden, n_gaussians*self.input_size)
    self.fc2 = nn.Linear(n_hidden, n_gaussians*self.input_size)
    self.fc3 = nn.Linear(n_hidden, n_gaussians*self.input_size)
  
  def get_mixture_coef(self, y): 
    rollot_length = y.size(1)
    pi, mu, sigma = self.fc1(y), self.fc2(y), self.fc3(y)

    pi = pi.view(-1, rollot_length, self.n_gaussians, self.input_size)
    mu = mu.view(-1, rollot_length, self.n_gaussians, self.input_size)
    sigma = sigma.view(-1, rollot_length, self.n_gaussians, self.input_size)

    pi = F.softmax(pi, 2)
    sigma = torch.exp(sigma)
    return pi, mu, sigma

  def forward(self, x, h):
    y, (h, c) = self.lstm(x, h)
    pi, mu, sigma = self.get_mixture_coef(y)
    return (pi, mu, sigma), (h, c)
  
  def init_hidden(self, batch):
    return (torch.zeros(self.n_layers, batch, self.n_hidden).to(device), 
            torch.zeros(self.n_layers, batch, self.n_hidden).to(device))


In [None]:
model = MDNRNN().to(device)
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def weighted_logsumexp(x, w, dim=None, keepdim=False):
    if dim is None:
        x, dim = x.view(-1), 0
    x_max, _ = torch.max(x, dim, keepdim=True)
    x = torch.where( # to prevent nasty nan's
        (x_max == float('inf')) | (x_max == float('-inf')),
        x_max,
        x_max + torch.log(torch.sum(torch.exp(x - x_max)*w, dim, keepdim=True)))
    
    return x if keepdim else x.squeeze(dim)


def mdn_loss_fn(y, pi, mu, sigma):
  m = torch.distributions.Normal(loc=mu, scale=sigma)  
  loss = -weighted_logsumexp(m.log_prob(y), pi, dim=2)
  return loss.mean()

def criterion(y, pi, mu, sigma): 
  y = y.unsqueeze(2)
  return mdn_loss_fn(y, pi, mu, sigma)

def detach(states): 
  return [state.detach() for state in states]

In [None]:
epochs=30
for epoch in range(epochs): 
  
  hidden = model.init_hidden(batch_size)
  for i, data in enumerate(dataloader):
    inputs = data['x'].to(device)
    targets = data['y'].to(device)

    hidden = detach(hidden)
    (pi, mu, sigma), hidden = model(inputs, hidden)
    loss = criterion(targets, pi, mu, sigma)

    model.zero_grad()
    loss.backward()
    clip_grad_norm_(model.parameters(), 0.5)
    optimizer.step()
    print(f'Epoch [{epoch}/{epochs}], Iter[{i}] -- Loss: {loss.item()}')

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [0/30], Iter[158] -- Loss: 0.9990720748901367
Epoch [0/30], Iter[159] -- Loss: 1.2068008184432983
Epoch [0/30], Iter[160] -- Loss: 0.25466230511665344
Epoch [0/30], Iter[161] -- Loss: 0.48139524459838867
Epoch [0/30], Iter[162] -- Loss: 0.9303758144378662
Epoch [0/30], Iter[163] -- Loss: 0.9023280739784241
Epoch [0/30], Iter[164] -- Loss: 0.2310640960931778
Epoch [0/30], Iter[165] -- Loss: 0.07966527342796326
Epoch [0/30], Iter[166] -- Loss: 0.10623180866241455
Epoch [0/30], Iter[167] -- Loss: 0.49160680174827576
Epoch [0/30], Iter[168] -- Loss: 0.7389270067214966
Epoch [0/30], Iter[169] -- Loss: 0.5639302134513855
Epoch [0/30], Iter[170] -- Loss: 0.19536060094833374
Epoch [0/30], Iter[171] -- Loss: 0.5817736983299255
Epoch [0/30], Iter[172] -- Loss: 0.025737015530467033
Epoch [0/30], Iter[173] -- Loss: 0.23947428166866302
Epoch [0/30], Iter[174] -- Loss: 0.37199166417121887
Epoch [0/30], Iter[175] -- Loss: 0.420159

KeyboardInterrupt: ignored

In [None]:
from google.colab import files
torch.save(model, 'rnn-mdn_14_ep.torch')
files.download('rnn-mdn_14_ep.torch')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>