In [1]:
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import pickle

In [2]:
from PIL import Image

In [4]:
names = open("image_names.txt", "w")
image_names = os.listdir('../data/ALLSTIMULI/')[2:-3]
for i in range(len(image_names)):
    names.write(image_names[i][:-5]+'\n')
names.close()

In [11]:
names = open("image_names.txt", "r")
img_names = names.readlines()
for i in range(len(img_names)):
    img_names[i]=img_names[i][:-1]
    
loc_data_xy={}
for name in img_names:
    locpath = '../data/loc_data/' + name
    f = open(locpath,'rb')
    loc_dict = pickle.load(f)
    loc_data_xy[name] = np.array(loc_dict['barycenters'])

In [12]:
len(loc_data_xy)

1000

In [13]:
def show_landmarks(image, landmarks):
    """Show image with landmarks"""
    plt.imshow(image)
    plt.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    #plt.pause(0.001)  # pause a bit so that plots are updated

In [14]:
for name in img_names[:30]:
    plt.figure()
    impath = '../data/ALLSTIMULI/' + name + ''
    print(impath)
    img = Image.open(impath)
    img_npy = np.asarray(img)
    landmarks = np.array(loc_data_xy[name])
    show_landmarks(img_npy, landmarks)
    #loc_ij = []
    #for coord in loc_xy:
    #    loc_ij += [[coord[1], coord[0]]]
    #loc_ij = np.array(loc_ij)
    #plt.plot(loc_data_ij_part[:,1], loc_data_ij_part[:,0])
    #for i, coord in enumerate(loc_xy):
    #    plt.plot(coord[0], coord[1], 'r+', ms=32)
    

../data/ALLSTIMULI/i05june05_static_street_boston_p1010764


FileNotFoundError: [Errno 2] No such file or directory: '../data/ALLSTIMULI/i05june05_static_street_boston_p1010764'

<Figure size 432x288 with 0 Axes>

# Dataset class

In [None]:
class SaccadeLandmarksDataset(Dataset):
    """Saccade Landmarks dataset."""

    def __init__(self, loc_dict, img_dir, transform=None):
        """
        Args:
            loc_dir (string): Path to the saccade location file
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.loc_dict = loc_dict
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.loc_dict)

    def __getitem__(self, idx):

        img_name = os.listdir(self.img_dir)[idx+2]
        img_path = os.path.join(self.img_dir,img_name)
        image = io.imread(img_path)
        name = img_name[:-5]
        landmarks = self.loc_dict[name]
        landmarks = np.array([landmarks])
        landmarks = landmarks.reshape(-1, 2) #.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
saccade_dataset = SaccadeLandmarksDataset(loc_dict=loc_data_xy,
                                    img_dir='../data/ALLSTIMULI/')

fig = plt.figure(figsize=(10,5))

for i in range(len(saccade_dataset)):
    sample = saccade_dataset[i]

    print(i, sample['image'].shape, sample['landmarks'].shape)

    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    show_landmarks(**sample)

    if i == 3:
        plt.show()
        break

In [None]:
len(sample['landmarks'])

# Transforms

In [None]:
class RandomSaccadeTo(object):
    def __call__(self, sample):
        image, landmarks = sample['image'], sample['landmarks']
        nb_sac = len(landmarks)
        sac_num =  np.random.randint(nb_sac)
        sac = landmarks[sac_num]
        N_X, N_Y = image.shape[:2]
        #img_color_sac = saccade_to(image, (N_X//2, N_Y//2), (sac[1], sac[0]))
        image_roll = np.copy(image)
        image_roll=np.roll(image_roll, N_X//2 - sac[1], axis=0)
        image_roll=np.roll(image_roll, N_Y//2 - sac[0], axis=1)
        return {'image':image_roll, 'pos':sac}

In [None]:
sample_sac = RandomSaccadeTo()(sample)
plt.imshow(sample_sac['image'])
N_X, N_Y = sample_sac['image'].shape[:2]
plt.scatter(N_Y//2, N_X//2, c='r')

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):

        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image_tens = sample['image'].transpose((2, 0, 1))
        return {'image': torch.FloatTensor(image_tens), 'pos': sample['pos']}

In [None]:
sample_tens = ToTensor()(sample_sac)

### Adapted cropped pyramid (squeezed tensor)

In [None]:
from PYramid import cropped_pyramid

In [None]:
class CroppedPyramid(object):
    def __init__(self, width, base_levels, color=True, do_mask=False, verbose=False):
        self.width = width
        self.base_levels = base_levels
        self.color = color
        self.do_mask = do_mask
        self.verbose = verbose
    
    def __call__(self, sample):
        img_crop, level_size = cropped_pyramid(sample['image'].unsqueeze(0), 
                                               width=self.width, 
                                               base_levels=self.base_levels,
                                               color=self.color, 
                                               do_mask=self.do_mask, 
                                               verbose=self.verbose,
                                               squeeze=True)
        return{'img_crop':img_crop, 'level_size':level_size, 'pos':sample['pos']}
        
    

In [None]:
width=32
base_levels=2
cropped_pyr_transform = CroppedPyramid(width, base_levels)
transformed_data = cropped_pyr_transform(sample_tens)

In [None]:
transformed_data['img_crop'].shape

In [None]:
from LogGabor import LogGabor
from PYramid import local_filter
from PYramid import get_K
from PYramid import log_gabor_transform

In [None]:
n_sublevel = 2 
n_azimuth = 12 
n_theta = 12
n_phase = 2

pe = {'N_X': width, 'N_Y': width, 'do_mask': False, 'base_levels':
          base_levels, 'n_theta': 24, 'B_sf': 0.6, 'B_theta': np.pi/12 ,
      'use_cache': True, 'figpath': 'results', 'edgefigpath':
          'results/edges', 'matpath': 'cache_dir', 'edgematpath':
          'cache_dir/edges', 'datapath': 'database/', 'ext': '.pdf', 'figsize':
          14.0, 'formats': ['pdf', 'png', 'jpg'], 'dpi': 450, 'verbose': 0}   

lg = LogGabor(pe)

In [None]:
K = get_K(width=width,
            n_sublevel = n_sublevel, 
          n_azimuth = n_azimuth, 
                  n_theta = n_theta,
                  n_phase = n_phase, 
                  r_min = width/6, 
                  r_max = width/3, 
                  log_density_ratio = 2, 
                  verbose=False)

In [None]:
class LogGaborTransform(object):
    def __init__(self, K=K, color=True, verbose=False):
        self.K = K
        self.color = color
        self.verbose = verbose
    
    def __call__(self, sample):
        log_gabor_coeffs = log_gabor_transform(sample['img_crop'].unsqueeze(0), K)
        
        return{'img_gabor':log_gabor_coeffs, 'K':K}

In [None]:
my_transform = LogGaborTransform()
transformed_data = my_transform(transformed_data)

# Compose transforms

In [None]:
composed_transform = transforms.Compose([RandomSaccadeTo(),
                               ToTensor(),
                               CroppedPyramid(width, base_levels)])

In [None]:
transformed_data = composed_transform(sample)

# Iterating through the dataset

In [None]:
saccade_dataset = SaccadeLandmarksDataset(loc_dict=loc_data_xy,
                                          img_dir='../data/ALLSTIMULI/',
                                          transform=composed_transform)

In [None]:
for i in range(len(saccade_dataset)):
    sample = saccade_dataset[i]

    print(i, sample['img_crop'].size(), sample['level_size'])

    if i == 3:
        break

In [None]:
sample['pos']

In [None]:
# Helper function to show a batch
def show_landmarks_batch(sample_batched):
    """Show image with landmarks for a batch of samples."""
    for level in range(5,0,-1):
        plt.figure()
        images_batch = sample_batched['img_crop'][:,level,:,:,:]
        batch_size = len(images_batch)
        im_size = images_batch.size(2)
        grid_border_size = 2

        grid = utils.make_grid(images_batch)
        plt.imshow(grid.numpy().transpose((1, 2, 0)).clip(0,255).astype('uint8'))

        plt.title('Batch from dataloader, level=' + str(level))


In [None]:
dataloader = DataLoader(saccade_dataset, batch_size=4,
                        shuffle=True, num_workers=0)
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['img_crop'].size())
    if i_batch ==3 :
        plt.figure()
        show_landmarks_batch(sample_batched)    
        break