# README
revise counter_start and counter_end here to control which part of dataset to use to predict scanpath

In [1]:
counter_start = 0          # 100000 / 5, 100000 / 5 * 2...
counter_end = 100000 / 5   # 100000 / 5 * 2, 100000 / 5 * 3...

# install

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import math
from scipy.misc import face
from scipy.ndimage import zoom
from scipy.special import logsumexp
from tqdm import tqdm
import sys
import gc
from IPython.display import clear_output

import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

# # TODO: load downloaded files from google drive
# from google.colab import drive
# drive.mount('/content/gdrive')

In [None]:
!git clone https://github.com/matthias-k/DeepGaze
%cd /content/DeepGaze/deepgaze_pytorch
sys.path.append("/content/DeepGaze")
import deepgaze_pytorch

!wget https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/centerbias_mit1003.npy
centerbias_template = np.load('centerbias_mit1003.npy') # load precomputed centerbias log density (from MIT1003) over a 1024x1024 image
DEVICE = 'cuda'

Cloning into 'DeepGaze'...
remote: Enumerating objects: 108, done.[K
remote: Counting objects: 100% (108/108), done.[K
remote: Compressing objects: 100% (66/66), done.[K
remote: Total 108 (delta 64), reused 83 (delta 39), pack-reused 0[K
Receiving objects: 100% (108/108), 145.59 KiB | 11.20 MiB/s, done.
Resolving deltas: 100% (64/64), done.
/content/DeepGaze/deepgaze_pytorch
--2022-08-17 03:25:21--  https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/centerbias_mit1003.npy
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/372933216/3c267f80-c32e-11eb-9f03-c6381f7da54a?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220817%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220817T032521Z&X-Amz-Expires=300&X-Amz-Signature=24970b3586c2bf5a5f49e3

# combine
1. use deepgaze 2e to predict a fixation distribution (without fixation history input)
2. draw 4 fixations far away from each other to fake a fixation history
3. feed this fake fixation history to deepgaze 3 to simulate rest of the scanpath

In [None]:
model_deepgaze2 = deepgaze_pytorch.DeepGazeIIE(pretrained=True).to(DEVICE)
model_deepgaze3 = deepgaze_pytorch.DeepGazeIII(pretrained=True).to(DEVICE)


def deepgaze2_pred(image, model=model_deepgaze2):
  '''use deepgaze 2e to predict a fixation distribution (without fixation history input)'''
  centerbias = zoom(centerbias_template, \
                    (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), \
                    order=0, mode='nearest') # rescale to match image size
  centerbias -= logsumexp(centerbias) # renormalize log density
  centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)

  image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE)
  log_density_prediction = model(image_tensor, centerbias_tensor) # predicted log density for the next fixation location 

  del model, centerbias_tensor, image_tensor
  gc.collect()
  torch.cuda.empty_cache()
  
  return log_density_prediction


def draw_fix_from_pred(log_density_prediction, nfix=1):
  '''draw 4 fixations to fake a fixation history'''

  fix_dist = log_density_prediction.detach().cpu().numpy()[0, 0]
  fix_dist = np.exp(fix_dist)
  assert math.isclose(fix_dist.sum(), 1) # validate 2d dist of prob sum to 1

  flat = fix_dist.flatten() # sample from fix_dist
  sample_index = np.random.choice(a=flat.size, p=flat, size=nfix)  # sample an index from the 1D array with the probability distribution from the original array
  adjusted_index = np.unravel_index(sample_index, fix_dist.shape) # Take this index and adjust it so it matches the original array
  fixations_x = adjusted_index[1]
  fixations_y = adjusted_index[0] # height = axis 0 = y

  return fixations_x, fixations_y


def deepgaze3_pred(image, fixation_history_x, fixation_history_y, model=model_deepgaze3, nfix_total=20):
  '''
  feed fake fixation history to deepgaze 3 to simulate rest of the scanpath
  use log_density_prediction to draw next fixation, update log_density_prediction, until reach nfix_total
  arg:
    image: np.array of shape (xpix, ypix, channel)
    fixation_history_x
    fixation_history_y
    nfix_total: total fixation needed per image, including 4 steps of fake fixation history
  return:
    fixation_history_x
    fixation_history_y
    log_density_prediction: pred prob of next fix
  '''
  centerbias = zoom(centerbias_template, \
                    (image.shape[0]/centerbias_template.shape[0], image.shape[1]/centerbias_template.shape[1]), \
                    order=0, mode='nearest') # rescale to match image size
  centerbias -= logsumexp(centerbias) # renormalize log density
  centerbias_tensor = torch.tensor([centerbias]).to(DEVICE)

  image_tensor = torch.tensor([image.transpose(2, 0, 1)]).to(DEVICE) 
  x_hist_tensor = torch.tensor([fixation_history_x[model.included_fixations]]).to(DEVICE)
  y_hist_tensor = torch.tensor([fixation_history_x[model.included_fixations]]).to(DEVICE)
  log_density_prediction = model(image_tensor, centerbias_tensor, x_hist_tensor, y_hist_tensor)

  nstep = nfix_total - 4 # subtract 4 steps of fake history, simulate rest of the fixations
  for i in range(nstep):

    fixations_x, fixations_y = draw_fix_from_pred(log_density_prediction, nfix=1) # predict next fixation
    fixation_history_x = np.append(fixation_history_x, fixations_x)
    fixation_history_y = np.append(fixation_history_y, fixations_y)
    
    x_hist_tensor = torch.tensor([fixation_history_x[model.included_fixations]]).to(DEVICE)
    y_hist_tensor = torch.tensor([fixation_history_x[model.included_fixations]]).to(DEVICE)
    log_density_prediction = model(image_tensor, centerbias_tensor, x_hist_tensor, y_hist_tensor)

  del model, centerbias_tensor, image_tensor, x_hist_tensor, y_hist_tensor
  gc.collect()
  torch.cuda.empty_cache()

  return fixation_history_x, fixation_history_y, log_density_prediction

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://bitbucket.org/robert_geirhos/texture-vs-shape-pretrained-models/raw/60b770e128fffcbd8562a3ab3546c1a735432d03/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar" to /root/.cache/torch/hub/checkpoints/resnet50_finetune_60_epochs_lr_decay_after_30_start_resnet50_train_45_epochs_combined_IN_SF-ca06340c.pth.tar


  0%|          | 0.00/195M [00:00<?, ?B/s]

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b5-b6417697.pth


  0%|          | 0.00/117M [00:00<?, ?B/s]

Loaded pretrained weights for efficientnet-b5


Downloading: "https://github.com/pytorch/vision/zipball/v0.6.0" to /root/.cache/torch/hub/v0.6.0.zip
Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to /root/.cache/torch/hub/checkpoints/densenet201-c1103571.pth


  0%|          | 0.00/77.4M [00:00<?, ?B/s]

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0
Downloading: "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-7cdf4587.pth


  0%|          | 0.00/95.8M [00:00<?, ?B/s]

Downloading: "https://github.com/matthias-k/DeepGaze/releases/download/v1.0.0/deepgaze2e.pth" to /root/.cache/torch/hub/checkpoints/deepgaze2e.pth


  0%|          | 0.00/400M [00:00<?, ?B/s]

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0
Downloading: "https://github.com/matthias-k/DeepGaze/releases/download/v1.1.0/deepgaze3.pth" to /root/.cache/torch/hub/checkpoints/deepgaze3.pth


  0%|          | 0.00/78.9M [00:00<?, ?B/s]

In [None]:
# log_density_prediction = deepgaze2_pred(face())
# f, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
# axs[0].imshow(face())
# axs[0].set_axis_off()
# axs[1].matshow(log_density_prediction.detach().cpu().numpy()[0, 0])  # first image in batch, first (and only) channel
# axs[1].set_axis_off()


# fixations_x, fixations_y = draw_fix_from_pred(log_density_prediction, nfix=4)
# fixation_history_x = fixations_x
# fixation_history_y = fixations_y

# min 4 fixations in history: https://github.com/matthias-k/DeepGaze/blob/c33b89f08016e41e68cec4e4d9f1a73a14211386/deepgaze_pytorch/deepgaze3.py#L106
# `included_fixations=[-1, -2, -3, -4]`
# if more fixation points are provided but always at similar positions, the model always stays at the same spot
# TODO: ensure 4 fixations are far away from each other / inhibition of return


# fixation_history_x, fixation_history_y, log_density_prediction = deepgaze3_pred(face(), \
#                                                                                 fixation_history_x, fixation_history_y, \
#                                                                                 nfix_total=20)
# f, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
# axs[0].imshow(face())
# axs[0].plot(fixation_history_x, fixation_history_y, 'o-', color='red')
# axs[0].scatter(fixation_history_x[-1], fixation_history_y[-1], 100, color='yellow', zorder=100)
# axs[0].set_axis_off()
# axs[1].matshow(log_density_prediction.detach().cpu().numpy()[0, 0])  # first image in batch, first (and only) channel
# axs[1].plot(fixation_history_x, fixation_history_y, 'o-', color='red')
# axs[1].scatter(fixation_history_x[-1], fixation_history_y[-1], 100, color='yellow', zorder=100)
# axs[1].set_axis_off()


# # after running the pipeline down below:
# f, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 3))
# axs[0].imshow(images_np)
# axs[0].plot(fixation_history_x, fixation_history_y, 'o-', color='red')
# axs[0].scatter(fixation_history_x[-1], fixation_history_y[-1], 100, color='yellow', zorder=100)
# axs[0].set_axis_off()
# axs[1].matshow(log_density_prediction.detach().cpu().numpy()[0, 0])  # first image in batch, first (and only) channel
# axs[1].plot(fixation_history_x, fixation_history_y, 'o-', color='red')
# axs[1].scatter(fixation_history_x[-1], fixation_history_y[-1], 100, color='yellow', zorder=100)
# axs[1].set_axis_off()

# deepgaze try on STL10

In [None]:
# %%timeit

dataset = datasets.STL10("Dataset", split="unlabeled", download=True, \
                         transform=transforms.ToTensor(),) # download takes 2-4 min
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, \
                        drop_last=False) #TODO: batch size back to larger (75)

nfix_total = 20
scanpath_arr = np.zeros((len(dataset), nfix_total, 2)) # n img, n fix per img, x & y

counter = 0
# TODO: wrap up in bash or py like https://github.com/Animadversio/Foveated_Saccade_SimCLR/blob/66c623e3abb54c03b7bac15b1cb19bca5049ee80/run_magnif.py
# take 2 arg: parallel_proc_id (1-6) that determines counter start and end
# and device cuda
# and save each img scanpath as a separate npy file (check if exist) in case .py get interrupted

for images, _ in tqdm(dataloader): # print(images.shape) # singleton, chan, x, y
  if counter < counter_start:
    counter += 1
    continue

  img_large = F.interpolate(images.to('cuda'), [256, 256]) # interpolate to increase scanpath pred perf, 256 seems best
  img_large = torch.transpose(img_large, 1, 3)
  img_large = torch.transpose(img_large, 1, 2) # print(img_large.shape) # batch_size x height x width x 3 
  # TODO: increase batch size

  images_np = img_large.cpu().detach().numpy() # tensor to np
  images_np = np.squeeze(images_np) # print(images_np.shape, type(images_np)) # height x width x 3
  # plt.matshow(images_np[:,:,0]) # ensure img is standing upright

  log_density_prediction = deepgaze2_pred(images_np)
  fixations_x, fixations_y = draw_fix_from_pred(log_density_prediction, nfix=4)
  del img_large, log_density_prediction; torch.cuda.empty_cache()
  fixation_history_x, fixation_history_y, log_density_prediction = deepgaze3_pred(images_np, \
                                                                                  fixations_x, fixations_y, \
                                                                                  nfix_total=nfix_total)
  scanpath_arr[counter,:,0] = fixation_history_x # width
  scanpath_arr[counter,:,1] = fixation_history_y # height

  gc.collect()
  torch.cuda.empty_cache()
  # clear_output(wait=True)
  # print(torch.cuda.is_available())
  # print(torch.cuda.memory_allocated() / 1024**3)
  # print(torch.cuda.memory_reserved() / 1024**3)

  counter += 1
  if counter >= counter_end: # result doesn't contain endpoint, only contain start point
    break

np.save(f"/content/stl10_unlabeled_scanpath_deepgaze_{counter_start}_{counter_end}.npy", scanpath_arr)

Files already downloaded and verified


 94%|█████████▍| 94224/100000 [15:16<6:15:26,  3.90s/it]

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%mv fname /content/stl10_unlabeled_scanpath_deepgaze_94000_100000.npy /content/gdrive/MyDrive/stl10_unlabeled_scanpath_deepgaze_94000_100000.npy

In [None]:
# %time
# counter = 0
# for images, _ in tqdm(dataloader): # print(images.shape) # singleton, chan, x, y
#   counter += 1
#   if counter > 100000 / 5 * 4:
#     break

# 30.1/5 * 100000 / 3600/24, 132/24 # 5-7 days to run this if batch size = 1, without parallelization