In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange
from IPython import display

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

In [2]:
# Implementing the core modules

class ConvStack(nn.Module):
    """
    Implements the conv_stack module as described in figure 6
    """
    def __init__(self, k1, c1, k2, c2, k3, c3):
        super(ConvStack, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=c1, kernel_size=k1, padding=(k1-1)//2)
        self.conv2 = nn.Conv2d(in_channels=c1, out_channels=c2, kernel_size=k2, padding=(k2-1)//2)
        self.conv3 = nn.Conv2d(in_channels=c2, out_channels=c3, kernel_size=k3, padding=(k3-1)//2)
        
    def forward(self, X):
        if not torch.is_tensor(X):
            X = torch.from_numpy(X).type(torch.FloatTensor)
            
        conv1_X = self.conv1(X)
        conv1_relu = F.relu(conv1_X)
        
        conv2_X = self.conv2(conv1_X)
        conv2_relu = F.relu(conv2_X + conv1_relu)
        
        conv3_X = self.conv3(conv2_relu)
        
        return conv3_X
    
    
class ResConv(nn.Module):
    """
    Implements the res_conv module as described in figure 7 in the paper
    """
    def __init__(self, use_extra_convolution=True):
        super(ResConv, self).__init__()
        self.use_extra_convolution = use_extra_convolution
        self.extra_conv = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=1)
        
        if use_extra_convolution:
            conv1_in_channels = 64
        else:
            conv1_in_channels = 3
        
        self.conv1 = nn.Conv2d(in_channels=conv1_in_channels, out_channels=32, kernel_size=3, padding=(3-1)//2)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=(5-1)//2)
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=(3-1)//2)
        
    def forward(self, X):
        if not torch.is_tensor(X):
            X = torch.from_numpy(X).type(torch.FloatTensor)
            
        if self.use_extra_convolution:
            c = self.extra_conv(X)
        else:
            c = X
        
        conv1_relu = F.relu(self.conv1(c))
        conv2_relu = F.relu(self.conv2(conv1_relu))
        rc3 = self.conv3(conv2_relu)
        rc = c + rc3
        return rc
    

class StateTransitionModule(nn.Module):
    """
    Implements the state transition function g(s,z,a) as described in figure 9 in the paper
    An action, a state and a latent variable at time t-1 is transitioned to 
    a state at time t.
    """
    def __init__(self):
        super(StateTransitionModule, self).__init__()
        self.res_conv1 = ResConv()
        self.res_conv2 = ResConv()
        
    def pool_inject(self, X):
        """
        Implements the pool & inject module as described in figure 8 in the paper
        """
        if not torch.is_tensor(X):
            X = torch.from_numpy(X).type(torch.FloatTensor)

        height, width = X.shape[2:]
        pooled = F.max_pool2d(X, kernel_size=(height, width), stride=(1, 1))
        tiled = pooled.expand(X.shape)
        pi = torch.cat([tiled, X], axis=1) # concat on the Color channel
        return pi
    
    def forward(self, a, s, z):      
        concat = torch.cat([a, s, z], 1)
        rc1_relu = F.relu(self.res_conv1(concat))
        pi = self.pool_inject(rc1_relu)
        s_next = self.res_conv2(pi)
        
        return s_next

In [3]:
class InitialModule(nn.Module):
    def __init__(self):
        super(InitialModule, self).__init__()
        self.conv_stack1 = ConvStack(3, 16, 5, 16, 3, 64)
        
    def forward(self, o):
        return self.conv_stack1(o)
        

class ObservationEncoder(nn.Module):
    def __init__(self):
        super(ObservationEncoder, self).__init__()
        self.conv_stack1 = ConvStack(3, 16, 5, 16, 3, 64)
        self.conv_stack2 = ConvStack(3, 32, 5, 32, 3, 64)
        
    def forward(self, o):
        std1 = o.view(o.shape[0], -1, o.shape[2]//4, o.shape[3]//4)
        cs1 = self.conv_stack1(std1)
        std2 = std1.view(std1.shape[0], -1, std1.shape[2]//2, std1.shape[3]//2)
        cs2 = self.conv_stack2(std2)
        e = F.relu(cs2)
        return e
    
class ObservationDecoder(nn.Module):
    def __init__(self):
        super(ObservationDecoder, self).__init__()
        self.conv_stack1 = ConvStack(1, 32, 5, 32, 3, 64)
        self.conv_stack2 = ConvStack(3, 64, 3, 64, 1, 48)
        
    def forward(self, s, z):
        concat = torch.cat([s, z], 1) 
        cs1 = self.conv_stack1(concat)
        dts1 = cs1.view(cs1.shape[0], -1, int(cs1.shape[2]*2), int(cs1.shape[3]*2))
        cs2 = self.conv_stack2(dts1)
        dts2 = cs2.view(cs2.shape[0], -1, int(cs2.shape[2]*4), int(cs2.shape[3]*4))
        
        return dts2