In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym
from gym_utils import AtariEnv
from gym_utils import AtariFrame

import numpy as np
import random

environment_name = "SpaceInvaders-v4"
#environment_name = "Pong-v4"
typical_bad_game_frame_count = 200


Discrete(6)


In [2]:
# define a pytorch model.  for now, accept a 210 x 160 greyscale image and output an array of actions


class AtariModel(nn.Module):

    def __init__(self, action_count, dropout=0.25):
        """
        Initialize the PyTorch AtariModel Module
        :param dropout: dropout to add in between LSTM/GRU layers
        """
        super(AtariModel, self).__init__()
        
        # convolutional layer 1  (in_channels, out_channels, kernel_size, stride=1, padding=0)
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        # convolutional layer 2
        self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        # convolutional layer 3
        self.conv3 = nn.Conv2d(32, 64, 3, stride=1, padding=1)

        # max pooling layer
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(8320, 512)    #64 * 14 * 14 = 12544
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, action_count)
        
        self.dropout = nn.Dropout(0.25)

    def forward(self, img_array):
        """
        Forward propagation of the neural network
        :param img_array: The input img array to the neural network
        :return
        """
        ## Define forward behavior
        
        #print("forward received img_array of shape: {}".format(img_array.shape))
        
        #convolutional layers
        x = self.maxpool(F.relu(self.conv1(img_array)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = self.maxpool(F.relu(self.conv3(x)))  
        
        #flatten
        x = x.view(-1, 8320)  
        #print("x.view shape: {}".format(x.shape))  #torch.Size([1, 8320])
        
        #fc layers
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    


In [3]:
#play a game. feed each frame into the model and see what we get
def play_game(env_name, model, max_frames=5000):
    model.eval()
    atari_env = AtariEnv(environment_name)
    current_action = 0
    done = False
    frame_counter = 0
    
    while not done:
        atari_frame = atari_env.step(current_action)
        img_array = atari_frame.img_array
        img_array = img_array.reshape((3,160,210))
        img_array = img_array.reshape((1,3,160,210))
        img_tensor = torch.from_numpy(img_array).float().cuda()
        output = model(img_tensor)
        action_array = output.detach().cpu().numpy()[0]
        atari_frame.action_array = action_array
        current_action = np.argmax(action_array)
        #print("{} - {}".format(current_action, output.detach().cpu().numpy()[0]))
        done = atari_frame.done_bool
        frame_counter += 1
        if frame_counter > max_frames:
            break

    atari_env.close()
    return atari_env


def train(atari_env, model, optimizer, criterion):
    model.train()
    action_count = atari_env.env.action_space.n
    discounted_rewards = atari_env.get_discounted_rewards()
    frame_buffer = atari_env.frame_buffer
    action_tally = np.zeros(action_count)
    train_tally = np.zeros(action_count)
    
    rand_arr = np.arange(len(discounted_rewards))
    np.random.shuffle(rand_arr)
    
    
    print("discounted_rewards mean: {}".format(np.mean(discounted_rewards)))
    if len(discounted_rewards) > typical_bad_game_frame_count:
        sorted_rewards = np.sort(discounted_rewards)
        desired_median = sorted_rewards[typical_bad_game_frame_count//2]
        discounted_rewards_mean = np.mean(discounted_rewards)
        reward_shift = discounted_rewards_mean - desired_median
        print("Shifting rewards by {}".format(reward_shift))
        discounted_rewards = discounted_rewards + reward_shift
        print("new discounted_rewards mean: {}".format(np.mean(discounted_rewards)))
        
    
    total_loss = 0
    for ii, reward_ii in enumerate(discounted_rewards):
        #print("{}: {}".format(i, reward))
        optimizer.zero_grad()
        i = rand_arr[ii]
        
        #get frame from the frame buffer and run it through the model
        atari_frame = atari_env.frame_buffer[i]
        reward = atari_frame.discounted_reward
        img_array = atari_frame.img_array
        img_array = img_array.reshape((3,160,210))
        img_array = img_array.reshape((1,3,160,210))
        img_tensor = torch.from_numpy(img_array).float().cuda()
        output = model(img_tensor)
        #print("train output: {}".format(output))
        
        #if the reward was positive, keep the same.  if not, choose lowest option
        action_array_from_model_in_training = output.detach().cpu().numpy()[0]
        action_array = atari_frame.action_array
        train_action = np.argmax(action_array)
        action_tally[train_action] += 1
        
        if reward < 0:
            train_action = np.argmin(action_array)
            #train_action = np.argsort(action_array)[1] #second highest
            #train_action = random.randint(0,action_count-1)

        if reward > -0.5 and reward < 0.5:
            continue
            
        if np.argmax(train_tally) == train_action and np.sum(train_tally) != 0:
            #keep things even to not introduce bias that will get it stuck on one action
            continue

        train_tally[train_action] += 1
        
        target = torch.empty(1, dtype=torch.int64)
        target[0] = int(train_action)
        target = target.cuda()
        
        loss = criterion(output, target)
        total_loss += loss
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
    
    print("avg loss: {:.3f}".format(total_loss / np.sum(train_tally)))
    print("model action_tally: {}".format(action_tally))
    print("train_tally:        {}".format(train_tally))
    
    
    

In [4]:
#new model
action_count = gym.make(environment_name).action_space.n
atari_model = AtariModel(action_count)
atari_model.cuda()

### loss function
atari_criterion = nn.CrossEntropyLoss()

### optimizer
atari_optimizer = optim.Adam(atari_model.parameters(), lr=0.00001)


In [None]:
for i in range(1000):
    #play a game
    atari_env = play_game(environment_name, atari_model)

    #discounted_rewards = atari_env.get_discounted_rewards()
    #print()
    #print(discounted_rewards)
    print("\n{}) frames played: {}, score: {}".format(i, len(atari_env.frame_buffer), atari_env.get_total_score()))

    #train the model
    for ii in range(3):
        train(atari_env, atari_model, atari_optimizer, atari_criterion)



0) frames played: 615, score: 110.0
discounted_rewards mean: 1.38642485026361e-16
Shifting rewards by 0.5671103310851955
new discounted_rewards mean: 0.5671103310851956
avg loss: 1.033
model action_tally: [  0. 404.   0.   0. 170.  41.]
train_tally:        [88. 86. 88.  0.  1. 24.]
discounted_rewards mean: 1.38642485026361e-16
Shifting rewards by 0.5671103310851955
new discounted_rewards mean: 0.5671103310851956
avg loss: 0.681
model action_tally: [  0. 404.   0.   0. 170.  41.]
train_tally:        [89. 88. 89.  0.  1. 24.]
discounted_rewards mean: 1.38642485026361e-16
Shifting rewards by 0.5671103310851955
new discounted_rewards mean: 0.5671103310851956
avg loss: 0.531
model action_tally: [  0. 404.   0.   0. 170.  41.]
train_tally:        [89. 89. 89.  0.  1. 24.]

1) frames played: 673, score: 105.0
discounted_rewards mean: -1.2669409850105798e-16
Shifting rewards by 0.4968948268678638
new discounted_rewards mean: 0.49689482686786385
avg loss: 0.467
model action_tally: [273.  67. 2

avg loss: 0.019
model action_tally: [562. 263.  30.  28.  86. 291.]
train_tally:        [ 8. 37.  0.  0.  8. 38.]
discounted_rewards mean: -9.022764898540955e-17
Shifting rewards by 0.45897697429586315
new discounted_rewards mean: 0.4589769742958631
avg loss: 0.008
model action_tally: [562. 263.  30.  28.  86. 291.]
train_tally:        [ 8. 34.  0.  0.  8. 34.]

11) frames played: 919, score: 110.0
discounted_rewards mean: 6.185355697585203e-17
Shifting rewards by 0.513159329645849
new discounted_rewards mean: 0.513159329645849
avg loss: 1.001
model action_tally: [348. 245.  20.  21.  13. 272.]
train_tally:        [59. 91. 98. 23. 62. 99.]
discounted_rewards mean: 6.185355697585203e-17
Shifting rewards by 0.513159329645849
new discounted_rewards mean: 0.513159329645849
avg loss: 0.260
model action_tally: [348. 245.  20.  21.  13. 272.]
train_tally:        [60. 91. 97. 23. 62. 97.]
discounted_rewards mean: 6.185355697585203e-17
Shifting rewards by 0.513159329645849
new discounted_reward


21) frames played: 622, score: 120.0
discounted_rewards mean: -6.854109991254986e-17
Shifting rewards by 0.800317779413856
new discounted_rewards mean: 0.8003177794138558
avg loss: 1.151
model action_tally: [ 42. 298. 127.  41.  22.  92.]
train_tally:        [74. 74. 29.  2. 52. 75.]
discounted_rewards mean: -6.854109991254986e-17
Shifting rewards by 0.800317779413856
new discounted_rewards mean: 0.8003177794138558
avg loss: 0.504
model action_tally: [ 42. 298. 127.  41.  22.  92.]
train_tally:        [77. 78. 27.  2. 52. 77.]
discounted_rewards mean: -6.854109991254986e-17
Shifting rewards by 0.800317779413856
new discounted_rewards mean: 0.8003177794138558
avg loss: 0.372
model action_tally: [ 42. 298. 127.  41.  22.  92.]
train_tally:        [77. 77. 29.  2. 52. 77.]

22) frames played: 1179, score: 335.0
discounted_rewards mean: -2.410662377472774e-17
Shifting rewards by 0.9543259457643467
new discounted_rewards mean: 0.9543259457643467
avg loss: 0.938
model action_tally: [105. 12

avg loss: 0.509
model action_tally: [ 11.  91. 153.  56. 113. 277.]
train_tally:        [75. 77. 58. 27. 28. 77.]
discounted_rewards mean: 1.4190582454552642e-16
Shifting rewards by 1.0983102338750461
new discounted_rewards mean: 1.0983102338750463
avg loss: 0.419
model action_tally: [ 11.  91. 153.  56. 113. 277.]
train_tally:        [73. 80. 60. 27. 28. 80.]

32) frames played: 498, score: 70.0
discounted_rewards mean: -5.707170568354218e-17
Shifting rewards by 0.7305356853030974
new discounted_rewards mean: 0.7305356853030973
avg loss: 0.929
model action_tally: [ 37. 175.  99.  39.  25. 123.]
train_tally:        [47. 47. 38. 43. 47. 36.]
discounted_rewards mean: -5.707170568354218e-17
Shifting rewards by 0.7305356853030974
new discounted_rewards mean: 0.7305356853030973
avg loss: 0.300
model action_tally: [ 37. 175.  99.  39.  25. 123.]
train_tally:        [46. 47. 42. 43. 48. 36.]
discounted_rewards mean: -5.707170568354218e-17
Shifting rewards by 0.7305356853030974
new discounted_


42) frames played: 552, score: 80.0
discounted_rewards mean: 0.0
Shifting rewards by 0.7028491477283041
new discounted_rewards mean: 0.7028491477283041
avg loss: 1.088
model action_tally: [ 85.  33. 251. 117.  35.  31.]
train_tally:        [64. 65. 35. 22. 25. 66.]
discounted_rewards mean: 0.0
Shifting rewards by 0.7028491477283041
new discounted_rewards mean: 0.7028491477283041
avg loss: 0.407
model action_tally: [ 85.  33. 251. 117.  35.  31.]
train_tally:        [64. 69. 34. 22. 25. 70.]
discounted_rewards mean: 0.0
Shifting rewards by 0.7028491477283041
new discounted_rewards mean: 0.7028491477283041
avg loss: 0.295
model action_tally: [ 85.  33. 251. 117.  35.  31.]
train_tally:        [70. 71. 36. 22. 24. 70.]

43) frames played: 860, score: 210.0
discounted_rewards mean: -5.783487384093839e-17
Shifting rewards by 0.7482173514686549
new discounted_rewards mean: 0.7482173514686549
avg loss: 1.302
model action_tally: [144. 101.  74.  34.  60. 447.]
train_tally:        [ 33. 101. 1

avg loss: 0.507
model action_tally: [ 68. 202. 107. 165. 208. 224.]
train_tally:        [118. 118. 112.  26.  60. 119.]
discounted_rewards mean: 5.836079965175361e-17
Shifting rewards by 1.0290934893636208
new discounted_rewards mean: 1.029093489363621
avg loss: 0.332
model action_tally: [ 68. 202. 107. 165. 208. 224.]
train_tally:        [113. 111. 114.  26.  60. 115.]

53) frames played: 527, score: 155.0
discounted_rewards mean: 2.6965568719548395e-17
Shifting rewards by 1.062374574938902
new discounted_rewards mean: 1.062374574938902
avg loss: 0.995
model action_tally: [ 34.  28. 209.  60.  43. 153.]
train_tally:        [36. 37. 32. 53. 54. 54.]
discounted_rewards mean: 2.6965568719548395e-17
Shifting rewards by 1.062374574938902
new discounted_rewards mean: 1.062374574938902
avg loss: 0.528
model action_tally: [ 34.  28. 209.  60.  43. 153.]
train_tally:        [38. 36. 32. 55. 56. 54.]
discounted_rewards mean: 2.6965568719548395e-17
Shifting rewards by 1.062374574938902
new disco


63) frames played: 1275, score: 375.0
discounted_rewards mean: 3.3437305212240006e-17
Shifting rewards by 0.9380866958276117
new discounted_rewards mean: 0.9380866958276116
avg loss: 1.082
model action_tally: [140.  31.  11. 251. 418. 424.]
train_tally:        [ 32. 149. 147. 149. 150. 146.]
discounted_rewards mean: 3.3437305212240006e-17
Shifting rewards by 0.9380866958276117
new discounted_rewards mean: 0.9380866958276116
avg loss: 0.484
model action_tally: [140.  31.  11. 251. 418. 424.]
train_tally:        [ 32. 149. 150. 149. 150. 139.]
discounted_rewards mean: 3.3437305212240006e-17
Shifting rewards by 0.9380866958276117
new discounted_rewards mean: 0.9380866958276116
avg loss: 0.345
model action_tally: [140.  31.  11. 251. 418. 424.]
train_tally:        [ 32. 151. 149. 148. 151. 144.]

64) frames played: 684, score: 140.0
discounted_rewards mean: -9.349246523159212e-17
Shifting rewards by 0.8460866998194763
new discounted_rewards mean: 0.8460866998194762
avg loss: 1.408
model a

avg loss: 0.568
model action_tally: [112.  97.  44.  63.  92. 418.]
train_tally:        [63. 79. 98. 95. 31. 98.]

74) frames played: 605, score: 105.0
discounted_rewards mean: 9.395606423274052e-17
Shifting rewards by 0.5780300536650621
new discounted_rewards mean: 0.5780300536650622
avg loss: 1.425
model action_tally: [ 68. 116. 161. 140.  23.  97.]
train_tally:        [75. 77. 21. 41. 77. 78.]
discounted_rewards mean: 9.395606423274052e-17
Shifting rewards by 0.5780300536650621
new discounted_rewards mean: 0.5780300536650622
avg loss: 0.703
model action_tally: [ 68. 116. 161. 140.  23.  97.]
train_tally:        [73. 77. 21. 41. 77. 78.]
discounted_rewards mean: 9.395606423274052e-17
Shifting rewards by 0.5780300536650621
new discounted_rewards mean: 0.5780300536650622
avg loss: 0.492
model action_tally: [ 68. 116. 161. 140.  23.  97.]
train_tally:        [75. 82. 21. 40. 81. 82.]

75) frames played: 823, score: 225.0
discounted_rewards mean: -7.770212177206442e-17
Shifting rewards b

avg loss: 0.567
model action_tally: [ 20.  59. 102. 159.  98.  56.]
train_tally:        [55. 45. 52. 41. 56. 56.]
discounted_rewards mean: -1.150676495158057e-16
Shifting rewards by 0.6357052480918024
new discounted_rewards mean: 0.6357052480918025
avg loss: 0.446
model action_tally: [ 20.  59. 102. 159.  98.  56.]
train_tally:        [57. 47. 52. 41. 57. 58.]

85) frames played: 643, score: 115.0
discounted_rewards mean: -5.525215674650857e-17
Shifting rewards by 0.7673262164584065
new discounted_rewards mean: 0.7673262164584066
avg loss: 1.381
model action_tally: [155.  41. 111.  91. 130. 115.]
train_tally:        [71. 59. 69. 65. 26. 72.]
discounted_rewards mean: -5.525215674650857e-17
Shifting rewards by 0.7673262164584065
new discounted_rewards mean: 0.7673262164584066
avg loss: 0.802
model action_tally: [155.  41. 111.  91. 130. 115.]
train_tally:        [68. 59. 66. 66. 26. 69.]
discounted_rewards mean: -5.525215674650857e-17
Shifting rewards by 0.7673262164584065
new discounted

avg loss: 1.258
model action_tally: [142. 145.  51. 208. 297. 256.]
train_tally:        [ 41. 137.  99.  48. 139. 139.]
discounted_rewards mean: -1.2930714026571432e-17
Shifting rewards by 1.0796403207409146
new discounted_rewards mean: 1.0796403207409146
avg loss: 0.707
model action_tally: [142. 145.  51. 208. 297. 256.]
train_tally:        [ 41. 144.  99.  48. 146. 146.]
discounted_rewards mean: -1.2930714026571432e-17
Shifting rewards by 1.0796403207409146
new discounted_rewards mean: 1.0796403207409146
avg loss: 0.499
model action_tally: [142. 145.  51. 208. 297. 256.]
train_tally:        [ 41. 137.  99.  48. 138. 137.]

96) frames played: 531, score: 160.0
discounted_rewards mean: 1.0704975303353675e-16
Shifting rewards by 0.8518270777275052
new discounted_rewards mean: 0.8518270777275053
avg loss: 1.100
model action_tally: [ 98. 111. 110.  32.  80. 100.]
train_tally:        [56. 57. 27. 37. 51. 58.]
discounted_rewards mean: 1.0704975303353675e-16
Shifting rewards by 0.85182707772

avg loss: 0.550
model action_tally: [ 66. 108. 105.  78. 175.  96.]
train_tally:        [52. 36. 60. 57. 60. 60.]

106) frames played: 1529, score: 915.0
discounted_rewards mean: 4.647107493525835e-17
Shifting rewards by 0.698867867125117
new discounted_rewards mean: 0.6988678671251168
avg loss: 1.398
model action_tally: [ 68. 389. 329. 261. 341. 141.]
train_tally:        [114. 115.  54.  53. 114.  90.]
discounted_rewards mean: 4.647107493525835e-17
Shifting rewards by 0.698867867125117
new discounted_rewards mean: 0.6988678671251168
avg loss: 0.869
model action_tally: [ 68. 389. 329. 261. 341. 141.]
train_tally:        [113. 115.  54.  53. 115.  93.]
discounted_rewards mean: 4.647107493525835e-17
Shifting rewards by 0.698867867125117
new discounted_rewards mean: 0.6988678671251168
avg loss: 0.727
model action_tally: [ 68. 389. 329. 261. 341. 141.]
train_tally:        [102. 103.  54.  53. 102.  94.]

107) frames played: 590, score: 120.0
discounted_rewards mean: 4.8172388865091537e-17


avg loss: 0.691
model action_tally: [141.  68. 244. 196. 228. 385.]
train_tally:        [85. 67. 53. 44. 88. 89.]
discounted_rewards mean: -6.193320200761571e-17
Shifting rewards by 0.6254066781209174
new discounted_rewards mean: 0.6254066781209174
avg loss: 0.500
model action_tally: [141.  68. 244. 196. 228. 385.]
train_tally:        [87. 61. 53. 44. 93. 94.]

117) frames played: 896, score: 275.0
discounted_rewards mean: -1.2688263138573217e-16
Shifting rewards by 0.9475619861011093
new discounted_rewards mean: 0.9475619861011092
avg loss: 1.415
model action_tally: [109.  46.  17. 105. 151. 468.]
train_tally:        [90. 92. 90. 41. 43. 92.]
discounted_rewards mean: -1.2688263138573217e-16
Shifting rewards by 0.9475619861011093
new discounted_rewards mean: 0.9475619861011092
avg loss: 0.864
model action_tally: [109.  46.  17. 105. 151. 468.]
train_tally:        [99. 94. 94. 41. 43. 99.]
discounted_rewards mean: -1.2688263138573217e-16
Shifting rewards by 0.9475619861011093
new discou


127) frames played: 811, score: 200.0
discounted_rewards mean: 7.885184490555982e-17
Shifting rewards by 0.8614179049868024
new discounted_rewards mean: 0.8614179049868023
avg loss: 1.251
model action_tally: [106. 238. 122.  77.  87. 181.]
train_tally:        [55. 96. 21. 71. 96. 97.]
discounted_rewards mean: 7.885184490555982e-17
Shifting rewards by 0.8614179049868024
new discounted_rewards mean: 0.8614179049868023
avg loss: 0.696
model action_tally: [106. 238. 122.  77.  87. 181.]
train_tally:        [ 55. 102.  21.  71. 106. 107.]
discounted_rewards mean: 7.885184490555982e-17
Shifting rewards by 0.8614179049868024
new discounted_rewards mean: 0.8614179049868023
avg loss: 0.471
model action_tally: [106. 238. 122.  77.  87. 181.]
train_tally:        [ 55. 100.  21.  71. 109. 109.]

128) frames played: 567, score: 105.0
discounted_rewards mean: -1.2531617914640214e-16
Shifting rewards by 0.8017951439456878
new discounted_rewards mean: 0.8017951439456877
avg loss: 1.243
model action_t

avg loss: 0.526
model action_tally: [ 54. 101. 157.  81.  51. 174.]
train_tally:        [78. 77. 74. 53. 66. 78.]

138) frames played: 971, score: 285.0
discounted_rewards mean: -2.9270555541095784e-17
Shifting rewards by 0.829015025958447
new discounted_rewards mean: 0.829015025958447
avg loss: 1.193
model action_tally: [193. 311.  36.  92.  90. 249.]
train_tally:        [103.  69.  98. 103. 103. 104.]
discounted_rewards mean: -2.9270555541095784e-17
Shifting rewards by 0.829015025958447
new discounted_rewards mean: 0.829015025958447
avg loss: 0.588
model action_tally: [193. 311.  36.  92.  90. 249.]
train_tally:        [109.  69. 103. 104. 109. 110.]
discounted_rewards mean: -2.9270555541095784e-17
Shifting rewards by 0.829015025958447
new discounted_rewards mean: 0.829015025958447
avg loss: 0.447
model action_tally: [193. 311.  36.  92.  90. 249.]
train_tally:        [102.  69. 101. 102. 102. 103.]

139) frames played: 802, score: 180.0
discounted_rewards mean: 8.859635109228182e-17

avg loss: 0.976
model action_tally: [176. 161. 118.  95.  39. 455.]
train_tally:        [70. 68. 33. 65. 39. 71.]
discounted_rewards mean: 5.4447719215333345e-17
Shifting rewards by 0.6879137779304241
new discounted_rewards mean: 0.687913777930424
avg loss: 0.640
model action_tally: [176. 161. 118.  95.  39. 455.]
train_tally:        [70. 68. 33. 65. 39. 71.]

149) frames played: 846, score: 210.0
discounted_rewards mean: 5.039310182695746e-17
Shifting rewards by 0.6652422840947236
new discounted_rewards mean: 0.6652422840947236
avg loss: 1.411
model action_tally: [124. 195.  13. 206.  76. 232.]
train_tally:        [109.  69. 108. 100.  57. 110.]
discounted_rewards mean: 5.039310182695746e-17
Shifting rewards by 0.6652422840947236
new discounted_rewards mean: 0.6652422840947236
avg loss: 0.878
model action_tally: [124. 195.  13. 206.  76. 232.]
train_tally:        [106.  73. 107. 101.  57. 107.]
discounted_rewards mean: 5.039310182695746e-17
Shifting rewards by 0.6652422840947236
new d