### TODO
1. ~~Code up VAE~~
2. ~~Code up M (RNN)~~
3. Code up Train Loop for M
3. Code up C (Controller)
4. Set up Car Racing
5. Set up Doom

In [1]:
import torch

import numpy as np

from torch import nn

from torch.autograd import Variable

from torch.nn import functional as F
from torch.nn import init
import time
from torch import optim
from copy import deepcopy
from tqdm import tqdm
from matplotlib import pyplot as plt
from tensorboardX import SummaryWriter
import os

import sys

if "ipykernel_launcher" in sys.argv[0]:
    sys.argv = [""]

In [2]:
import gym

In [3]:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--weights_file",type=str,default="None")
parser.add_argument("--lr",type=float, default=0.0001)
parser.add_argument("--ctl_type",type=str, default="lstm")
parser.add_argument("--opt",type=str, default="adam")
parser.add_argument("--iters",type=int, default=100000)
args = parser.parse_args()

In [4]:
import matplotlib.pyplot as plt
%matplotlib inline

# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    """
    Displays a list of frames as a gif, with controls
    """
    #plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    display(display_animation(anim, default_mode='loop'))

In [5]:
def data_gen(bs=128,nb=1):
    for _ in range(nb):
        x = Variable(torch.Tensor(bs,3,64,64).uniform_(-1,1)).cuda()
        yield x

In [6]:
class VAE(nn.Module):
    def __init__(self,env="CarRacing"):
        super(VAE,self).__init__()
        if env == "CarRacing":
            nz = 32
        elif env == "Doom":
            nz = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4, stride=2),
            nn.ReLU())
        self.sigma_fc = nn.Linear(in_features=256*2*2,out_features=nz)
        self.mu_fc = nn.Linear(in_features=256*2*2,out_features=nz)
        
        self.decode_fc = nn.Linear(in_features=32,out_features=1024)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024,out_channels=128,kernel_size=5,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=5,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=6,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32,out_channels=3,kernel_size=6,stride=2),
            nn.Sigmoid())
        
        self._initialize_weights()

    def forward(self,x):
        vec = self.encoder(x)
        
        #flatten
        vec = vec.view(vec.size(0),-1)
        mu, sigma = self.mu_fc(vec), self.sigma_fc(vec)
        z = self.reparameterize(mu,sigma)
        im = self.decode_fc(z)
        
        #reshape into im
        im = im[:,:,None,None]
        
        xh = self.decoder(im)
        
        return xh,mu,sigma
        
    
    def reparameterize(self,mu,sigma):
        eps = Variable(torch.Tensor(*sigma.size()).normal_()).cuda()
        z = mu + eps*sigma
        return z
    
    def _initialize_weights(self):
        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
        

In [7]:
def vae_loss(x,xh,mu,sigma):
    mu_sum_sq = (mu*mu).sum(dim=1)
    sig_sum_sq = (sigma*sigma).sum(dim=1)
    log_term = (1 + torch.log(sigma**2)).sum(dim=1)
    kldiv = 0.5 * (log_term - mu_sum_sq - sig_sum_sq)
    
    rec = F.mse_loss(xh,x)
    
    return rec + kldiv.mean()
    
    

In [8]:
def get_optim(name,  model, lr, momentum):
    if name == "adam":
        return optim.Adam(params=model.parameters(),
                        lr=lr)
    elif name == "sgd":
        return optim.SGD(params=model.parameters(),
                        lr=lr,
                        momentum=momentum)
    elif name == "rmsprop":
          return optim.RMSprop(params=model.parameters(),
                        lr=lr,
                        momentum=momentum)

def print_info(mode,loss,t0,it):
    print("time: %8.4f"% (time.time() - t0))
    print("%s Loss for it %i: %8.4f"%(mode.capitalize(),it,loss))
    #print("%s Accuracy for epoch %i: %8.4f"%(mode.capitalize(),epoch,acc))

In [9]:
def do_epoch(mode, model,opt,epoch):
    criterion = vae_loss
    dataloader = data_gen()
    if mode == "train":
        model.train()
    else:
        model.eval()
    it_losses = []
    for it, xv in enumerate(dataloader):

        if mode == "train":
            opt.zero_grad()
        xh,mu,sigma = model(xv)

        loss = criterion(xv,xh,mu,sigma)

        it_losses.append(loss.data[0])
        if mode == "train":
            loss.backward()
            opt.step()


    loss = np.mean(it_losses)
    #writer.add_scalar("%s/loss"%mode,scalar_value=loss,global_step=epoch)
    return loss

In [10]:
def train(
            model,
            lr=0.1,
            momentum=0.9,
            test=False,
            val=True,
            verbose=True,
            num_epochs=5, optimizer="adam"):
    
        if test:
            modes = ["test"]
            num_epochs = 1
            if args.weights_file == "None":
                assert False, "I don't think you meant to run on the test set without a weights file!"
        else:
            modes = ["train"]
            if val:
                modes.append("val")
        
        if args.weights_file != "None":
            model.load_state_dict(torch.load(args.weights_file))


        model = model.cuda()
        opt = get_optim(optimizer,model,lr,momentum)
        for epoch in range(num_epochs):
            for mode in modes:
                t0 = time.time()
                loss = do_epoch(mode,model,opt,epoch)
                if verbose:
                    print_info(mode,loss,t0,epoch)
#                 if mode == "train":
#                     torch.save(model.state_dict(), saved_model_dir +'/epoch_%i.pt' % epoch)

In [4]:
env = gym.make('CarRacing-v0')

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [None]:
# x = next(data_gen().__iter__())

# V = VAE().cuda()

# xh,mu,sigma = V(x)

# loss = vae_loss(x,xh,mu,sigma)

In [12]:
if __name__ == "__main__":
    V = VAE()

    train(V,lr=0.0001)

time:   4.1645
Train Loss for it 0: -41.1274
time:   0.1297
Val Loss for it 0: -42.6591
time:   0.1040
Train Loss for it 1: -42.8650
time:   0.1204
Val Loss for it 1: -44.2696
time:   0.0991
Train Loss for it 2: -43.7002
time:   0.1206
Val Loss for it 2: -44.4730
time:   0.1000
Train Loss for it 3: -44.8610
time:   0.1177
Val Loss for it 3: -44.9915
time:   0.1001
Train Loss for it 4: -45.5676
time:   0.1203
Val Loss for it 4: -46.3006


In [210]:
class M(nn.Module):
    def __init__(self,env="CarRacing"):
        super(M,self).__init__()
        if env == "CarRacing":
            self.nz = 32
            self.nh = 256
            self.action_len = 3 #3 continuous values
        elif env == "Doom":
            pass
#             self.nz = 64
#             self.nh = 512

        self.num_gaussians = 5
        
        self.sigma_len = self.nz
        self.mu_len = self.nz
        self.pi_len = 1
        self.len_mdp_output = 5*(self.sigma_len + self.mu_len + self.pi_len)
        

        self.rnn = nn.LSTM(input_size=self.nz+self.action_len,
                           hidden_size=self.nh,num_layers=1)
        
        self.mdn_fc = nn.Linear(in_features=self.nh,
                                out_features=self.len_mdp_output)
        
        self.h_prev = Variable(torch.Tensor(1,batch_size,m.nh).normal_()).cuda()
        self.c_prev = Variable(torch.Tensor(1,batch_size,m.nh).normal_()).cuda()
    
    def reset(self):
        self.h_prev = Variable(torch.Tensor(1,batch_size,m.nh).normal_()).cuda()
        self.c_prev = Variable(torch.Tensor(1,batch_size,m.nh).normal_()).cuda()
    
    def postproc_mdp_out(self,mdp_out):
        mus = mdp_out[:,:self.num_gaussians*self.nz]

        sigmas= mdp_out[:,self.num_gaussians*self.nz:2*self.num_gaussians*self.nz]
        
        pis = mdp_out[:,-self.num_gaussians:]
        
        
        mus = mus.resize(mus.size(0),self.num_gaussians,self.nz)
        
        sigmas = torch.exp(sigmas)
        sigmas = sigmas.resize(sigmas.size(0),self.num_gaussians,self.nz)
        
        pis = F.softmax(pis,dim=1)
        return mus, sigmas, pis
                        
    
    def forward(self, az):
        

        lstm_out, (self.h_prev,self.c_prev) = self.rnn(az[None,:],(self.h_prev,self.c_prev))
        
        raw_mdp_out = self.mdn_fc(lstm_out[0])
        
        mus, sigmas, pis = self.postproc_mdp_out(raw_mdp_out)
        return mus, sigmas, pis


In [217]:
batch_size = 128

m = M().cuda()

seq_len = 20

In [218]:
z = Variable(torch.Tensor(seq_len,batch_size,m.nz).normal_()).cuda()

a = Variable(torch.Tensor(seq_len,batch_size,3).normal_()).cuda()

az = torch.cat((z,a),dim=2)

In [228]:
az

Variable containing:
( 0 ,.,.) = 
  8.3264e-01 -1.5553e+00  6.7560e-01  ...  -6.4484e-01 -8.5330e-02 -1.5184e+00
  9.6795e-01 -6.8876e-01  3.3358e-01  ...   3.8021e-02 -4.4560e-02 -1.6403e-01
  3.5661e-01 -1.4102e-01 -5.0923e-01  ...  -1.0018e+00  1.6424e-02 -2.3967e-01
                 ...                   ⋱                   ...                
  7.9887e-01 -2.3134e+00  1.3460e+00  ...   6.8913e-02 -8.7446e-01 -1.6146e+00
  1.1519e+00 -3.2552e-01  2.7568e-01  ...   2.3266e+00  9.0345e-01 -2.2876e-01
 -5.4487e-01 -4.3077e-01  3.7654e-01  ...   7.7189e-02  4.1499e-01  1.6228e+00

( 1 ,.,.) = 
  2.1760e+00 -1.9260e+00  8.6509e-01  ...  -6.1966e-01  5.1640e-01  8.0364e-01
 -1.5310e+00  5.9219e-01 -4.4613e-01  ...  -9.7488e-01  2.6564e+00 -1.1863e+00
 -3.9626e-01 -1.8857e+00  4.7410e-01  ...  -5.9115e-01  1.2153e+00 -1.3075e+00
                 ...                   ⋱                   ...                
  1.1667e+00 -2.1198e-02  7.2850e-01  ...  -1.5191e-01 -6.4944e-01  6.1119e-01
 -1.

In [222]:
outs = []
for azi in az:
    msp = m(azi)
    outs.append(msp)


mus,sigmas,pi = msp

def calc_normal_pdf(z,mus,sigmas):
    num_mixtures = mus.size(1)
    pdfs = []
    for i in range(num_mixtures):
        zmmu = z[0] - mus[:,0]

        zmmu = zmmu[:,:,None]


        sigma = sigmas[:,0]


        exp_term = -0.5 * zmmu.transpose(2,1) @ (zmmu / sigma[:,:,None])

        torch.exp(exp_term[:,0,0])

        torch.prod(sigma,dim=1).size()

        coeff = (1 / (2*np.pi)) * torch.prod(sigma,dim=1)

        pdf = coeff * torch.exp(exp_term[:,0,0])
        pdfs.append(pdf[:,None])
    return torch.cat(pdfs,dim=1)

normals = calc_normal_pdf(z,mus,sigmas)

pdf = (pis * normals).sum(dim=1)

nll = -torch.log(pdf)

nll