<a href="https://colab.research.google.com/github/nicole-sb/erdos-project-2022--atari-HEAD/blob/main/pacman_attention_cnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import glob
import torch
import numpy as np
from PIL import Image
import pandas as pd
from tqdm import tqdm
import pickle
import matplotlib.pyplot as plt

In [None]:
parent_path = "/content/drive/MyDrive/erdos-project-2022--atari-HEAD" 

In [None]:
train_path = "{}/final_data/ms_pacman/highscore".format(parent_path)

In [None]:
tars = glob.glob("{}/*.tar.bz2".format(train_path)) #get directories of tar files
tars

['/content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/118_RZ_4303947_Sep-01-17-15-39.tar.bz2',
 '/content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/593_RZ_5037271_Aug-05-15-35-12.tar.bz2']

In [None]:
!tar xjf {tars[0]} #untar first trial '118_RZ
!tar xjf {tars[1]} #untar second trial '593_RZ

In [None]:
!unzip /content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/combined_trial_data.csv.zip -d /content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/

Archive:  /content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/combined_trial_data.csv.zip
replace /content/drive/MyDrive/erdos-project-2022--atari-HEAD/final_data/ms_pacman/highscore/combined_trial_data.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: n


In [None]:
trial_text_df = pd.read_csv("{}/combined_trial_data.csv".format(train_path))
trial_1 = trial_text_df[trial_text_df['trial_id'] == '118_RZ']
trial_2 = trial_text_df[trial_text_df['trial_id'] == '593_RZ']

In [None]:
data_len = int(trial_1.tail(1)['frame_id'].values.tolist()[0].split('_')[-1])

def load_img(index):
  return Image.open("./118_RZ_4303947_Sep-01-17-15-39/RZ_4303947_{}.png".format(index+1)).convert('RGB')

#Read entire dataset into memory
images_1 = [np.array(load_img(i)) for i in range(data_len)]

In [None]:
data_len = int(trial_2.tail(1)['frame_id'].values.tolist()[0].split('_')[-1])

def load_img(index):
  return Image.open("./593_RZ_5037271_Aug-05-15-35-12/RZ_5037271_{}.png".format(index+1)).convert('RGB')

#Read entire dataset into memory
images_2 = [np.array(load_img(i)) for i in range(data_len)]

In [None]:
images_1.extend(images_2)

In [None]:
images = images_1

In [None]:
with open('{}/processed_data/images.pkl'.format(train_path), 'wb') as f:
  pickle.dump(images, f)

# Creating training data and binning of gaze coordinates with only one pass through dataset

In [None]:
def bin(df_train):
  training_data = []
  prev_frame = df_train.iloc[0]
  agg_list = []
  threshold = 10
  for curr_index, _ in enumerate(tqdm(range(len(df_train)))) :
    curr_frame = df_train.iloc[curr_index]
    curr_gaze_coords = (curr_frame['gaze_position_x'], curr_frame['gaze_position_y'] )
    #First check if new frame, if so then must bin the rest that we left on
    if curr_frame['frame_id'] != prev_frame['frame_id']:
      if len(agg_list) > 0:
        #Average gaze coords for aggregated list of rows
        avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
        #Make the training example and add it to training data
        example = (int(prev_frame['frame_id'].split('_')[-1])-1, avg_gaze_coords[0], avg_gaze_coords[1], int(prev_frame['action_int']) )
        training_data.append(example)
        #reset and add the gaze coords for the new frame
        agg_list = []
      agg_list.append(curr_gaze_coords)
      prev_frame = curr_frame
      continue

    if len(agg_list)+1 == threshold:
      #First add the current frame's gaze coords before averaging
      agg_list.append(curr_gaze_coords )
      #Average the current bin's gaze values
      avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
      #Make training example and add to training, then reset
      example = (int(curr_frame['frame_id'].split('_')[-1])-1, avg_gaze_coords[0], avg_gaze_coords[1], int(curr_frame['action_int']) )
      training_data.append(example)
      agg_list = []
      prev_frame = curr_frame
      continue
    
    #Otherwise we're in the same frame and don't need to bin yet, so we just add to agg_list and update prev
    agg_list.append(curr_gaze_coords)
    prev_frame = curr_frame

  if len(agg_list) > 0:
    #Still have left over in agg_list after going through entire df_train
    #Average gaze coords for aggregated list of rows
    avg_gaze_coords = tuple(sum(y) / len(y) for y in zip(*agg_list))
    #Make the training example and add it to training data
    example = (data_len - 1, avg_gaze_coords[0], avg_gaze_coords[1], int(df_train.tail(1)['action_int']) )
    training_data.append(example)
  return training_data

In [None]:
binned_data_1 = bin(trial_1)
binned_data_2 = bin(trial_2)

100%|██████████| 882316/882316 [02:17<00:00, 6396.11it/s]
100%|██████████| 881897/881897 [01:57<00:00, 7533.16it/s]


In [None]:
# split into train and test
# first 80% of frames for training, last 20% for testing

def train_test_split(binned_data, ratio = 0.8):
  total_len = len(binned_data)
  split = int(round(total_len*ratio, 0))
  train = binned_data[:split]
  test = binned_data[split:]
  return train, test

train_1, test_1 = train_test_split(binned_data_1)
train_2, test_2 = train_test_split(binned_data_2)

X_train = train_1 + train_2
X_test = test_1 + test_2

with open('{}/processed_data/X_train.pkl'.format(train_path), 'wb') as f:
  pickle.dump(X_train, f)

with open('{}/processed_data/X_test.pkl'.format(train_path), 'wb') as f:
  pickle.dump(X_test, f)

# trial 1
# total_gaze_1 = len(binned_data)
# train_len_1  = int(round(total_gaze_1*0.8, 0))

# df_train_1 = trial_1[:train_len_1]
# df_test_1 = trial_1[train_len_1:]

# # trial 2
# total_gaze_2 = len(trial_2)
# train_len_2 = int(round(total_gaze_2*0.8, 0))

# df_train_2 = trial_2[:train_len_2]
# df_test_2 = trial_2[train_len_2:]

# # combine trial 1+2 train and save
# df_train = pd.concat([df_train_1, df_train_2])

# # combine trial 1+2 test and save
# df_test = pd.concat([df_test_1, df_test_2])

# df_train.to_csv("{}/train_data.csv".format(train_path), index=False)
# df_test.to_csv("{}/test_data.csv".format(train_path), index=False)

In [None]:
class ErdosDataset(torch.utils.data.Dataset):
    def __init__(self, train_path, images):
        # load pickle serializations for data grabbing
        self.images = images
        
        with open(train_path, 'rb') as f:
          self.data = pickle.load(f)

    def __len__(self):
        # we will return the number of bins
        return len(self.data)

    def __getitem__(self,index):
      """
      Returns the training example for the given index.
      Args:
      - index (int): index of the example to grab
      Returns:
      - img (numpy arr): Frame image with shape (210, 160, 3)
      - gaze_x (float): average x-coordinate for the given bin
      - gaze_y (float): average y-coordinate for the given bin
      - y (int): Integer value of the true class (action)
      """
      img_idx, gaze_x, gaze_y, y = self.data[index] #Grab the data

      #Grab the image associated with img_idx
      img = self.images[img_idx]

      return img, gaze_x, gaze_y, y

In [None]:
image_path = '{}/processed_data/images.pkl'.format(train_path)
with open(image_path, 'rb') as f:
  images = pickle.load(f)

In [None]:
h,w = 210,160
batch_size = 512
to_print = False

In [None]:
train_data = ErdosDataset('{}/processed_data/X_train.pkl'.format(train_path), images)
train_data_loader = torch.utils.data.DataLoader(train_data,batch_size=batch_size, drop_last=True, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
test_data = ErdosDataset('{}/processed_data/X_test.pkl'.format(train_path), images)
test_data_loader = torch.utils.data.DataLoader(test_data,batch_size=batch_size, drop_last=True, shuffle=True, num_workers=2, pin_memory=True)

Create Learn Weighted Mask 

In [None]:
class Mask(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(in_features=1000, out_features= 64 ),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=64, out_features=h*w*1)
        )
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, z, gaze_bias):
      """
      Given a random Tensor of shape (batch_size, 1000), this makes a learned weighted mask
      Args:
        - z (Tensor): Our input random vector with shape (batch_size, 1000)
        - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (batch_size, h, w)
      Returns:
        - out (Tensor): Our learned weighted mask, which will be applied to our input image later
      """
      #Start we have z with shape (batch_size, 1000)
      #Apply our fully connected layer
      out = self.MLP(z) #This should now have a shape of (batch_size, h*w*1)
      out = out.view((out.shape[0], 1, h, w)) #Unflatten, so this should now have a shape of (batch_size, h, w, 1)
      
      #Reshape gaze_biase from (batch_size, h, w) to (batch_size, 1, h, w) to match out shape
      gaze_bias = gaze_bias.unsqueeze(1)

      #Apply gaze_bias to learned mask
      if to_print:
        print("[Mask] Out shape: ", out.shape)
        print("[Mask] Gaze Bias shape: ", gaze_bias.shape)
      out = out + gaze_bias
      #Apply sigmoid now to make sure values go between 0 and 1
      return self.sigmoid(out) #This is the learned weight mask with gaze information

In [None]:
class CNN(torch.nn.Module):
  def __init__(self):
        super().__init__()
        self.learned_mask = Mask()

        self.convs = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=7, stride=4, padding=1),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=2, padding=1),
            # torch.nn.BatchNorm2d(16),
            # torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            # torch.nn.BatchNorm2d(16),
            # torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
            # torch.nn.BatchNorm2d(16),
            # torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            # torch.nn.BatchNorm2d(32),
            # torch.nn.ReLU(),
            # torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            # torch.nn.BatchNorm2d(32),
            # torch.nn.ReLU()
        )

        # #conv block 1
        # self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=8, stride=4, padding=1)
        # self.relu1 = torch.nn.ReLU()

        # #conv block 2
        # self.conv2 = torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
        # self.relu2 = torch.nn.ReLU()
        
        # #conv block 3
        # self.conv3 = torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        # self.relu3 = torch.nn.ReLU()

        self.flatten = torch.nn.Flatten()
        self.MLP = torch.nn.Sequential(
            torch.nn.Linear(in_features=4160, out_features=32),
            torch.nn.Dropout(p=0.5),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=32, out_features=10)
        )
        #self.final_activation = torch.nn.Softmax()

  def forward(self, x, z, gaze_bias):
    """
    This applies our input image with the mask and then runs it through the rest of the CNN
    Args:
      - x (Tensor): Our input image with shape (batch_size, 3, h, w)
      - z (Tensor): Random tensor used to make mask with shape (batch_size, 1000)
      - gaze_bias (Tensor): Our one-hot-encoded tensor with shape (batch_size, h, w)
    Returns:
      - prediction: Which action?
    """
    out = x
    #print("[CNN] X shape: ", out.shape)
    # learned_weight_mask = self.learned_mask(z, gaze_bias)
    # if to_print:
    #   print("[CNN] Mask shape: ", learned_weight_mask.shape)

    # #Apply the mask to the image to get initial input to CNN parts
    # out = torch.mul(x, learned_weight_mask) #Shape (batch_size, 3,h,w)
    # if to_print:
    #   print("[CNN] Element mult shape: ", out.shape)

    #Apply the convs etc.
    out = self.convs(out)
    # out = self.conv1(out) #Shape (batch_size, 100, h, w)
    if to_print:
      print("[CNN] Conv1 shape: ", out.shape)
    # out = self.relu1(out)

    # out = self.conv2(out) #Shape (batch_size, 100, h, w)
    # if to_print:
    #   print("[CNN] Conv2 shape: ", out.shape)
    # out = self.relu2(out)

    # out = self.conv3(out) #Shape (batch_size, 100, h, w)
    # if to_print:
    #   print("[CNN] Conv3 shape: ", out.shape)
    # out = self.relu3(out)

    #Flatten the image for the fully connected layers
    out = self.flatten(out)
    if to_print:
      print("[CNN] Flatten shape: ", out.shape)

    #Apply the Fully connected layers
    out = self.MLP(out)
    if to_print:
      print("[CNN] MLP shape: ", out.shape)
    #out = self.final_activation(out) #This will get you probability vector with probs for each class

    #Then output the scores
    return out

In [None]:
def train(model, data_loader, loss_fn, optimizer, batch_size = 64):
  total_correct = 0
  total_loss = 0
  total_imgs = 0
  for step, values in enumerate(tqdm(data_loader)):
    #Grab the train data
    imgs, gazes_x, gazes_y, labels = values
    imgs = imgs.permute((0, 3, 1, 2)).float().to(device)
    gazes_x = gazes_x.long().to(device)
    gazes_y = gazes_y.long().to(device)
    labels = labels.long().to(device)

    if to_print:
      print("Image shape: ", imgs.shape)
      print("Gazes_X shape: ", gazes_x.shape)
      print("Labels shape: ", labels.shape)

    # #Make one-hot-encoded gaze_bias
    with torch.no_grad():
      gaze_bias = torch.zeros((batch_size, h, w), requires_grad=False).to(device)
      for b in range(batch_size):
        gaze_bias[b, gazes_y[b]-1, gazes_x[b]-1] = 1
    if to_print:
      print("Gaze bias shape: ", gaze_bias.shape)

    # #Make random noise that will be used to make mask
    #z = torch.rand((batch_size, 1000), requires_grad=True).to(device)

    # #Pass image into cnn_model
    with torch.cuda.amp.autocast():
      logits = cnn_model(imgs, None, gaze_bias)
    if to_print:
      print("[Train] logits shape: ", logits.shape)
    # #Calculate loss 
    with torch.cuda.amp.autocast():
      loss = loss_fn(logits, labels)

    # #Zero gradients before calculating gradients of loss with respect to the image
    optimizer.zero_grad()

    # #Calcualte gradients using backprob
    loss.backward()

    # #Step optimizer to update weights using the new grads
    optimizer.step()
    total_correct += (torch.argmax(logits, dim = 1) == labels).float().sum()

    total_loss += loss.item()
    total_imgs += batch_size

  total_loss /= batch_size
  total_correct /= total_imgs
  return total_correct, total_loss

In [None]:
def test(test_data, cnn_model, loss_fn):
    #size = len(test_data.dataset)
    cnn_model.eval()
    test_loss, correct = 0, 0
    total_imgs = 0
    with torch.no_grad():
        for step, values in enumerate(tqdm(test_data)):
            image, gaze_x, gaze_y, y = values
            image, gaze_x, gaze_y, y = image.permute((0, 3, 1, 2)).float().to(device), gaze_x.long().to(device), gaze_y.long().to(device), y.long().to(device)
            y_hat = cnn_model(image, None, y)
            test_loss += loss_fn(y_hat, y).item()
            correct += (y_hat.argmax(1) == y).type(torch.float).sum().item()
            total_imgs+=batch_size
    test_loss /= batch_size
    correct /= total_imgs
    return correct, test_loss

In [None]:
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_model = CNN().to(device)
optimizer = torch.optim.SGD(cnn_model.parameters(), lr=1e-3, momentum=0.9)
loss_fn = torch.nn.CrossEntropyLoss()
#optimizer.lr = 'blah'

#This is training
losses = []
test_losses = []
for epoch in range(num_epochs):
  correct, total_loss = train(cnn_model, train_data_loader, loss_fn, optimizer, batch_size)
  test_correct, test_loss = test(test_data_loader, cnn_model, loss_fn)
  acc = 100.0 * correct
  print(f"Epoch: {epoch}, Acc: {acc:>7f}, Loss: {total_loss:>7f}")
  test_acc = test_correct*100.0
  print(f"Test Acc: {test_acc:>7f}, Test Loss: {test_loss:>7f}")
  losses.append(total_loss)
  test_losses.append(test_loss)
#Graph epoch and loss:
plt.plot(np.array(losses), 'r')
plt.plot(np.array(test_losses), 'b')


100%|██████████| 284/284 [00:40<00:00,  7.04it/s]
100%|██████████| 71/71 [00:06<00:00, 11.61it/s]


Epoch: 0, Acc: 25.605192, Loss: 0.988807
Test Acc: 25.022007, Test Loss: 0.233251


100%|██████████| 284/284 [00:40<00:00,  7.01it/s]
100%|██████████| 71/71 [00:06<00:00, 11.74it/s]


Epoch: 1, Acc: 33.525665, Loss: 0.903018
Test Acc: 28.036972, Test Loss: 0.238552


100%|██████████| 284/284 [00:40<00:00,  7.02it/s]
100%|██████████| 71/71 [00:06<00:00, 11.58it/s]


Epoch: 2, Acc: 42.516918, Loss: 0.825277
Test Acc: 28.270797, Test Loss: 0.290207


100%|██████████| 284/284 [00:40<00:00,  6.98it/s]
100%|██████████| 71/71 [00:06<00:00, 11.64it/s]


Epoch: 3, Acc: 49.172672, Loss: 0.738977
Test Acc: 28.391835, Test Loss: 0.350520


100%|██████████| 284/284 [00:40<00:00,  7.04it/s]
100%|██████████| 71/71 [00:06<00:00, 11.78it/s]


Epoch: 4, Acc: 53.389084, Loss: 0.663694
Test Acc: 28.251540, Test Loss: 0.403950


100%|██████████| 284/284 [00:40<00:00,  7.01it/s]
100%|██████████| 71/71 [00:06<00:00, 11.79it/s]


Epoch: 5, Acc: 56.115894, Loss: 0.601484
Test Acc: 27.462038, Test Loss: 0.441829


100%|██████████| 284/284 [00:40<00:00,  7.10it/s]
100%|██████████| 71/71 [00:05<00:00, 11.84it/s]


Epoch: 6, Acc: 57.608932, Loss: 0.557949
Test Acc: 27.937940, Test Loss: 0.550282


100%|██████████| 284/284 [00:40<00:00,  7.07it/s]
100%|██████████| 71/71 [00:05<00:00, 11.89it/s]


Epoch: 7, Acc: 58.438324, Loss: 0.527481
Test Acc: 27.869168, Test Loss: 0.545614


100%|██████████| 284/284 [00:40<00:00,  7.02it/s]
100%|██████████| 71/71 [00:06<00:00, 11.77it/s]


Epoch: 8, Acc: 58.941734, Loss: 0.507401
Test Acc: 27.445533, Test Loss: 0.592916


100%|██████████| 284/284 [00:40<00:00,  7.01it/s]
100%|██████████| 71/71 [00:06<00:00, 11.61it/s]


Epoch: 9, Acc: 59.273903, Loss: 0.489130
Test Acc: 28.642165, Test Loss: 0.588025


100%|██████████| 284/284 [00:40<00:00,  7.06it/s]
100%|██████████| 71/71 [00:06<00:00, 11.80it/s]


Epoch: 10, Acc: 59.518730, Loss: 0.473902
Test Acc: 28.023217, Test Loss: 0.579083


100%|██████████| 284/284 [00:40<00:00,  7.05it/s]
100%|██████████| 71/71 [00:06<00:00, 11.68it/s]


Epoch: 11, Acc: 59.810329, Loss: 0.461628
Test Acc: 28.133253, Test Loss: 0.577285


100%|██████████| 284/284 [00:40<00:00,  7.02it/s]
100%|██████████| 71/71 [00:06<00:00, 11.80it/s]


Epoch: 12, Acc: 59.987759, Loss: 0.450051
Test Acc: 28.130502, Test Loss: 0.643221


100%|██████████| 284/284 [00:39<00:00,  7.12it/s]
100%|██████████| 71/71 [00:06<00:00, 11.79it/s]


Epoch: 13, Acc: 60.027649, Loss: 0.440068
Test Acc: 28.069982, Test Loss: 0.648651


100%|██████████| 284/284 [00:40<00:00,  7.03it/s]
100%|██████████| 71/71 [00:06<00:00, 11.46it/s]


Epoch: 14, Acc: 60.103298, Loss: 0.435389
Test Acc: 28.790713, Test Loss: 0.625796


100%|██████████| 284/284 [00:40<00:00,  7.03it/s]
100%|██████████| 71/71 [00:06<00:00, 11.67it/s]


Epoch: 15, Acc: 60.131489, Loss: 0.427171
Test Acc: 27.797645, Test Loss: 0.619957


100%|██████████| 284/284 [00:39<00:00,  7.11it/s]
100%|██████████| 71/71 [00:06<00:00, 11.79it/s]


Epoch: 16, Acc: 60.373569, Loss: 0.422246
Test Acc: 28.529379, Test Loss: 0.629300


100%|██████████| 284/284 [00:39<00:00,  7.12it/s]
100%|██████████| 71/71 [00:06<00:00, 11.78it/s]


Epoch: 17, Acc: 60.284851, Loss: 0.415480
Test Acc: 28.204776, Test Loss: 0.651899


100%|██████████| 284/284 [00:39<00:00,  7.13it/s]
100%|██████████| 71/71 [00:06<00:00, 11.64it/s]


Epoch: 18, Acc: 60.477413, Loss: 0.410020
Test Acc: 28.025968, Test Loss: 0.655806


100%|██████████| 284/284 [00:39<00:00,  7.11it/s]
100%|██████████| 71/71 [00:05<00:00, 11.84it/s]


Epoch: 19, Acc: 60.498734, Loss: 0.407393
Test Acc: 28.039723, Test Loss: 0.671041


100%|██████████| 284/284 [00:39<00:00,  7.12it/s]
100%|██████████| 71/71 [00:06<00:00, 11.67it/s]


Epoch: 20, Acc: 60.557194, Loss: 0.403155
Test Acc: 28.325814, Test Loss: 0.650775


100%|██████████| 284/284 [00:39<00:00,  7.13it/s]
100%|██████████| 71/71 [00:05<00:00, 11.96it/s]


Epoch: 21, Acc: 60.575069, Loss: 0.398621
Test Acc: 28.331316, Test Loss: 0.656347


100%|██████████| 284/284 [00:39<00:00,  7.11it/s]
100%|██████████| 71/71 [00:05<00:00, 11.85it/s]


Epoch: 22, Acc: 60.374947, Loss: 0.396281
Test Acc: 28.047975, Test Loss: 0.651595


100%|██████████| 284/284 [00:40<00:00,  7.09it/s]
100%|██████████| 71/71 [00:06<00:00, 11.75it/s]


Epoch: 23, Acc: 60.621834, Loss: 0.393106
Test Acc: 28.801717, Test Loss: 0.710221


100%|██████████| 284/284 [00:40<00:00,  7.04it/s]
100%|██████████| 71/71 [00:06<00:00, 11.50it/s]


Epoch: 24, Acc: 60.715366, Loss: 0.388995
Test Acc: 28.639415, Test Loss: 0.635867


100%|██████████| 284/284 [00:40<00:00,  7.08it/s]
100%|██████████| 71/71 [00:06<00:00, 11.63it/s]


Epoch: 25, Acc: 60.534496, Loss: 0.388622
Test Acc: 27.926937, Test Loss: 0.677933


100%|██████████| 284/284 [00:40<00:00,  7.07it/s]
100%|██████████| 71/71 [00:06<00:00, 11.77it/s]


Epoch: 26, Acc: 60.595703, Loss: 0.386060
Test Acc: 28.383583, Test Loss: 0.695243


100%|██████████| 284/284 [00:39<00:00,  7.12it/s]
100%|██████████| 71/71 [00:06<00:00, 11.65it/s]


Epoch: 27, Acc: 60.686481, Loss: 0.383455
Test Acc: 28.433099, Test Loss: 0.681532


100%|██████████| 284/284 [00:40<00:00,  7.04it/s]
100%|██████████| 71/71 [00:06<00:00, 11.55it/s]


Epoch: 28, Acc: 60.641777, Loss: 0.380843
Test Acc: 28.097491, Test Loss: 0.678971


100%|██████████| 284/284 [00:40<00:00,  7.03it/s]
100%|██████████| 71/71 [00:06<00:00, 11.29it/s]


Epoch: 29, Acc: 60.810268, Loss: 0.380617
Test Acc: 28.694432, Test Loss: 0.624050


100%|██████████| 284/284 [00:40<00:00,  6.99it/s]
100%|██████████| 71/71 [00:06<00:00, 11.64it/s]


Epoch: 30, Acc: 60.762131, Loss: 0.378190
Test Acc: 28.543134, Test Loss: 0.671434


100%|██████████| 284/284 [00:40<00:00,  7.05it/s]
100%|██████████| 71/71 [00:06<00:00, 11.73it/s]


Epoch: 31, Acc: 60.761440, Loss: 0.375931
Test Acc: 27.695863, Test Loss: 0.704699


100%|██████████| 284/284 [00:40<00:00,  7.09it/s]
100%|██████████| 71/71 [00:06<00:00, 11.62it/s]


Epoch: 32, Acc: 60.882481, Loss: 0.373469
Test Acc: 28.364327, Test Loss: 0.703534


100%|██████████| 284/284 [00:39<00:00,  7.16it/s]
100%|██████████| 71/71 [00:05<00:00, 11.85it/s]


Epoch: 33, Acc: 60.784824, Loss: 0.372567
Test Acc: 28.911752, Test Loss: 0.723709


100%|██████████| 284/284 [00:39<00:00,  7.14it/s]
100%|██████████| 71/71 [00:06<00:00, 11.83it/s]


Epoch: 34, Acc: 60.925121, Loss: 0.370689
Test Acc: 28.113996, Test Loss: 0.690098


100%|██████████| 284/284 [00:39<00:00,  7.11it/s]
100%|██████████| 71/71 [00:06<00:00, 11.78it/s]


Epoch: 35, Acc: 60.777260, Loss: 0.369679
Test Acc: 27.396017, Test Loss: 0.699701


100%|██████████| 284/284 [00:39<00:00,  7.13it/s]
100%|██████████| 71/71 [00:06<00:00, 11.80it/s]


Epoch: 36, Acc: 60.915493, Loss: 0.368994
Test Acc: 28.292804, Test Loss: 0.657808


100%|██████████| 284/284 [00:40<00:00,  6.98it/s]
100%|██████████| 71/71 [00:06<00:00, 11.58it/s]


Epoch: 37, Acc: 60.877670, Loss: 0.367609
Test Acc: 28.430348, Test Loss: 0.673623


100%|██████████| 284/284 [00:40<00:00,  7.03it/s]
100%|██████████| 71/71 [00:06<00:00, 11.62it/s]


Epoch: 38, Acc: 60.861851, Loss: 0.366739
Test Acc: 28.273548, Test Loss: 0.732515


100%|██████████| 284/284 [00:39<00:00,  7.11it/s]
100%|██████████| 71/71 [00:06<00:00, 11.71it/s]


Epoch: 39, Acc: 61.039280, Loss: 0.364876
Test Acc: 27.398768, Test Loss: 0.624488


100%|██████████| 284/284 [00:40<00:00,  7.10it/s]
100%|██████████| 71/71 [00:06<00:00, 11.54it/s]


Epoch: 40, Acc: 60.844654, Loss: 0.363899
Test Acc: 28.433099, Test Loss: 0.687483


100%|██████████| 284/284 [00:40<00:00,  7.10it/s]
100%|██████████| 71/71 [00:06<00:00, 11.67it/s]


Epoch: 41, Acc: 60.820587, Loss: 0.364082
Test Acc: 28.064481, Test Loss: 0.727643


100%|██████████| 284/284 [00:40<00:00,  7.09it/s]
100%|██████████| 71/71 [00:06<00:00, 11.79it/s]


Epoch: 42, Acc: 60.863228, Loss: 0.362208
Test Acc: 28.323063, Test Loss: 0.703512


100%|██████████| 284/284 [00:40<00:00,  7.09it/s]
100%|██████████| 71/71 [00:06<00:00, 11.63it/s]


Epoch: 43, Acc: 60.814400, Loss: 0.362113
Test Acc: 27.536312, Test Loss: 0.704366


 47%|████▋     | 134/284 [00:18<00:20,  7.25it/s]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cnn_model = CNN().to(device)
print(cnn_model)

In [None]:
"""
cnn_model.eval()
x, y = test_data[0][0], test_data[0][1]
with torch.no_grad():
    pred = model(x)
    predicted, actual = classes[pred[0].argmax(0)], classes[y]
    print(f'Predicted: "{predicted}", Actual: "{actual}"')
  """