In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import gym
from gym_utils import AtariEnv
from gym_utils import AtariFrame

import numpy as np

environment_name = "SpaceInvaders-v4"



In [2]:
# define a pytorch model.  for now, accept a 210 x 160 greyscale image and output an array of actions


class AtariModel(nn.Module):

    def __init__(self, dropout=0.25):
        """
        Initialize the PyTorch AtariModel Module
        :param dropout: dropout to add in between LSTM/GRU layers
        """
        super(AtariModel, self).__init__()
        
        # convolutional layer 1  (in_channels, out_channels, kernel_size, stride=1, padding=0)
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        # convolutional layer 2
        self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        # convolutional layer 3
        self.conv3 = nn.Conv2d(32, 64, 3, stride=1, padding=1)

        # max pooling layer
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(8320, 512)    #64 * 14 * 14 = 12544
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 6)
        
        self.dropout = nn.Dropout(0.25)

    def forward(self, img_array):
        """
        Forward propagation of the neural network
        :param img_array: The input img array to the neural network
        :return
        """
        ## Define forward behavior
        
        #print("forward received img_array of shape: {}".format(img_array.shape))
        
        #convolutional layers
        x = self.maxpool(F.relu(self.conv1(img_array)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = self.maxpool(F.relu(self.conv3(x)))  
        
        #flatten
        x = x.view(-1, 8320)  
        #print("x.view shape: {}".format(x.shape))  #torch.Size([1, 8320])
        
        #fc layers
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    


In [3]:
current_action = 0
done = False

#play a game. feed each frame into the model and see what we get

model = AtariModel()
model.cuda()
atari_env = AtariEnv(environment_name)
while not done:
    atari_frame = atari_env.step(current_action)
    img_array = atari_frame.img_array
    img_array = img_array.reshape((3,160,210))
    img_array = img_array.reshape((1,3,160,210))
    img_tensor = torch.from_numpy(img_array).float().cuda()
    output = model(img_tensor)
    action_array = output.detach().cpu().numpy()[0]
    current_action = np.argmax(action_array)
    print("{} - {}".format(current_action, output.detach().cpu().numpy()[0]))
    done = atari_frame.done_bool
    
atari_env.close()

discounted_rewards = atari_env.get_discounted_rewards()

print()
print(discounted_rewards)
    
    

0 - [ 2.2037475   1.5563657   0.48850822 -0.5162107  -1.2929356  -0.43742296]
0 - [ 1.9128141   0.4270519   0.77913254  0.30288914 -0.6352037  -0.8877074 ]
0 - [ 1.1330154   0.71238685  0.8408416  -0.2990056  -0.45433557 -1.4575402 ]
0 - [ 0.8252804   0.6519688   0.24410674 -0.75893027 -1.4665937  -1.7605662 ]
0 - [ 1.6772327   0.4708508  -0.3934226  -0.33061656 -0.92311645 -0.42581016]
0 - [ 2.2349193   1.5096482   0.11856144 -0.782941   -1.1498011  -1.1057075 ]
1 - [ 1.5229372   1.5834614   0.7312812  -0.31299338 -1.7057703  -1.9977542 ]
0 - [ 2.2876418   1.8001051   0.11021099 -1.0306253  -1.878144   -1.2458704 ]
0 - [ 2.0185952   0.78385407  0.44315806 -0.7768797  -0.52215284 -0.44142392]
0 - [ 1.3961701   1.1992224   0.66818446 -1.2778066  -1.0978649  -1.921906  ]
0 - [ 1.8605754  1.1170897  0.540122  -0.7183042 -1.4194356 -0.8800545]
1 - [ 1.8573169   1.8721999  -0.51114684  0.36848596 -0.43318418  0.25579187]
0 - [ 2.7966983   0.96871024  0.74010736 -1.1010034  -1.0960648  -0.91

0 - [ 2.1024199   0.72147864  0.4608227  -0.1726866  -0.5540132  -0.4715723 ]
0 - [ 1.3642528   0.6271391   0.18799037 -0.05068608 -1.9178315  -1.4184349 ]
0 - [ 1.8569596   1.5390594  -0.26728556 -0.5003101  -1.8933046  -2.1635792 ]
2 - [ 1.3539994   0.72971827  1.5024441  -0.5968494  -0.93498135 -2.9158785 ]
0 - [ 1.5575924   0.91017306  0.8494722  -0.38442284 -0.30672094 -1.9571953 ]
1 - [ 0.7185502   1.9496075   0.7838564   0.29346806 -0.7047962  -1.7811178 ]
0 - [ 1.9540731   0.8763759   1.3161393  -0.44613984 -0.6251205  -1.4460945 ]
0 - [ 0.900096   -0.07898754  0.26098168 -1.1657333  -0.5253048  -1.5192754 ]
0 - [ 2.6989522   1.4169788   0.15406352 -0.16858767 -1.5630522  -2.5105608 ]
0 - [ 0.38926363 -0.3025565  -0.1583885  -0.8745155  -0.53799933 -1.484368  ]
0 - [ 2.20116     0.9477549  -0.01897758 -0.20717368 -0.5591454  -0.96974367]
0 - [ 2.0768702   1.0081071   0.8654377  -0.4298694   0.23305476 -0.3144493 ]
0 - [ 1.3958658   1.0020578   0.58417255  0.7350782  -1.6493782 

0 - [ 1.7488731   0.6444527   0.20125955 -0.12723671 -0.70308816 -0.44364473]
0 - [ 2.9209468e+00  1.3492420e+00  5.0972331e-01  9.0461597e-04
 -8.4092265e-01 -4.5079866e-01]
0 - [ 2.2102826   1.1707256  -0.5019786  -0.09723693 -1.1768064  -0.17285961]
0 - [ 2.5061142  0.3696126  0.2787233 -1.1558063 -1.2543105 -2.2407389]
0 - [ 1.4648983   1.0558476  -0.20434788 -0.92160004 -0.7541744  -1.5173575 ]
0 - [ 1.7044269  -0.30570024 -1.3493823  -1.1305605  -0.24010462 -1.8471202 ]
0 - [ 1.1617175   0.8687662  -0.09775054  0.8251633  -0.36450824 -2.393726  ]
0 - [ 2.615662    1.3387781   0.4709492   0.22549622 -1.2269818  -2.6299386 ]
1 - [ 0.6836314   1.0109553   0.64066494 -0.41428944  0.10632887 -1.1247447 ]
1 - [ 1.1609802   1.4995476   1.3646576  -1.3709086  -0.35733262 -2.410575  ]
0 - [ 2.0281775   1.5606738   0.37873617 -0.846067   -0.47783467 -1.8173572 ]
0 - [ 2.3341463   1.6004928   0.36223292  0.37536106  0.11536092 -1.8323961 ]
0 - [ 1.8937066  0.9592312  1.2808831 -0.8570316 -0

0 - [ 1.8180029   1.1839497   1.2093794  -0.2541181  -0.92903936 -0.997075  ]
0 - [ 1.4754885   1.2116903   0.9893915  -0.65525866 -0.5176006  -1.2830138 ]
0 - [ 2.3758678   1.6837522   0.842156   -0.37491783 -0.8808092  -1.5005459 ]
0 - [ 1.5849265   0.25353354  0.7665445  -0.38285956  0.22102445 -1.5399497 ]
0 - [ 2.861126   1.2803915  1.0179437 -0.9655003 -1.0308908 -0.8942005]
0 - [ 1.5651388   1.5014275   0.15563837 -1.002891   -1.3451306  -1.8479534 ]
0 - [ 1.7578545   0.9803216   1.1011512  -0.8469371  -0.77599335 -0.53200215]
0 - [ 1.9168444   1.3560575   0.97855765 -0.06934543 -0.5042182  -1.092753  ]
0 - [ 1.4980841   1.1703193   0.12052817 -0.704292   -0.6541359  -1.6054801 ]
0 - [ 1.8357779   1.0045129   0.5537155  -0.8290616  -1.2428933  -0.53623754]
0 - [ 1.215237    0.4212921   0.29973894 -0.2639375  -1.4827379  -1.4626315 ]
0 - [ 1.759163    1.1405112   1.5539769  -1.0256811   0.39784893 -0.76144004]
1 - [ 0.6263964   1.0847708   0.6123753   0.15599094 -1.7378505  -0.77

0 - [ 0.92059726 -0.34599972  0.3195837  -0.6070302  -1.2185344  -1.3860183 ]
0 - [ 2.0077767   0.09975764 -0.8782604  -0.6448444  -1.0883033  -2.2720978 ]
0 - [ 2.0177572  -0.4986518   0.07237803 -0.6353593  -1.0958298  -1.8657523 ]
0 - [ 2.2971344   0.47396523  0.5769637  -0.4228854  -0.3712366  -2.1391497 ]
0 - [ 1.9940598  -0.15543047  1.6678134  -0.60546094 -0.10038983 -1.8798486 ]
0 - [ 1.2462138   0.43392694  0.6902884  -0.9657635  -0.02421228  0.32351902]
0 - [ 2.2565143   1.0633001   0.5493036   0.90807754 -0.6290555  -1.4564822 ]
0 - [ 1.6981758   1.2939366   1.0572735  -0.41414842 -0.13039514 -1.3423041 ]
0 - [ 2.0303576   0.86589116 -0.03070914 -0.5418844  -0.57350093 -1.1791778 ]
2 - [ 1.0101473   0.19474147  1.1061116  -0.5904141  -0.02735902 -1.6585481 ]
0 - [ 1.8464547   0.20005232 -0.18203455 -0.881465   -1.01566    -1.2958653 ]
0 - [ 0.44811395  0.38408422  0.0527794  -0.4105492  -0.9608094  -0.8352801 ]
0 - [ 1.2500609   0.29299453  0.9974906   0.01669575 -0.49113843

0 - [ 1.1504694  -0.37251937  0.8692853  -0.8882046  -1.1583873  -1.8847808 ]
0 - [ 0.96748114 -0.23489329 -0.5954248  -0.3460866  -1.1268692  -1.1939718 ]
0 - [ 1.6502287  0.9599282 -0.6193301 -0.5875147 -0.8519133 -1.6088403]
0 - [ 2.0751853   0.67660886  0.54792    -0.06957336 -0.53066033 -0.67099696]
0 - [ 2.31139     0.5693854   0.28581527 -0.33743486 -0.6524154  -0.81626457]
0 - [ 2.719381    0.2837746   0.05143198  0.24480085 -1.004945   -1.6259726 ]
0 - [ 2.5311327   1.0117214   0.31763622 -0.01588923 -0.46717355 -1.2229806 ]
0 - [ 3.2248297   1.2160203   0.95880395 -0.05166652 -0.7681483  -2.3565874 ]
0 - [ 1.8840419   1.0863605   0.36583972 -0.7119295  -0.79769075 -0.86770344]
0 - [ 1.6775572  -0.26419142  1.2149214  -0.4774045  -1.056762   -1.4256197 ]
0 - [ 2.6536334   0.01286688  0.2353678  -0.43386978 -1.318728   -2.4481583 ]
0 - [ 1.6636255   0.2401475   0.6151871  -1.6835011  -1.1376636  -0.62997884]
0 - [ 2.4391434   1.2263318   0.26834044 -1.8024008  -1.5727189  -1.57