<a href="https://colab.research.google.com/github/dannyjwan/ErdosBootcampProject/blob/main/pacman_attention_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
h,w = 160,210
hidden_size = 256
batch_size = 16

In [None]:
class Mask(torch.nn.Module):
  def  __init__(self):
      super().__init__()
      self.MLP = torch.nn.Sequential(
          torch.nn.Linear(in_features=1000, out_features =  64),
          torch.nn.ReLU(),
          torch.nn.Linear(in_features=65, out_features=h*w*1) #
      )
      self.sigmoid = torch.nn.Sigmoid()

      def forward(self, random_vector, gaze_bias):
        """
      Given a random vector of length 1000, this makes a learned weighted mask
      Args:
        - random_vector (Tensor): Our input random vector with shape (1000,)
        - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (1, h, w)
      Returns:
        - out (Tensor): Our learned weighted mask, which will be applied to our input image later
      """
        #start we have x with shape (1000,)
        #apply our fully connected layer
        out = self.MLP(random_vector) #This should now have a shape of (h*w*1)
        out = out.view() #unflatten
        #apply gaze_bias to learned mask
        out = out + gaze_bias
        #apply sigmoid now to make values go between 0 and 1
        return self.sigmoid(out)

class CNN(torch.nn.Module):
  def  __init__(self):
      super().__init__()
      
      self.learned_mask = Mask()

      self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=100, kernel_size=3, stride=1)
      self.relu = torch.nn.ReLU()
      self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2)

      self.flatten = torch.nn.Flatten()
      self.MLP = torch.nn.Sequential(
          torch.nn.Linear(in_features = 100*h*w, out_features=16),
          torch.nn.ReLU(),
          torch.nn.Linear(in_features=16, out_features=9)
      )
      self.final_activation = torch.nn.Softmax()
 
  def forward(self, x, random_vector, mask):
     """
    This applies our input image with the mask and then runs it through the rest of the CNN
    Args:
      - x (Tensor): Our input image with shape (3, h, w)
      - random_vector (Tensor): Random vector used to make mask with shape (1000)
      - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (1, h, w)
    Returns:
      - prediction: Which action?
    """
    learned_weight_mask = self.learned.mask(random_vector)

    
    #Apply the mask to the image to get initial input to CNN parts
    out = torch.mul(x, learned_weight_mask) #shape (3,h,w)

    out = self.conv1(out) #shape (100,h,w)
    out = self.relu1(out)
    out = self.maxpool1(out) #shape (100,h,w)

    #apply the fully connected layers
    out = self.flatten(out)

    out = self.final_activation(self.MLP(out)) #this will get you probability vector with probs for each class
    
    #Output the class associated with the highest probability
    return torch.argmax(out)

In [None]:
num_epochs = 17814
training_data = None #TODO: load training data (images, gaze_coords)
cnn_model = CNN()
opt_cnn = torch.optim.SGD(cnn_model.parameters(), lr=1e-2, momentum=0.9)
loss_func = torch.nn.CrossEntropyLoss()

#This is training
for epoch in range(num_epochs):
  for image, gaze_coords, y in training_data:
    #Every step

    #Grab the gaze coordinates to do one hot encoding
    gaze_x, gaze_y = gaze_coords[0], gaze_coords[1]

    #Make one hot encoded gaze_bias
    gaze_bias = torch.zeros((1, h, w), requires_grad=False)
    gaze_bias[:, gaze_y, gaze_x] = 1

    #Make random vector that will be used to make mask
    random_vector = torch.rand(1000, requires_grad=True)

    #Zero gradients before calculating gradients of loss with respect to the image
    opt_cnn.zero_grad()

    #Pass image into cnn_model
    y_hat = cnn_model(image, random_vector, gaze_bias)
    
    #Calculate loss 
    loss = loss_func(y_hat, y)

    #Calcualte gradients with backprob
    loss.backward()

    #Step optimizer to update weights with new grads
    opt_cnn.step()


SyntaxError: ignored