<a href="https://colab.research.google.com/github/nicole-sb/atari-HEAD/blob/main/pacman_attention_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import glob
import torch
import numpy as np
from PIL import Image
import pandas as pd
from tqdm import tqdm
import pickle

In [3]:
parent_path = "/content/drive/MyDrive/ErdosBootcampProject" 

In [4]:
train_path = "{}/raw_data/highscore".format(parent_path)

In [5]:
tars = glob.glob("{}/*.tar.bz2".format(train_path)) #get directories of tar files
tars

['/content/drive/MyDrive/ErdosBootcampProject/raw_data/highscore/118_RZ_4303947_Sep-01-17-15-39.tar.bz2',
 '/content/drive/MyDrive/ErdosBootcampProject/raw_data/highscore/593_RZ_5037271_Aug-05-15-35-12.tar.bz2']

In [6]:
!tar xjf {tars[0]} #untar first trial '118_RZ_4303947

In [7]:
meta_data = pd.read_csv("{}/raw_data/combined.csv".format(parent_path))
prefix = "RZ_4303947_"
meta_data = meta_data[meta_data['frame_id'].str.contains(prefix)]

In [8]:
#Filter rows that have gaze positions outside of bounds (x > 161 and y > 211)
meta_data = meta_data[(meta_data['gaze_position_x'] <= 161)]
meta_data = meta_data[(meta_data['gaze_position_y'] <= 211)]
data_len = int(meta_data.tail(1)['frame_id'].values.tolist()[0].split('_')[-1])


In [9]:
data_len = int(meta_data.tail(1)['frame_id'].values.tolist()[0].split('_')[-1])
tar_name = tars[0].split('/')[-1].split('.tar')[0]

def load_img(index):
  return Image.open("{}/{}{}.png".format(tar_name, prefix, index+1)).convert('RGB')

#Read entire dataset into memory
images = [np.array(load_img(i)) for i in range(data_len)]

In [10]:
with open('{}/processed_data/images.pkl'.format(parent_path), 'wb') as f:
  pickle.dump(images, f)

In [11]:
with open('{}/processed_data/images.pkl'.format(parent_path), 'rb') as f:
  images = pickle.load(f)

# Creating training data and binning of gaze coordinates with only one pass through dataset

In [None]:
training_data = []
prev_frame = meta_data.iloc[0]
agg_list = []
threshold = 10
for curr_index, _ in enumerate(tqdm(range(len(meta_data)))) :
  curr_frame = meta_data.iloc[curr_index]
  curr_gaze_coords = (curr_frame['gaze_position_x'], curr_frame['gaze_position_y'] )
  #First check if new frame, if so then must bin the rest that we left on
  if curr_frame['frame_id'] != prev_frame['frame_id']:
    if len(agg_list) > 0:
      #Average gaze coords for aggregated list of rows
      avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
      #Make the training example and add it to training data
      example = (int(prev_frame['frame_id'].split('_')[-1])-1, avg_gaze_coords[0], avg_gaze_coords[1], int(prev_frame['action_int']) )
      training_data.append(example)
      #reset and add the gaze coords for the new frame
      agg_list = []
    agg_list.append(curr_gaze_coords )
    prev_frame = curr_frame
    continue

  if len(agg_list)+1 == threshold:
    #First add the current frame's gaze coords before averaging
    agg_list.append(curr_gaze_coords )
    #Average the current bin's gaze values
    avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
    #Make training example and add to training, then reset
    example = (int(curr_frame['frame_id'].split('_')[-1])-1, avg_gaze_coords[0], avg_gaze_coords[1], int(curr_frame['action_int']) )
    training_data.append(example)
    agg_list = []
    prev_frame = curr_frame
    continue
  
  #Otherwise we're in the same frame and don't need to bin yet, so we just add to agg_list and update prev
  agg_list.append(curr_gaze_coords)
  prev_frame = curr_frame

if len(agg_list) > 0:
  #Still have left over in agg_list after going through entire meta_data
  #Average gaze coords for aggregated list of rows
  avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
  #Make the training example and add it to training data
  example = (data_len - 1, avg_gaze_coords[0], avg_gaze_coords[1], int(meta_data.tail(1)['action_int']) )
  training_data.append(example)

 18%|█▊        | 156843/879295 [00:33<01:56, 6179.43it/s]

In [None]:
with open('{}/processed_data/X_train.pkl'.format(parent_path), 'wb') as f:
  pickle.dump(training_data, f)

In [None]:
"""
def get_gaze_list(data, frame_index, threshold=10):
  sub_df = data.query("`frame_id` == '{}{}'".format(prefix, frame_index+1))
  sub_df = sub_df.groupby(np.arange(len(sub_df))//threshold).mean()
  gaze_tups = list(zip(sub_df.gaze_position_x, sub_df.gaze_position_y))
  rep_img = [frame_index for _ in range(len(gaze_tups))]
  return zip(rep_img, gaze_tups)

#get_gaze_list(meta_data, 0)
gaze_dict = { str(frame_index) : get_gaze_list(frame_index) for frame_index in range(data_len)}
"""

In [None]:
class ErdosDataset(torch.utils.data.Dataset):
    def __init__(self, image_path, train_path):
        # load pickle serializations for data grabbing
        with open(image_path, 'rb') as f:
          self.images = pickle.load(f)

        with open(train_path, 'rb') as f:
          self.X_train = pickle.load(f)

    def __len__(self):
        # we will return the number of bins
        return len(self.X_train)

    def __getitem__(self,index):
      """
      Returns the training example for the given index.
      Args:
      - index (int): index of the example to grab
      Returns:
      - img (numpy arr): Frame image with shape (210, 160, 3)
      - gaze_x (float): average x-coordinate for the given bin
      - gaze_y (float): average y-coordinate for the given bin
      - y (int): Integer value of the true class (action)
      """
      img_idx, gaze_x, gaze_y, y = self.X_train[index] #Grab the data

      #Grab the image associated with img_idx
      img = self.images[img_idx]

      return img, gaze_x, gaze_y, y

In [None]:
h,w = 210,160
hidden_size = 256
batch_size = 64
to_print = False

In [None]:
train_data = ErdosDataset('{}/processed_data/images.pkl'.format(parent_path), '{}/processed_data/X_train.pkl'.format(parent_path))
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size, drop_last=True, shuffle=True, num_workers=8, pin_memory=True)

Create Learn Weighted Mask 

In [None]:
class Mask(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(in_features=1000, out_features= 64 ),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=64, out_features=h*w*1)
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, z, gaze_bias):
      """
      Given a random Tensor of shape (batch_size, 1000), this makes a learned weighted mask
      Args:
        - z (Tensor): Our input random vector with shape (batch_size, 1000)
        - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (batch_size, h, w)
      Returns:
        - out (Tensor): Our learned weighted mask, which will be applied to our input image later
      """
      #Start we have z with shape (batch_size, 1000)
      #Apply our fully connected layer
      out = self.MLP(z) #This should now have a shape of (batch_size, h*w*1)
      out = out.view((out.shape[0], 1, h, w)) #Unflatten, so this should now have a shape of (batch_size, h, w, 1)
      
      #Reshape gaze_biase from (batch_size, h, w) to (batch_size, 1, h, w) to match out shape
      gaze_bias = gaze_bias.unsqueeze(1)

      #Apply gaze_bias to learned mask
      if to_print:
        print("[Mask] Out shape: ", out.shape)
        print("[Mask] Gaze Bias shape: ", gaze_bias.shape)
      out = out + gaze_bias
      #Apply sigmoid now to make sure values go between 0 and 1
      return self.sigmoid(out) #This is the learned weight mask with gaze information

In [None]:
class CNN(torch.nn.Module):
  def __init__(self):
        super().__init__()
        self.learned_mask = Mask()
        self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=100, kernel_size=3, stride=1, padding=1)
        self.relu1 = torch.nn.ReLU()
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)

        self.flatten = torch.nn.Flatten()
        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(in_features=100*h*w, out_features=16),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=16, out_features=10)
        )
        #self.final_activation = torch.nn.Softmax()

  def forward(self, x, z, gaze_bias):
    """
    This applies our input image with the mask and then runs it through the rest of the CNN
    Args:
      - x (Tensor): Our input image with shape (batch_size, 3, h, w)
      - z (Tensor): Random tensor used to make mask with shape (batch_size, 1000)
      - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (batch_size, h, w)
    Returns:
      - prediction: Which action?
    """

    learned_weight_mask = self.learned_mask(z, gaze_bias)
    if to_print:
      print("[CNN] Mask shape: ", learned_weight_mask.shape)

    #Apply the mask to the image to get initial input to CNN parts
    out = torch.mul(x, learned_weight_mask) #Shape (batch_size, 3,h,w)
    if to_print:
      print("[CNN] Element mult shape: ", out.shape)

    #Apply the convs etc.
    out = self.conv1(out) #Shape (batch_size, 100, h, w)
    if to_print:
      print("[CNN] Conv1 shape: ", out.shape)
    out = self.relu1(out)
    out = self.maxpool1(out) #Shape (batch_size, 100, h, w)

    #Flatten the image for the fully connected layers
    out = self.flatten(out)
    if to_print:
      print("[CNN] Flatten shape: ", out.shape)

    #Apply the Fully connected layers
    out = self.MLP(out)
    if to_print:
      print("[CNN] MLP shape: ", out.shape)
    #out = self.final_activation(out) #This will get you probability vector with probs for each class

    #Then output the scores
    return out

In [None]:
num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_model = CNN().to(device)
optimizer = torch.optim.SGD(cnn_model.parameters(), lr=1e-2, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
data_iterator = iter(train_data_loader)

#This is training
for epoch in range(num_epochs):
  correct = 0
  for step in tqdm(range(len(data_iterator))):
    #Grab the train data
    imgs, gazes_x, gazes_y, labels = data_iterator.next()
    imgs = imgs.permute((0, 3, 1, 2)).to(device)
    gazes_x = gazes_x.long().to(device)
    gazes_y = gazes_y.long().to(device)
    labels = labels.long().to(device)

    if to_print:
      print("Image shape: ", imgs.shape)
      print("Gazes_X shape: ", gazes_x.shape)
      print("Labels shape: ", labels.shape)

    # #Make one-hot-encoded gaze_bias
    with torch.no_grad():
      gaze_bias = torch.zeros((batch_size, h, w), requires_grad=False).to(device)
      for b in range(batch_size):
        gaze_bias[b, gazes_y[b]-1, gazes_x[b]-1] = 1
    if to_print:
      print("Gaze bias shape: ", gaze_bias.shape)

    # #Make random noise that will be used to make mask
    z = torch.rand((batch_size, 1000), requires_grad=True).to(device)

    # #Pass image into cnn_model
    logits = cnn_model(imgs, z, gaze_bias)
    if to_print:
      print("[Train] logits shape: ", logits.shape)
    # #Calculate loss 
    loss = loss_fn(logits, labels)

    # #Zero gradients before calculating gradients of loss with respect to the image
    optimizer.zero_grad()

    # #Calcualte gradients using backprob
    loss.backward()

    # #Step optimizer to update weights using the new grads
    optimizer.step()
    correct += ((torch.argmax(logits) == labels).float().sum()*(1/batch_size))

    if step % 1000 == 0:
            loss  = loss.item()
            print(f"Loss: {loss:>7f}")
  acc = 100 * correct / len(data_iterator)
  print(f"Acc: {acc:>7f}")

  cpuset_checked))
  0%|          | 1/1411 [00:33<13:07:32, 33.51s/it]

Loss: 8.254302


In [None]:
def test(test_data, cnn_model, loss_fn):
    size = len(test_data.dataset)
    num_batches = len(test_data)
    cnn_model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for image, gaze_coords, y in test_data:
            image, gaze_coords, y = image.to(device), gaze_coords.to(device), y.to(device)
            y_hat = cnn_model(image, gaze_coords, y)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
cnn_model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')