In [9]:
import torch
import torch.nn as nn 
import torch.optim as optim
import torch.nn.functional as F

import os
from glob import glob
import json
import numpy as np
from datetime import datetime

import gym

###### set seed for deterministic results #########

SEED = 89
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

class Agent(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 2, kernel_size=3, stride=1)
        self.conv2 = nn.Conv2d(2, 8, kernel_size=2, stride=1)
        self.pool = nn.AdaptiveMaxPool2d(1)  
        self.fc_out = nn.Linear(8, 1) # decide to go up or down

    def forward(self, x):
        h = F.relu(self.conv1(x))
        h = self.pool(F.relu(self.conv2(h)))
        h = h.view(x.size(0), -1)
        return F.sigmoid(self.fc_out(h))

######## helper functions ############

def prepro(img):
    """ prepro 210x160x3 uint8 frame into 6400 (80x80) 1D float vector """
    # crop image
    img = img[35:195]
    
    # downsample, take every second element and set to grey scale 
    img = img[::2,::2,0]
    
    img[img == 144] = 0 # erase background (type I) set to black 
    img[img == 109] = 0 # erase background (type II) set to black 
    
    # everything else erase (ball, paddle)
    img[img != 0] = 1 
    
    return img.astype(np.float32)

In [3]:
env = gym.make("Pong-v0")

In [4]:
state = env.reset()

In [8]:
x = prepro(state)

In [10]:
model = Agent()

In [12]:
model()

RuntimeError: Given groups=1, weight of size 2 1 3 3, expected input[1, 210, 160, 3] to have 1 channels, but got 210 channels instead

In [17]:
x = torch.from_numpy(x).view(1, 1, 80, 80)

In [20]:
model(x).view(1, -1).size()

torch.Size([1, 8])