In [None]:
import torch

import numpy as np

from torch import nn

from torch.autograd import Variable

from torch.nn import functional as F
from torch.nn import init
import time
from torch import optim
from copy import deepcopy
from tqdm import tqdm
from matplotlib import pyplot as plt
from tensorboardX import SummaryWriter
import os

import sys

if "ipykernel_launcher" in sys.argv[0]:
    sys.argv = [""]

In [None]:
import gym

import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--weights_file",type=str,default="None")
parser.add_argument("--lr",type=float, default=0.0001)
parser.add_argument("--ctl_type",type=str, default="lstm")
parser.add_argument("--opt",type=str, default="adam")
parser.add_argument("--iters",type=int, default=100000)
args = parser.parse_args()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Imports specifically so we can render outputs in Jupyter.
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames):
    """
    Displays a list of frames as a gif, with controls
    """
    #plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    display(display_animation(anim, default_mode='loop'))

In [None]:
def data_gen(bs=128,nb=1):
    for _ in range(nb):
        x = Variable(torch.Tensor(bs,3,64,64).uniform_(-1,1)).cuda()
        yield x

In [None]:
def get_optim(name,  model, lr, momentum):
    if name == "adam":
        return optim.Adam(params=model.parameters(),
                        lr=lr)
    elif name == "sgd":
        return optim.SGD(params=model.parameters(),
                        lr=lr,
                        momentum=momentum)
    elif name == "rmsprop":
          return optim.RMSprop(params=model.parameters(),
                        lr=lr,
                        momentum=momentum)

def print_info(mode,loss,t0,it):
    print("time: %8.4f"% (time.time() - t0))
    print("%s Loss for it %i: %8.4f"%(mode.capitalize(),it,loss))
    #print("%s Accuracy for epoch %i: %8.4f"%(mode.capitalize(),epoch,acc))

In [None]:
class VAE(nn.Module):
    def __init__(self,env="CarRacing"):
        super(VAE,self).__init__()
        if env == "CarRacing":
            nz = 32
        elif env == "Doom":
            nz = 64
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32,out_channels=64,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=4, stride=2),
            nn.ReLU())
        self.sigma_fc = nn.Linear(in_features=256*2*2,out_features=nz)
        self.mu_fc = nn.Linear(in_features=256*2*2,out_features=nz)
        
        self.decode_fc = nn.Linear(in_features=32,out_features=1024)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=1024,out_channels=128,kernel_size=5,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=128,out_channels=64,kernel_size=5,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=64,out_channels=32,kernel_size=6,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=32,out_channels=3,kernel_size=6,stride=2),
            nn.Sigmoid())
        
        self._initialize_weights()

    def forward(self,x):
        vec = self.encoder(x)
        
        #flatten
        vec = vec.view(vec.size(0),-1)
        mu, sigma = self.mu_fc(vec), self.sigma_fc(vec)
        z = self.reparameterize(mu,sigma)
        im = self.decode_fc(z)
        
        #reshape into im
        im = im[:,:,None,None]
        
        xh = self.decoder(im)
        
        return xh,mu,sigma
        
    
    def reparameterize(self,mu,sigma):
        eps = Variable(torch.Tensor(*sigma.size()).normal_()).cuda()
        z = mu + eps*sigma
        return z
    
    def _initialize_weights(self):
        # Official init from torch repo.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
        

def vae_loss(x,xh,mu,sigma):
    mu_sum_sq = (mu*mu).sum(dim=1)
    sig_sum_sq = (sigma*sigma).sum(dim=1)
    log_term = (1 + torch.log(sigma**2)).sum(dim=1)
    kldiv = 0.5 * (log_term - mu_sum_sq - sig_sum_sq)
    
    rec = F.mse_loss(xh,x)
    
    return rec + kldiv.mean()
    
    

In [None]:
def do_epoch(mode, model,opt,epoch):
    criterion = vae_loss
    dataloader = data_gen()
    if mode == "train":
        model.train()
    else:
        model.eval()
    it_losses = []
    for it, xv in enumerate(dataloader):

        if mode == "train":
            opt.zero_grad()
        xh,mu,sigma = model(xv)

        loss = criterion(xv,xh,mu,sigma)

        it_losses.append(loss.data[0])
        if mode == "train":
            loss.backward()
            opt.step()


    loss = np.mean(it_losses)
    #writer.add_scalar("%s/loss"%mode,scalar_value=loss,global_step=epoch)
    return loss

def train(
            model,
            lr=0.1,
            momentum=0.9,
            test=False,
            val=True,
            verbose=True,
            num_epochs=5, optimizer="adam"):
    
        if test:
            modes = ["test"]
            num_epochs = 1
            if args.weights_file == "None":
                assert False, "I don't think you meant to run on the test set without a weights file!"
        else:
            modes = ["train"]
            if val:
                modes.append("val")
        
        if args.weights_file != "None":
            model.load_state_dict(torch.load(args.weights_file))


        model = model.cuda()
        opt = get_optim(optimizer,model,lr,momentum)
        for epoch in range(num_epochs):
            for mode in modes:
                t0 = time.time()
                loss = do_epoch(mode,model,opt,epoch)
                if verbose:
                    print_info(mode,loss,t0,epoch)
#                 if mode == "train":
#                     torch.save(model.state_dict(), saved_model_dir +'/epoch_%i.pt' % epoch)

In [None]:
if __name__ == "__main__":
    env = gym.make('CarRacing-v0')
    V = VAE()

    train(V,lr=0.0001)

In [None]:
# x = next(data_gen().__iter__())

# V = VAE().cuda()

# xh,mu,sigma = V(x)

# loss = vae_loss(x,xh,mu,sigma)