In [None]:
import torch
print('Version', torch.__version__)
print('CUDA enabled:', torch.cuda.is_available())

Version 1.7.0+cu101
CUDA enabled: True


In [None]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [None]:
import os
BASE_PATH = '/gdrive/My Drive/colab_files/Final Project/'
os.chdir(BASE_PATH)

In [None]:
import numpy as np

In [None]:
import torch.nn as nn
import torchvision
import torch.nn.functional as f
from torch.distributions.normal import Normal

In [None]:
USE_CUDA = True
use_cuda = USE_CUDA and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print('Using device', device)
import multiprocessing
NUM_WORKERS = multiprocessing.cpu_count()
print('num workers:', NUM_WORKERS)

Using device cuda
num workers: 2


In [None]:
def gmm_loss(batch, mus, sigmas, logpi, reduce=True): # pylint: disable=too-many-arguments
    """ Computes the Gaussian Mixture Model (GMM) loss.
    Compute minus the log probability of batch under the GMM model described
    by mus, sigmas, pi. Precisely, with bs1, bs2, ... the sizes of the batch
    dimensions (several batch dimension are useful when you have both a batch
    axis and a time step axis), gs the number of mixtures and fs the number of
    features.
    :args batch: (bs1, bs2, *, fs) torch tensor
    :args mus: (bs1, bs2, *, gs, fs) torch tensor
    :args sigmas: (bs1, bs2, *, gs, fs) torch tensor
    :args logpi: (bs1, bs2, *, gs) torch tensor
    :args reduce: if not reduce, the mean in the following formula is ommited
    :returns:
    loss(batch) = - mean_{i1=0..bs1, i2=0..bs2, ...} log(
        sum_{k=1..gs} pi[i1, i2, ..., k] * N(
            batch[i1, i2, ..., :] | mus[i1, i2, ..., k, :], sigmas[i1, i2, ..., k, :]))
    NOTE: The loss is not reduced along the feature dimension (i.e. it should scale ~linearily
    with fs).
    """
    batch = batch.unsqueeze(-2)
    normal_dist = Normal(mus, sigmas)
    g_log_probs = normal_dist.log_prob(batch)
    g_log_probs = logpi + torch.sum(g_log_probs, dim=-1)
    max_log_probs = torch.max(g_log_probs, dim=-1, keepdim=True)[0]
    g_log_probs = g_log_probs - max_log_probs

    g_probs = torch.exp(g_log_probs)
    probs = torch.sum(g_probs, dim=-1)

    log_prob = max_log_probs.squeeze() + torch.log(probs)
    if reduce:
        return - torch.mean(log_prob)
    return - log_prob

In [None]:
def total_loss(latent_obs, action, reward, terminal, latent_next_obs, model):
  latent_obs, action,\
        reward, terminal,\
        latent_next_obs = [arr.transpose(1, 0)
                           for arr in [latent_obs, action,
                                       reward, terminal,
                                       latent_next_obs]]
  mus, sigmas, logpi, rs, ds = model(action, latent_obs)
  gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
  bce = f.binary_cross_entropy_with_logits(ds, terminal)
  mse = f.mse_loss(rs, reward)
  scale = 32 + 2 # LATENT_SIZE=32 plus other losses
  return (gmm + bce + mse) / scale

In [None]:
class _MDRNNBase(nn.Module):
    def __init__(self, latents, actions, hiddens, gaussians):
        super(_MDRNNBase, self).__init__()
        self.latents = latents
        self.actions = actions
        self.hiddens = hiddens
        self.gaussians = gaussians

        self.gmm_linear = nn.Linear(
            hiddens, (2 * latents + 1) * gaussians + 2)

    def forward(self, *inputs):
        pass

In [None]:
class MDRNN(_MDRNNBase):
    """ MDRNN model for multi steps forward """
    def __init__(self, latents, actions, hiddens, gaussians):
        super(MDRNN, self).__init__(latents, actions, hiddens, gaussians)
        self.rnn = nn.LSTM(latents + actions, hiddens)

    def forward(self, actions, latents): # pylint: disable=arguments-differ
        """ MULTI STEPS forward.
        :args actions: (SEQ_LEN, BSIZE, ASIZE) torch tensor
        :args latents: (SEQ_LEN, BSIZE, LSIZE) torch tensor
        :returns: mu_nlat, sig_nlat, pi_nlat, rs, ds, parameters of the GMM
        prediction for the next latent -> rs, gaussian prediction of the reward and
        logit prediction of terminality -> ds.
            - mu_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
            - sigma_nlat: (SEQ_LEN, BSIZE, N_GAUSS, LSIZE) torch tensor
            - logpi_nlat: (SEQ_LEN, BSIZE, N_GAUSS) torch tensor
            - rs: (SEQ_LEN, BSIZE) torch tensor
            - ds: (SEQ_LEN, BSIZE) torch tensor
        """
        seq_len, bs = actions.size(0), actions.size(1)

        ins = torch.cat([actions, latents], dim=-1)
        outs, _ = self.rnn(ins)
        gmm_outs = self.gmm_linear(outs)

        stride = self.gaussians * self.latents

        mus = gmm_outs[:, :, :stride]
        mus = mus.view(seq_len, bs, self.gaussians, self.latents)

        sigmas = gmm_outs[:, :, stride:2 * stride]
        sigmas = sigmas.view(seq_len, bs, self.gaussians, self.latents)
        sigmas = torch.exp(sigmas)

        pi = gmm_outs[:, :, 2 * stride: 2 * stride + self.gaussians]
        pi = pi.view(seq_len, bs, self.gaussians)
        logpi = f.log_softmax(pi, dim=-1)

        rs = gmm_outs[:, :, -2]

        ds = gmm_outs[:, :, -1]

        return mus, sigmas, logpi, rs, ds

In [None]:
class MDRNNCell(_MDRNNBase):
    """ MDRNN model for one step forward """
    def __init__(self, latents, actions, hiddens, gaussians):
        super(MDRNNCell, self).__init__(latents, actions, hiddens, gaussians)
        self.rnn = nn.LSTMCell(latents + actions, hiddens)

    def forward(self, action, latent, hidden): # pylint: disable=arguments-differ
        """ ONE STEP forward.
        :args actions: (BSIZE, ASIZE) torch tensor
        :args latents: (BSIZE, LSIZE) torch tensor
        :args hidden: (BSIZE, RSIZE) torch tensor
        :returns: mu_nlat, sig_nlat, pi_nlat, r, d, next_hidden, parameters of
        the GMM prediction for the next latent, gaussian prediction of the
        reward, logit prediction of terminality and next hidden state.
            - mu_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
            - sigma_nlat: (BSIZE, N_GAUSS, LSIZE) torch tensor
            - logpi_nlat: (BSIZE, N_GAUSS) torch tensor
            - rs: (BSIZE) torch tensor
            - ds: (BSIZE) torch tensor
        """
        in_al = torch.cat([action, latent], dim=1)

        next_hidden = self.rnn(in_al, hidden)
        out_rnn = next_hidden[0]

        out_full = self.gmm_linear(out_rnn)

        stride = self.gaussians * self.latents

        mus = out_full[:, :stride]
        mus = mus.view(-1, self.gaussians, self.latents)

        sigmas = out_full[:, stride:2 * stride]
        sigmas = sigmas.view(-1, self.gaussians, self.latents)
        sigmas = torch.exp(sigmas)

        pi = out_full[:, 2 * stride:2 * stride + self.gaussians]
        pi = pi.view(-1, self.gaussians)
        logpi = f.log_softmax(pi, dim=-1)

        r = out_full[:, -2]

        d = out_full[:, -1]

        return mus, sigmas, logpi, r, d, next_hidden

In [None]:
class VAE(nn.Module):
    def __init__(self, device, batch_size=250):
        super(VAE, self).__init__()

        self.device = device
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(128, 256, 4, stride=2),
            nn.ReLU()
        )
        
        self.mufc = nn.Linear(1024, 32)
        self.logvarfc = nn.Linear(1024, 32)
        
        self.decoder_fc = nn.Linear(32, 1024)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 128, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 6, stride=2),
            nn.Sigmoid(),
        )
        
        self.batch_size = batch_size
        #self.dist = torch.distributions.laplace.Laplace(0, torch.ones([50]))
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        #noise = torch.randn(self.batch_size, 32).to(self.device)
        noise = torch.randn_like(std).to(self.device)
        return mu + std * noise # z
    
    def forward(self, x):
        x = self.encoder(x)
        x = x.reshape(-1, 1024)
        mu, logvar = self.mufc(x), self.logvarfc(x)
        z = self.reparameterize(mu, logvar)
        z_ = self.decoder_fc(z)
        z_ = z_.reshape(-1, 1024, 1, 1)
        return self.decoder(z_.float()), mu, logvar
    
    def get_z(self, x):
        with torch.no_grad():
            encoded = self.encoder(x).reshape(-1, 1024)
            mu, logvar = self.mufc(encoded), self.logvarfc(encoded)
            return self.reparameterize(mu, logvar)

    def loss_func(self, x, x_prime, mu, logvar):
      recon_loss = nn.BCELoss(reduction='sum')
      loss = recon_loss(x_prime, x)
      loss += -0.5 * torch.sum(1 + logvar - mu.pow(2) - torch.exp(logvar))

      return loss

In [None]:
vae = VAE(device)
vae.to(device, dtype=torch.float)
vae_path = BASE_PATH + "vae_original.pt"
vae.load_state_dict(torch.load(vae_path))

<All keys matched successfully>

In [None]:
class RolloutVaeDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path, transform=None):
        super(RolloutVaeDataset, self).__init__()

        self.transform = transform

        self.data = []
        if (dir_path[-1] != '/'):
            dir_path += '/'
        for file in os.listdir(dir_path):
          file_np = np.load(dir_path + str(file))
          imgs = file_np['obs'] # 1000 x 64 x 64 x 3
          actions = file_np['action']
          for i in range(len(imgs)):
            curr_img = imgs[i]
            curr_action = actions[i]
            self.data.append((curr_img, curr_action))
            #self.data.append((np.transpose(curr_img, (2, 0, 1))))

    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, idx):
        # TODO
        if (self.transform):
          return self.transform(self.data[idx][0]), torch.tensor(self.data[idx][1])
        else:
          return torch.tensor(self.data[idx][0]), torch.tensor(self.data[idx][1])

In [None]:
""" Some data loading utilities """
from bisect import bisect
from os import listdir
from os.path import join, isdir
from tqdm import tqdm
import torch
import torch.utils.data
import numpy as np

class _RolloutDataset(torch.utils.data.Dataset): # pylint: disable=too-few-public-methods
    def __init__(self, root, transform, buffer_size=200, train=True): # pylint: disable=too-many-arguments
        self._transform = transform

        self._files = [root + "/" + file for file in os.listdir(root)]

        # if train:
        #     self._files = self._files[:-600]
        # else:
        #     self._files = self._files[-600:]

        self._cum_size = None
        self._buffer = None
        self._buffer_fnames = None
        self._buffer_index = 0
        self._buffer_size = buffer_size

    def load_next_buffer(self):
        """ Loads next buffer """
        self._buffer_fnames = self._files[self._buffer_index:self._buffer_index + self._buffer_size]
        self._buffer_index += self._buffer_size
        self._buffer_index = self._buffer_index % len(self._files)
        self._buffer = []
        self._cum_size = [0]

        # progress bar
        pbar = tqdm(total=len(self._buffer_fnames),
                    bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} {postfix}')
        pbar.set_description("Loading file buffer ...")

        for f in self._buffer_fnames:
            with np.load(f) as data:
                self._buffer += [{k: np.copy(v) for k, v in data.items()}]
                self._cum_size += [self._cum_size[-1] +
                                   self._data_per_sequence(data['reward'].shape[0])]
            pbar.update(1)
        pbar.close()

    def __len__(self):
        # to have a full sequence, you need self.seq_len + 1 elements, as
        # you must produce both an seq_len obs and seq_len next_obs sequences
        if not self._cum_size:
            self.load_next_buffer()
        return self._cum_size[-1]

    def __getitem__(self, i):
        # binary search through cum_size
        file_index = bisect(self._cum_size, i) - 1
        seq_index = i - self._cum_size[file_index]
        data = self._buffer[file_index]
        return self._get_data(data, seq_index)

    def _get_data(self, data, seq_index):
        pass

    def _data_per_sequence(self, data_length):
        pass


class RolloutSequenceDataset(_RolloutDataset): # pylint: disable=too-few-public-methods
    """ Encapsulates rollouts.
    Rollouts should be stored in subdirs of the root directory, in the form of npz files,
    each containing a dictionary with the keys:
        - observations: (rollout_len, *obs_shape)
        - actions: (rollout_len, action_size)
        - rewards: (rollout_len,)
        - terminals: (rollout_len,), boolean
     As the dataset is too big to be entirely stored in rams, only chunks of it
     are stored, consisting of a constant number of files (determined by the
     buffer_size parameter).  Once built, buffers must be loaded with the
     load_next_buffer method.
    Data are then provided in the form of tuples (obs, action, reward, terminal, next_obs):
    - obs: (seq_len, *obs_shape)
    - actions: (seq_len, action_size)
    - reward: (seq_len,)
    - terminal: (seq_len,) boolean
    - next_obs: (seq_len, *obs_shape)
    NOTE: seq_len < rollout_len in moste use cases
    :args root: root directory of data sequences
    :args seq_len: number of timesteps extracted from each rollout
    :args transform: transformation of the observations
    :args train: if True, train data, else test
    """
    def __init__(self, root, seq_len, transform, buffer_size=200, train=True): # pylint: disable=too-many-arguments
        super().__init__(root, transform, buffer_size, train)
        self._seq_len = seq_len

    def _get_data(self, data, seq_index):
        obs_data = data['obs'][seq_index:seq_index + self._seq_len + 1]
        obs_data = self._transform(obs_data.astype(np.float32))
        obs, next_obs = obs_data[:-1], obs_data[1:]
        action = data['action'][seq_index+1:seq_index + self._seq_len + 1]
        action = action.astype(np.float32)
        reward, terminal = [data[key][seq_index+1:
                                      seq_index + self._seq_len + 1].astype(np.float32)
                            for key in ('reward', 'done')]
        # data is given in the form
        # (obs, action, reward, terminal, next_obs)
        return obs, action, reward, terminal, next_obs

    def _data_per_sequence(self, data_length):
        return data_length - self._seq_len

In [None]:
path = BASE_PATH + "record/"
transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize(mean=(0,), std=(1,))])

In [None]:
LATENT_SIZE=32
ACTION_SIZE=3
HIDDEN_SIZE=64
GAUSSIAN_SIZE=5
SEQ_LEN=32

In [None]:
mdrnn = MDRNN(LATENT_SIZE, ACTION_SIZE, HIDDEN_SIZE, GAUSSIAN_SIZE).to(device)
optimizer = torch.optim.RMSprop(mdrnn.parameters(), lr=1e-3, alpha=.9)
transform = torchvision.transforms.Lambda(lambda x: np.transpose(x, (0, 3, 1, 2)) / 255)
mdrnn_train_loader = torch.utils.data.DataLoader(RolloutSequenceDataset(path + "train", SEQ_LEN, transform, buffer_size=30), batch_size=256, shuffle=True)
mdrnn_test_loader = torch.utils.data.DataLoader(RolloutSequenceDataset(path + "test", SEQ_LEN, transform, buffer_size=30), batch_size=256, shuffle=True)

Loading file buffer ...: 100%|██████████| 30/30 
Loading file buffer ...: 100%|██████████| 20/20 


In [89]:
for i, data in enumerate(mdrnn_train_loader):
  obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]
  break

In [90]:
obs.shape, action.shape, reward.shape, terminal.shape, next_obs.shape

(torch.Size([256, 32, 3, 64, 64]),
 torch.Size([256, 32, 3]),
 torch.Size([256, 32]),
 torch.Size([256, 32]),
 torch.Size([256, 32, 3, 64, 64]))

In [None]:
def get_latent_obs(obs):
  latent_obs = torch.Tensor().to(device)
  for i in obs:
    latent_obs = torch.cat((latent_obs, torch.unsqueeze(vae.get_z(i), 0)))
  return latent_obs

In [None]:
import tqdm
def train(model, device, optimizer, train_loader, epoch, log_interval):
    model.train()
    losses = []
    for batch_idx, data in enumerate(train_loader):
      obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]
      latent_obs, next_latent_obs = get_latent_obs(obs), get_latent_obs(next_obs)
      loss = total_loss(latent_obs, action, reward, terminal, next_latent_obs, model)
      losses.append(loss.item())
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(obs), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)

def test(model, device, optimizer, test_loader, epoch, log_interval):
  model.eval()
  losses = []
  for batch_idx, data in enumerate(test_loader):
    obs, action, reward, terminal, next_obs = [arr.to(device) for arr in data]
    latent_obs, next_latent_obs = get_latent_obs(obs), get_latent_obs(next_obs)
    loss = total_loss(latent_obs, action, reward, terminal, next_latent_obs, model)
    losses.append(loss.item())
    if batch_idx % log_interval == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(obs), len(test_loader.dataset),
            100. * batch_idx / len(test_loader), loss.item()))
  return np.mean(losses)

In [None]:
epochs = 20
for epoch in range(epochs):
  train_loss = train(mdrnn, device, optimizer, mdrnn_train_loader, epoch, 20)
  test_loss = test(mdrnn, device, optimizer, mdrnn_test_loader, epoch, 20)
  print(f"TRAIN LOSS: {train_loss}")
  print(f"TEST LOSS: {test_loss}")

TRAIN LOSS: 1.2080542418691846
TEST LOSS: 1.1278328318749704
TRAIN LOSS: 1.0585249556435479
TEST LOSS: 1.0352778338616895
TRAIN LOSS: 1.0031214469008976
TEST LOSS: 0.9956617672597209
TRAIN LOSS: 0.9790780961513519
TEST LOSS: 0.975283436236843
TRAIN LOSS: 0.9662396550178528
TEST LOSS: 0.9674212653790751
TRAIN LOSS: 0.9583552281061808
TEST LOSS: 0.9616825090300652
TRAIN LOSS: 0.9524621751573351
TEST LOSS: 0.9585492293680867
TRAIN LOSS: 0.948262666993671
TEST LOSS: 0.9555230871323617
TRAIN LOSS: 0.9444746236006419
TEST LOSS: 0.9528036723213811
TRAIN LOSS: 0.9416323277685378
TEST LOSS: 0.9499958590153725
TRAIN LOSS: 0.9390058716138204
TEST LOSS: 0.949752627841888
TRAIN LOSS: 0.9367363552252451
TEST LOSS: 0.9518741148133432
TRAIN LOSS: 0.9348812586731381
TEST LOSS: 0.9485522374030082
TRAIN LOSS: 0.9335455709033542
TEST LOSS: 0.9467083782918991
TRAIN LOSS: 0.9316480861769783
TEST LOSS: 0.9465707636648609
TRAIN LOSS: 0.9303020656108856
TEST LOSS: 0.9474343599811677
TRAIN LOSS: 0.9293936358557

In [None]:
torch.save(mdrnn.state_dict(), BASE_PATH + "mdrnn.pt")