In [8]:

from toleft.reference_game.features import ImageNetFeat

import os

import numpy as np
import torch
import torch.nn.parallel
import torch.utils.data as data
import argparse


In [2]:
base_dir = '/home/xappma/to-your-left/signaling_game_data'

In [3]:
f = ImageNetFeat(os.path.join(base_dir, 'train'))

  data = torch.FloatTensor(list(fc[key]))


In [5]:
f.obj2id[0]

{'labels': 'accordion',
 'ims': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99]}

In [88]:
class RotFeat(data.Dataset):
    def __init__(self, root='/home/xappma/spatial-dataset/features',  train=True):
        import h5py

        self.root = os.path.expanduser(root)
        self.train = train  # training set or test set

        # FC features
        flat_features = []
        features = {}
        for target in ['target', 'distractor']:
            for rot in ['rot0', 'rot90', 'rot180', 'rot270']:
                fc_file = os.path.join(root, f'new-vgg-layers-2-{target}-{rot}.h5')
                fc = h5py.File(fc_file, 'r')
                key = list(fc.keys())[0]
                data = torch.FloatTensor(list(fc[key]))
                # img_norm = torch.norm(data, p=2, dim=1, keepdim=True)
                # normed_data = data / img_norm
                # features[(target, rot)] = normed_data
                features[(target, rot)] = data

        # normalise data

        self.features = features
        self.create_obj2id(features)
        data = self.flat_features
        img_norm = torch.norm(data, p=2, dim=1, keepdim=True)
        normed_data = data / img_norm
        self.flat_features = normed_data

        target_to_idx = {'target':1, 'distractor':0}
        rotation_to_idx = {'rot0':0, 'rot90':1, 'rot180':2, 'rot270':3}
        
        for t, r in features.keys():
            x = int(t == 'target')
            
            self.features[(t,r)] = self.flat_features[self.obj2id[target_to_idx[t]][rotation_to_idx[r]]['ims']]

    def __getitem__(self, index):
        return self.flat_features[index], index

    def __len__(self):
        return self.flat_features.size(0)

    def create_obj2id(self, feature_dict):

        self.obj2id = {}
        keys = {}
        rot_keys = {}
        idx_label = 0
        flat_features = None
        for i, (target, rot) in enumerate(feature_dict.keys()):
            if rot == 'rot0':
                r = 0
            elif rot == 'rot90':
                r = 1
            elif rot == 'rot180':
                r = 2
            else:
                r = 3
            
            if target not in keys:
                if target == 'target':
                    key = 1
                else:
                    key = 0
                keys[target] = key
                self.obj2id[key] = [{'labels':(target, 'rot0'), 'ims':[]},
                                    {'labels':(target, 'rot90'), 'ims':[]}, 
                                    {'labels':(target, 'rot180'), 'ims':[]}, 
                                    {'labels':(target, 'rot270'), 'ims':[]}]

            end = idx_label + len(feature_dict[(target, rot)])
            self.obj2id[keys[target]][r]['ims'] = np.array(list(range(idx_label, end)))
            idx_label = end
            if flat_features is None:
                flat_features = feature_dict[(target, rot)]
            else:
                flat_features = torch.cat((flat_features, feature_dict[(target, rot)]))
        self.flat_features = flat_features

In [89]:
dataset = RotFeat()

In [20]:
assert(all(r.flat_features[r.obj2id[1][0]['ims'][20]] == r.features[('target', 'rot0')][20]))
assert(all(r.flat_features[r.obj2id[1][2]['ims'][20]] == r.features[('target', 'rot180')][20]))
assert(all(r.flat_features[r.obj2id[0][2]['ims'][20]] == r.features[('distractor', 'rot180')][20]))

In [109]:
class _BatchIterator:
    def __init__(self, loader, n_batches, seed=None):
        self.loader = loader
        self.n_batches = n_batches
        self.batches_generated = 0
        self.random_state = np.random.RandomState(seed)

    def __iter__(self):
        return self

    def __next__(self):
        if self.batches_generated > self.n_batches:
            raise StopIteration()

        batch_data = self.get_batch()
        self.batches_generated += 1
        return batch_data

    def get_batch(self):
        loader = self.loader
        opt = loader.opt

        # C = len(self.loader.dataset.obj2id.keys())  # number of concepts
        images_indexes_sender = np.zeros((opt.batch_size, opt.game_size))
        
        target_ims = loader.dataset.obj2id[1][0]['ims'] # get the target image with rotation 0
        distractor_ims = loader.dataset.obj2id[0][0]['ims']

        assert(target_ims[0] != distractor_ims[0])
                
        idxs = self.random_state.choice(list(range(len(target_ims))), opt.batch_size).astype(int)
        # print('idx', len(idxs))      
        
        target = target_ims[idxs]
        distractor = distractor_ims[idxs]

        # print('target', target.shape)

        assert(target[0] != distractor[0])
        
        images_indexes_sender[:, 1] = target
        images_indexes_sender[:, 0] = distractor

        print('sender indexes', images_indexes_sender.shape)
        
        images_vectors_sender = []

        for i in range(opt.game_size):
            x = loader.dataset.flat_features[images_indexes_sender[:, i]]
            images_vectors_sender.append(x)

        images_vectors_sender = torch.stack(images_vectors_sender).contiguous()

        print('images vector', images_vectors_sender.shape)
        
        y = torch.zeros(opt.batch_size).long()

        print('y', y.shape)
        
        images_vectors_receiver = torch.zeros_like(images_vectors_sender)
        for i in range(opt.batch_size):
            permutation = torch.randperm(opt.game_size)

            # print(permutation)
            
            images_vectors_receiver[:, i, :] = images_vectors_sender[permutation, i, :]
            y[i] = permutation.argmin()

        return images_vectors_sender, y, images_vectors_receiver



class RotationLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, **kwargs):
        self.opt = kwargs.pop("opt")
        self.seed = kwargs.pop("seed")
        self.batches_per_epoch = kwargs.pop("batches_per_epoch")

        super(RotationLoader, self).__init__(*args, **kwargs)

    def __iter__(self):
        if self.seed is None:
            seed = np.random.randint(0, 2 ** 32)
        else:
            seed = self.seed
        return _BatchIterator(self, n_batches=self.batches_per_epoch, seed=seed)


In [110]:
class Args(argparse.Namespace):
  batch_size = 32
  game_size = 2

args=Args()

l = RotationLoader(opt=args, seed=3, batches_per_epoch=50, dataset=dataset)

In [112]:
for batch in l:
    s, y, r = batch
    for i in range(len(batch)):
        print(s[0][i], s[1][i])
        print(r[0][i], r[1][i])
        
        print(y[i])
        print()
    break

sender indexes (32, 2)
images vector torch.Size([2, 32, 4096])
y torch.Size([32])
tensor([-0.0248, -0.0127,  0.0027,  ..., -0.0025, -0.0217,  0.0226]) tensor([-0.0276, -0.0116,  0.0002,  ..., -0.0051, -0.0221,  0.0190])
tensor([-0.0248, -0.0127,  0.0027,  ..., -0.0025, -0.0217,  0.0226]) tensor([-0.0276, -0.0116,  0.0002,  ..., -0.0051, -0.0221,  0.0190])
tensor(0)

tensor([-0.0331,  0.0073,  0.0079,  ...,  0.0034, -0.0145,  0.0035]) tensor([-0.0331,  0.0009,  0.0141,  ..., -0.0039, -0.0177,  0.0066])
tensor([-0.0331,  0.0009,  0.0141,  ..., -0.0039, -0.0177,  0.0066]) tensor([-0.0331,  0.0073,  0.0079,  ...,  0.0034, -0.0145,  0.0035])
tensor(1)

tensor([-0.0386, -0.0056,  0.0203,  ..., -0.0058, -0.0184,  0.0072]) tensor([-0.0264, -0.0049,  0.0139,  ...,  0.0021, -0.0191,  0.0053])
tensor([-0.0386, -0.0056,  0.0203,  ..., -0.0058, -0.0184,  0.0072]) tensor([-0.0264, -0.0049,  0.0139,  ...,  0.0021, -0.0191,  0.0053])
tensor(0)



In [104]:
a, b = batch[0][:, 1, :]
all(a == b)
len(batch[0][0])

32

In [50]:
all(a == b)

True

In [84]:
state = np.random.RandomState(3)
np.array([1,2,3,4,5])[state.choice([1,2, 3], 2, replace=False)]

array([3, 2])

In [82]:
state.choice([1,2, 3], 2, replace=False).astype(int)

array([1, 3])