In [359]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import gym
from gym_utils import AtariEnv
from gym_utils import AtariFrame

import numpy as np
import random

# environment_name = "SpaceInvaders-v4"
# typical_bad_game_frame_count = 250
# reward_frame_shift = -15

environment_name = "Pong-v4"
typical_bad_game_frame_count = 1050
reward_frame_shift = -1

In [360]:
# define a pytorch model.  for now, accept a 210 x 160 greyscale image and output an array of actions


class AtariModel(nn.Module):

    def __init__(self, action_count, dropout=0.25):
        """
        Initialize the PyTorch AtariModel Module
        :param dropout: dropout to add in between LSTM/GRU layers
        """
        super(AtariModel, self).__init__()
        
        # convolutional layer 1  (in_channels, out_channels, kernel_size, stride=1, padding=0)
        self.conv1 = nn.Conv2d(3, 16, 3, stride=2, padding=1)
        # convolutional layer 2
        self.conv2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        # convolutional layer 3
        self.conv3 = nn.Conv2d(32, 64, 3, stride=1, padding=1)

        # max pooling layer
        self.maxpool = nn.MaxPool2d(2, 2)
        
        self.fc1 = nn.Linear(8320, 512)    #64 * 14 * 14 = 12544
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, action_count)
        
        self.dropout = nn.Dropout(0.25)

    def forward(self, img_array):
        """
        Forward propagation of the neural network
        :param img_array: The input img array to the neural network
        :return
        """
        ## Define forward behavior
        
        #print("forward received img_array of shape: {}".format(img_array.shape))
        
        #convolutional layers
        x = self.maxpool(F.relu(self.conv1(img_array)))
        x = self.maxpool(F.relu(self.conv2(x)))
        x = self.maxpool(F.relu(self.conv3(x)))  
        
        #flatten
        x = x.view(-1, 8320)  
        #print("x.view shape: {}".format(x.shape))  #torch.Size([1, 8320])
        
        #fc layers
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    


In [361]:
#play a game. feed each frame into the model and see what we get
def play_game(env_name, model, max_frames=5000):
    model.eval()
    atari_env = AtariEnv(environment_name, reward_frame_shift)
    current_action = 0
    done = False
    frame_counter = 0
    
    while not done:
        atari_frame = atari_env.step(current_action)
        img_array = atari_frame.img_array
        img_array = img_array.reshape((3,160,210))
        img_array = img_array.reshape((1,3,160,210))
        img_tensor = torch.from_numpy(img_array).float().cuda()
        output = model(img_tensor)
        action_array = output.detach().cpu().numpy()[0]
        atari_frame.action_array = action_array
        current_action = np.argmax(action_array)
        #print("{} - {}".format(current_action, output.detach().cpu().numpy()[0]))
        done = atari_frame.done_bool
        frame_counter += 1
        if frame_counter > max_frames:
            break

    atari_env.close()
    return atari_env

# def get_batch(atari_env, discounted_rewards, batch_size):
#     rand_arr = np.arange(len(discounted_rewards))
#     np.random.shuffle(rand_arr)
    
#     frame_batch = np.zeros(batch_size, 3, 160, 210)
#     reward_batch = np.zeros(batch_size)
    
#     for i in range(batch_size):
        
        
    
    
    
    
    
    
    



def train(atari_env, model, optimizer, criterion):
    model.train()
    action_count = atari_env.env.action_space.n
    discounted_rewards = atari_env.get_discounted_rewards()
    frame_buffer = atari_env.frame_buffer
    action_tally = np.zeros(action_count)
    train_tally = np.zeros(action_count)
    
    rand_arr = np.arange(len(discounted_rewards))
    np.random.shuffle(rand_arr)
    
    
    print("discounted_rewards mean: {}".format(np.mean(discounted_rewards)))
    if len(discounted_rewards) > typical_bad_game_frame_count:
        sorted_rewards = np.sort(discounted_rewards)
        desired_median = sorted_rewards[typical_bad_game_frame_count//2]
        discounted_rewards_mean = np.mean(discounted_rewards)
        #reward_mean_shift = discounted_rewards_mean - desired_median
        reward_mean_shift = (discounted_rewards_mean - desired_median)/2.0
        print("Shifting rewards by {}".format(reward_mean_shift))
        discounted_rewards = discounted_rewards + reward_mean_shift
        print("new discounted_rewards mean: {}".format(np.mean(discounted_rewards)))
        
    
    total_loss = 0
    for ii, reward_ii in enumerate(discounted_rewards):
        #print("{}: {}".format(i, reward))
        optimizer.zero_grad()
        i = rand_arr[ii]
        
        #get frame from the frame buffer and run it through the model
        atari_frame = atari_env.frame_buffer[i]
        reward = atari_frame.discounted_reward
        img_array = atari_frame.img_array
        img_array = img_array.reshape((3,160,210))
        img_array = img_array.reshape((1,3,160,210))
        img_tensor = torch.from_numpy(img_array).float().cuda()
        output = model(img_tensor)
        #print("train output: {}".format(output))
        
        #if the reward was positive, keep the same.  if not, choose lowest option
        action_array_from_model_in_training = output.detach().cpu().numpy()[0]
        action_array = atari_frame.action_array
        train_action = np.argmax(action_array)
        action_tally[train_action] += 1
        
        if reward < 0:
            train_action = np.argmin(action_array)
            #train_action = np.argsort(action_array)[1] #second highest
            #train_action = random.randint(0,action_count-1)

        if reward > -0.2 and reward < 0.2:
            continue
            
        if np.argmax(train_tally) == train_action and np.sum(train_tally) != 0:
            #keep things even to not introduce bias that will get it stuck on one action
            continue

        train_tally[train_action] += 1
        
        target = torch.empty(1, dtype=torch.int64)
        target[0] = int(train_action)
        target = target.cuda()
        
        loss = criterion(output, target)
        total_loss += loss
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
    
    print("avg loss: {:.3f}".format(total_loss / np.sum(train_tally)))
    print("model action_tally: {}".format(action_tally))
    print("train_tally:        {}".format(train_tally))
    
    
    

In [362]:
#new model
action_count = gym.make(environment_name).action_space.n
atari_model = AtariModel(action_count)
atari_model.cuda()

### loss function
atari_criterion = nn.CrossEntropyLoss()

### optimizer
atari_optimizer = optim.Adam(atari_model.parameters(), lr=0.00001)


In [None]:
for i in range(200):
    #play a game
    atari_env = play_game(environment_name, atari_model)

    #discounted_rewards = atari_env.get_discounted_rewards()
    #print()
    #print(discounted_rewards)
    print("\n{}) frames played: {}, score: {}".format(i, len(atari_env.frame_buffer), atari_env.get_total_score()))

    #train the model
    for ii in range(3):
        train(atari_env, atari_model, atari_optimizer, atari_criterion)



0) frames played: 993, score: -21.0
discounted_rewards mean: -1.6099910931120096e-17
avg loss: 0.862
model action_tally: [  0.   0. 993.   0.   0.   0.]
train_tally:        [  0.   0. 206.   0. 207.   0.]
discounted_rewards mean: -1.6099910931120096e-17
avg loss: 0.845
model action_tally: [  0.   0. 993.   0.   0.   0.]
train_tally:        [  0.   0. 203.   0. 203.   0.]
discounted_rewards mean: -1.6099910931120096e-17
avg loss: 0.776
model action_tally: [  0.   0. 993.   0.   0.   0.]
train_tally:        [  0.   0. 222.   0. 222.   0.]

1) frames played: 1033, score: -21.0
discounted_rewards mean: 1.1005502199575608e-16
avg loss: 0.906
model action_tally: [   0.    0.    0.    0. 1033.    0.]
train_tally:        [  0. 221.   0.   0. 221.   0.]
discounted_rewards mean: 1.1005502199575608e-16
avg loss: 0.783
model action_tally: [   0.    0.    0.    0. 1033.    0.]
train_tally:        [  0. 229.   0.   0. 230.   0.]
discounted_rewards mean: 1.1005502199575608e-16
avg loss: 0.754
model 

avg loss: 0.781
model action_tally: [723.   0.   0.   0. 293.   0.]
train_tally:        [212. 212.   0.   0.  66.   0.]
discounted_rewards mean: -1.3987061727561027e-17
avg loss: 0.728
model action_tally: [723.   0.   0.   0. 293.   0.]
train_tally:        [205. 205.   0.   0.  66.   0.]

15) frames played: 1029, score: -21.0
discounted_rewards mean: -8.631471522838924e-17
avg loss: 0.926
model action_tally: [701. 328.   0.   0.   0.   0.]
train_tally:        [212.  69.   0.   0.   0. 212.]
discounted_rewards mean: -8.631471522838924e-17
avg loss: 0.762
model action_tally: [701. 328.   0.   0.   0.   0.]
train_tally:        [198.  68.   0.   0.   0. 199.]
discounted_rewards mean: -8.631471522838924e-17
avg loss: 0.717
model action_tally: [701. 328.   0.   0.   0.   0.]
train_tally:        [208.  69.   0.   0.   0. 209.]

16) frames played: 1006, score: -21.0
discounted_rewards mean: -1.0594573594832507e-17
avg loss: 1.360
model action_tally: [484.   0.   0.   0.   0. 522.]
train_tally:

avg loss: 1.161
model action_tally: [362. 302.   0.   0. 241. 109.]
train_tally:        [134.  97.   0. 134.  78. 102.]

29) frames played: 1046, score: -21.0
discounted_rewards mean: 2.5473568442642216e-16
avg loss: 1.068
model action_tally: [550. 236.   0.  88.  33. 139.]
train_tally:        [197.  87. 197.  12.  19.  64.]
discounted_rewards mean: 2.5473568442642216e-16
avg loss: 0.866
model action_tally: [550. 236.   0.  88.  33. 139.]
train_tally:        [188.  87. 188.  12.  19.  64.]
discounted_rewards mean: 2.5473568442642216e-16
avg loss: 0.805
model action_tally: [550. 236.   0.  88.  33. 139.]
train_tally:        [185.  87. 185.  12.  19.  64.]

30) frames played: 1017, score: -21.0
discounted_rewards mean: -1.0130648641614014e-16
avg loss: 1.549
model action_tally: [357. 262. 273.   0.   0. 125.]
train_tally:        [103. 103.  67.   8.  73.  89.]
discounted_rewards mean: -1.0130648641614014e-16
avg loss: 1.338
model action_tally: [357. 262. 273.   0.   0. 125.]
train_tally:

avg loss: 0.909
model action_tally: [385. 151. 359.   0. 122.   0.]
train_tally:        [144.  81. 141. 144.  59. 144.]
discounted_rewards mean: -2.72479515188239e-16
avg loss: 0.769
model action_tally: [385. 151. 359.   0. 122.   0.]
train_tally:        [138.  82. 139. 139.  59. 140.]

42) frames played: 1021, score: -21.0
discounted_rewards mean: 1.0612905700628332e-16
avg loss: 1.300
model action_tally: [264. 151. 285. 123.  39. 159.]
train_tally:        [154. 155. 130.   4. 154.  45.]
discounted_rewards mean: 1.0612905700628332e-16
avg loss: 1.027
model action_tally: [264. 151. 285. 123.  39. 159.]
train_tally:        [150. 150. 136.   4. 151.  46.]
discounted_rewards mean: 1.0612905700628332e-16
avg loss: 0.929
model action_tally: [264. 151. 285. 123.  39. 159.]
train_tally:        [154. 155. 134.   4. 152.  46.]

43) frames played: 1000, score: -21.0
discounted_rewards mean: 1.9184653865522705e-16
avg loss: 1.146
model action_tally: [277. 227. 169.   0. 285.  42.]
train_tally:   

avg loss: 0.550
model action_tally: [253. 551.   0.  27.   0. 187.]
train_tally:        [169. 171. 172.   1.  44.  75.]

54) frames played: 1020, score: -21.0
discounted_rewards mean: 1.2713142085903753e-16
avg loss: 1.027
model action_tally: [294. 262. 328.   0.  73.  63.]
train_tally:        [162. 154.  44. 162.  14. 163.]
discounted_rewards mean: 1.2713142085903753e-16
avg loss: 0.804
model action_tally: [294. 262. 328.   0.  73.  63.]
train_tally:        [160. 159.  44. 156.  14. 161.]
discounted_rewards mean: 1.2713142085903753e-16
avg loss: 0.748
model action_tally: [294. 262. 328.   0.  73.  63.]
train_tally:        [157. 157.  44. 158.  14. 157.]

55) frames played: 1028, score: -21.0
discounted_rewards mean: 2.3500440676890473e-16
avg loss: 0.952
model action_tally: [288. 269.   0. 197.   0. 274.]
train_tally:        [180. 181.   0.   7.  42.  86.]
discounted_rewards mean: 2.3500440676890473e-16
avg loss: 0.745
model action_tally: [288. 269.   0. 197.   0. 274.]
train_tally:  

avg loss: 1.039
model action_tally: [548.  37.   0.   9.   9. 505.]
train_tally:        [142. 142.  43.  12.   4.  84.]

67) frames played: 1035, score: -21.0
discounted_rewards mean: -3.0721533744216894e-16
avg loss: 0.649
model action_tally: [260. 494.   0.   0.   0. 281.]
train_tally:        [176. 172.   0.   0. 177. 106.]
discounted_rewards mean: -3.0721533744216894e-16
avg loss: 0.491
model action_tally: [260. 494.   0.   0.   0. 281.]
train_tally:        [170. 171.   0.   0. 172. 106.]
discounted_rewards mean: -3.0721533744216894e-16
avg loss: 0.423
model action_tally: [260. 494.   0.   0.   0. 281.]
train_tally:        [166. 166.   0.   0. 166. 106.]

68) frames played: 1017, score: -21.0
discounted_rewards mean: -1.7117302877209885e-16
avg loss: 0.704
model action_tally: [281. 227.   0.   0. 406. 103.]
train_tally:        [141. 141. 141. 103.  25.  99.]
discounted_rewards mean: -1.7117302877209885e-16
avg loss: 0.511
model action_tally: [281. 227.   0.   0. 406. 103.]
train_tal


79) frames played: 1012, score: -21.0
discounted_rewards mean: 1.4042346556523718e-17
avg loss: 1.011
model action_tally: [220. 232.  92. 163.   2. 303.]
train_tally:        [141. 142.  79.  53.  53.  90.]
discounted_rewards mean: 1.4042346556523718e-17
avg loss: 0.745
model action_tally: [220. 232.  92. 163.   2. 303.]
train_tally:        [135. 135.  79.  53.  53.  90.]
discounted_rewards mean: 1.4042346556523718e-17
avg loss: 0.678
model action_tally: [220. 232.  92. 163.   2. 303.]
train_tally:        [137. 137.  78.  53.  53.  90.]

80) frames played: 998, score: -21.0
discounted_rewards mean: 4.9837666836880775e-17
avg loss: 0.742
model action_tally: [368. 213. 142.  77. 112.  86.]
train_tally:        [134. 132. 110.   3. 117. 135.]
discounted_rewards mean: 4.9837666836880775e-17
avg loss: 0.500
model action_tally: [368. 213. 142.  77. 112.  86.]
train_tally:        [135. 137. 110.   3. 117. 137.]
discounted_rewards mean: 4.9837666836880775e-17
avg loss: 0.476
model action_tally:

avg loss: 0.745
model action_tally: [223. 382. 182. 125.  28.  69.]
train_tally:        [131. 130.  85.  44. 125. 132.]
discounted_rewards mean: -5.633639133875918e-17
avg loss: 0.465
model action_tally: [223. 382. 182. 125.  28.  69.]
train_tally:        [134. 133.  85.  44. 122. 135.]
discounted_rewards mean: -5.633639133875918e-17
avg loss: 0.422
model action_tally: [223. 382. 182. 125.  28.  69.]
train_tally:        [134. 134.  85.  44. 122. 135.]

92) frames played: 1175, score: -20.0
discounted_rewards mean: 9.675475550775832e-17
Shifting rewards by 0.04276449026526491
new discounted_rewards mean: 0.04276449026526499
avg loss: 1.307
model action_tally: [150.  57. 190.  23.  31. 724.]
train_tally:        [187. 188.  61.  31.  20. 188.]
discounted_rewards mean: 9.675475550775832e-17
Shifting rewards by 0.04276449026526491
new discounted_rewards mean: 0.04276449026526499
avg loss: 1.058
model action_tally: [150.  57. 190.  23.  31. 724.]
train_tally:        [190. 187.  61.  31.  20.


104) frames played: 1019, score: -21.0
discounted_rewards mean: -2.091882440903141e-17
avg loss: 0.830
model action_tally: [612. 164.  26.  25. 121.  71.]
train_tally:        [152. 152. 153.   5. 154.  80.]
discounted_rewards mean: -2.091882440903141e-17
avg loss: 0.643
model action_tally: [612. 164.  26.  25. 121.  71.]
train_tally:        [151. 151. 151.   5. 149.  80.]
discounted_rewards mean: -2.091882440903141e-17
avg loss: 0.598
model action_tally: [612. 164.  26.  25. 121.  71.]
train_tally:        [151. 152. 153.   5. 151.  80.]

105) frames played: 1027, score: -21.0
discounted_rewards mean: 1.2453524093166314e-16
avg loss: 0.664
model action_tally: [153. 158. 297.   0. 350.  69.]
train_tally:        [129. 128.  57. 129. 100. 127.]
discounted_rewards mean: 1.2453524093166314e-16
avg loss: 0.401
model action_tally: [153. 158. 297.   0. 350.  69.]
train_tally:        [127. 122.  57. 129. 101. 130.]
discounted_rewards mean: 1.2453524093166314e-16
avg loss: 0.342
model action_tal

avg loss: 0.861
model action_tally: [196.  76. 145.   1. 564. 176.]
train_tally:        [ 79. 151.  83.  74. 151.  75.]
discounted_rewards mean: 4.908758105423835e-17
Shifting rewards by 0.027250641415619475
new discounted_rewards mean: 0.02725064141561955
avg loss: 0.783
model action_tally: [196.  76. 145.   1. 564. 176.]
train_tally:        [ 79. 154.  83.  74. 154.  75.]

117) frames played: 1161, score: -20.0
discounted_rewards mean: -1.9584295903809824e-16
Shifting rewards by 0.030004636237914117
new discounted_rewards mean: 0.030004636237913957
avg loss: 1.071
model action_tally: [122. 414. 131.  72. 221. 201.]
train_tally:        [168. 170.  29.   6. 170.  95.]
discounted_rewards mean: -1.9584295903809824e-16
Shifting rewards by 0.030004636237914117
new discounted_rewards mean: 0.030004636237913957
avg loss: 0.876
model action_tally: [122. 414. 131.  72. 221. 201.]
train_tally:        [177. 172.  29.   6. 177.  95.]
discounted_rewards mean: -1.9584295903809824e-16
Shifting rewar

avg loss: 1.043
model action_tally: [517. 143. 132.   2. 382.  70.]
train_tally:        [140.  73.  92. 164. 164.  39.]
discounted_rewards mean: -4.5620721397117187e-17
Shifting rewards by 0.09414801767794362
new discounted_rewards mean: 0.0941480176779436
avg loss: 0.945
model action_tally: [517. 143. 132.   2. 382.  70.]
train_tally:        [142.  73.  92. 178. 179.  39.]

130) frames played: 1521, score: -20.0
discounted_rewards mean: 1.4014649620514797e-16
Shifting rewards by 0.1631584226899891
new discounted_rewards mean: 0.16315842268998926
avg loss: 1.519
model action_tally: [289. 247. 281. 283. 262. 159.]
train_tally:        [134. 171. 136. 171. 170. 172.]
discounted_rewards mean: 1.4014649620514797e-16
Shifting rewards by 0.1631584226899891
new discounted_rewards mean: 0.16315842268998926
avg loss: 1.306
model action_tally: [289. 247. 281. 283. 262. 159.]
train_tally:        [134. 177. 135. 175. 167. 178.]
discounted_rewards mean: 1.4014649620514797e-16
Shifting rewards by 0.1


142) frames played: 1015, score: -21.0
discounted_rewards mean: 2.24450999658209e-16
avg loss: 0.496
model action_tally: [141. 223. 329.  64.  90. 168.]
train_tally:        [115. 115.  96.  42.  86.  91.]
discounted_rewards mean: 2.24450999658209e-16
avg loss: 0.305
model action_tally: [141. 223. 329.  64.  90. 168.]
train_tally:        [116. 116.  94.  42.  86.  91.]
discounted_rewards mean: 2.24450999658209e-16
avg loss: 0.272
model action_tally: [141. 223. 329.  64.  90. 168.]
train_tally:        [113. 113.  92.  42.  86.  91.]

143) frames played: 1012, score: -21.0
discounted_rewards mean: -9.303054593696964e-17
avg loss: 0.637
model action_tally: [441. 140.  98.  59. 183.  91.]
train_tally:        [120. 139. 139.  16.  60. 133.]
discounted_rewards mean: -9.303054593696964e-17
avg loss: 0.313
model action_tally: [441. 140.  98.  59. 183.  91.]
train_tally:        [122. 136. 137.  16.  60. 132.]
discounted_rewards mean: -9.303054593696964e-17
avg loss: 0.303
model action_tally: [4

In [276]:
from gym_utils import AtariEnv
from gym_utils import AtariFrame

#display frame
frame_num=600
discounted_rewards = atari_env.get_discounted_rewards()
discounted_rewards_mean_shifted = atari_env.get_discounted_rewards()

print("discounted_rewards mean: {}".format(np.mean(discounted_rewards)))
if len(discounted_rewards_mean_shifted) > typical_bad_game_frame_count:
    sorted_rewards = np.sort(discounted_rewards_mean_shifted)
    desired_median = sorted_rewards[typical_bad_game_frame_count//2]
    discounted_rewards_mean = np.mean(discounted_rewards_mean_shifted)
    reward_shift = (discounted_rewards_mean - desired_median)/2.0
    print("Shifting rewards by {}".format(reward_shift))
    discounted_rewards_mean_shifted = discounted_rewards_mean_shifted + reward_shift
    print("new discounted_rewards mean: {}".format(np.mean(discounted_rewards_mean_shifted)))

discounted_rewards mean: 5.4898445367778163e-17
Shifting rewards by 0.48340174895687293
new discounted_rewards mean: 0.48340174895687305


In [358]:
frame_num += 4
atari_frame = atari_env.frame_buffer[frame_num]

print("frame: {}, original reward: {:.3f}, shifted reward: {:.3f}".format(
    frame_num, discounted_rewards[frame_num], discounted_rewards_shifted[frame_num]))
atari_frame.show_frame()

IndexError: deque index out of range