In [None]:
import os

import pickle
import random
import numpy as np

import torch
from torch.autograd import Variable

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
#import argparse
#parser = argparse.ArgumentParser(description='PyTorch Relations-from-Stream sort-of-CLVR Example')
#args = parser.parse_args()

from attrdict import AttrDict
args = AttrDict()

In [None]:
torch.cuda.is_available()

In [None]:
args.batch_size = 32
args.cuda = torch.cuda.is_available()
args.lr   = 0.0001
args.seed = 5
args.process_coords=False
args.debug = True
args.rnn_hidden_size = 32

In [None]:
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

In [None]:
# Load in a batch of images to test dimensions, etc...
bs = args.batch_size

input_img = torch.FloatTensor(bs, 3, 75, 75)
input_qst = torch.FloatTensor(bs, 11)
label = torch.LongTensor(bs)

if args.cuda:
    input_img = input_img.cuda()
    input_qst = input_qst.cuda()
    label = label.cuda()
    
input_img = Variable(input_img)
input_qst = Variable(input_qst)
label = Variable(label)

In [None]:
data_dirs = './data'

filename = os.path.join(data_dirs, 'sort-of-clevr.pickle')
with open(filename, 'rb') as f:
  train_datasets, test_datasets = pickle.load(f)

def cvt_data_axis(data):
    img = [e[0] for e in data]
    qst = [e[1] for e in data]
    ans = [e[2] for e in data]
    return (img,qst,ans)

rel_train, norel_train = [], []
for img, relations, norelations in train_datasets:
    img = np.swapaxes(img,0,2)
    for qst,ans in zip(relations[0], relations[1]):
        rel_train.append((img,qst,ans))
    for qst,ans in zip(norelations[0], norelations[1]):
        norel_train.append((img,qst,ans))

rel   = cvt_data_axis(rel_train)
norel = cvt_data_axis(norel_train)

In [None]:
def tensor_data(data, i):
    img = torch.from_numpy(np.asarray(data[0][bs*i:bs*(i+1)]))
    qst = torch.from_numpy(np.asarray(data[1][bs*i:bs*(i+1)]))
    ans = torch.from_numpy(np.asarray(data[2][bs*i:bs*(i+1)]))

    input_img.data.resize_(img.size()).copy_(img)
    input_qst.data.resize_(qst.size()).copy_(qst)
    label.data.resize_(ans.size()).copy_(ans)

tensor_data(rel, 0)    # Loads batch 0 into input_img, input_qst amd label
#tensor_data(norel, 0)

In [None]:
def show_image(img):
    im = np.swapaxes( img.cpu().data.numpy(), 0,2)  # Undo the np->pytorch swap
    plt.figure(figsize=(6,6))
    plt.imshow(im[:,:,::-1], interpolation='nearest')  # BGR-> RGB
    #plt.axis('off')

def show_question(q):
    colors = ['red ', 'green ', 'blue ', 'orange ', 'gray ', 'yellow ']
    question = list(q.cpu().data.numpy())
    query = colors[question[0:6].index(1)]

    if question[6] == 1:  # NonRel Questions
        if question[8] == 1:
            query += 'shape?'
        if question[9] == 1:
            query += 'left?'
        if question[10] == 1:
            query += 'up?'
    if question[7] == 1:  # Rel questions
        if question[8] == 1:
            query += 'closest shape?'
        if question[9] == 1:
            query += 'furthest shape?'
        if question[10] == 1:
            query += 'count?'
    return query

def show_answer(a):
    answer_sheet = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6']
    answer = a.cpu().data.numpy()[0]
    return answer_sheet[answer]

def show_example(i):
    print( show_question( input_qst[i] ), show_answer( label[i] ) )
    show_image(input_img[i])

show_example(18)

In [None]:
#from model import RN, CNN_MLP, RFS
import model

import importlib
importlib.reload(model)

args.process_coords=False

m = model.RFS(args)
if args.cuda:
    m.cuda()
m.train();

#accuracy_rel = m.train_(input_img, input_qst, label)
#accuracy_norel = m.train_(input_img, input_qst, label)

# Load a snapshot
m.load_state_dict( torch.load('model/RFS_2item-span-again-seed10_050.pth') )

m.optimizer.zero_grad()
output = m(input_img, input_qst)

### Sanity checks

In [None]:
a = torch.zeros(8, 2)
b = torch.ones(8, 7)

c = torch.cat( (a,b), 1)
c

In [None]:
p = torch.rand( (1,5) )
p
p.expand( (6,5) )
p.expand( (6,5) ) + torch.rand( (6,5) )
#p

In [None]:
b = torch.from_numpy( np.array([[ 1.,2.,3.], [6.,1.,4. ] ], dtype=np.float32) )
b = Variable(b)
b
torch.nn.functional.softmax(b)    # This is the expected one (probs add up to 1 along rows)
#torch.nn.functional.log_softmax(b)

In [None]:
#model.sample_gumbel(b) 
model.gumbel_softmax_sample(b, temperature=0.4)