# Relation Network
## Sort-of-CLEVR Dataset

## Data

### Exploring the data

In [1]:
from data_generator import build_sample, translate_sample

In [2]:
sample = build_sample()

In [3]:
sample[1]

([array([1., 0., 0., 0., 0., 0., 1., 0., 0., 1.]),
  array([0., 0., 0., 0., 1., 0., 1., 1., 0., 0.]),
  array([0., 0., 0., 1., 0., 0., 1., 0., 0., 1.]),
  array([0., 1., 0., 0., 0., 0., 1., 0., 1., 0.]),
  array([0., 0., 0., 0., 0., 1., 1., 1., 0., 0.]),
  array([0., 0., 0., 0., 1., 0., 1., 1., 0., 0.]),
  array([0., 0., 0., 1., 0., 0., 1., 0., 1., 0.]),
  array([0., 0., 0., 0., 1., 0., 1., 1., 0., 0.]),
  array([0., 0., 1., 0., 0., 0., 1., 1., 0., 0.]),
  array([0., 0., 0., 0., 0., 1., 1., 1., 0., 0.])],
 [7, 2, 7, 3, 2, 2, 3, 2, 2, 2])

In [4]:
translate_sample(sample, show_img=True)

Q0. How many objects of the same shape as the red object are there? ==> 4
Q1. What is the closest shape to the gray object? ==> rectangle
Q2. How many objects of the same shape as the orange object are there? ==> 4
Q3. What is the furthest shape from the green object? ==> circle
Q4. What is the closest shape to the yellow object? ==> rectangle
Q5. What is the closest shape to the gray object? ==> rectangle
Q6. What is the furthest shape from the orange object? ==> circle
Q7. What is the closest shape to the gray object? ==> rectangle
Q8. What is the closest shape to the blue object? ==> rectangle
Q9. What is the closest shape to the yellow object? ==> rectangle
Q10. Is there a yellow object on the left? ==> yes
Q11. What is the shape of the green object? ==> circle
Q12. What is the shape of the orange object? ==> rectangle
Q13. Is there a orange object on the left? ==> yes
Q14. What is the shape of the gray object? ==> rectangle
Q15. Is there a yellow object on the left? ==> yes
Q16. I

### Generating the Data

In [1]:
!python data_generator.py

Building Train Dataset...
Building Test Dataset...
Saving Datasets...
Datasets saved at ./data


## Train

In [None]:
!python train.py

## Predictions

In [5]:
import cv2
import numpy as np
import torch
from model import RNModel

In [6]:
translate_sample(sample)

Q0. How many objects of the same shape as the red object are there? ==> 4
Q1. What is the closest shape to the gray object? ==> rectangle
Q2. How many objects of the same shape as the orange object are there? ==> 4
Q3. What is the furthest shape from the green object? ==> circle
Q4. What is the closest shape to the yellow object? ==> rectangle
Q5. What is the closest shape to the gray object? ==> rectangle
Q6. What is the furthest shape from the orange object? ==> circle
Q7. What is the closest shape to the gray object? ==> rectangle
Q8. What is the closest shape to the blue object? ==> rectangle
Q9. What is the closest shape to the yellow object? ==> rectangle
Q10. Is there a yellow object on the left? ==> yes
Q11. What is the shape of the green object? ==> circle
Q12. What is the shape of the orange object? ==> rectangle
Q13. Is there a orange object on the left? ==> yes
Q14. What is the shape of the gray object? ==> rectangle
Q15. Is there a yellow object on the left? ==> yes
Q16. I

In [7]:
cv2.imwrite('sample.jpg', cv2.resize(sample[0]*255, (512, 512)))

True

In [10]:
def preprocess_sample(sample):
    '''Preprocess a single sample'''
    img = np.swapaxes(sample[0], 0, 2)
    relations = sample[1]
    norelations = sample[2]
    
    sample_data = []
    for ques, ans in zip(relations[0], relations[1]):
        sample_data.append((img, ques, ans))
    for ques, ans in zip(norelations[0], norelations[1]):
        sample_data.append((img, ques, ans))
        
    imgs = [e[0] for e in sample_data]
    ques = [e[1] for e in sample_data]
    ans = [e[2] for e in sample_data]
    
    return torch.Tensor(imgs).float(), torch.Tensor(ques).float(), torch.Tensor(ans).long()

In [11]:
# Preprocess the inputs
imgs, ques, ans = preprocess_sample(sample)

In [13]:
# Create model and load the weights
model = RNModel(None)
model.load_state_dict(torch.load('epoch_25.pth'))

<All keys matched successfully>

In [16]:
# Make prediction
output = model(imgs, ques)

In [17]:
answer_map = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6']

In [18]:
pred = output.argmax(1)
accuracy = pred.eq(ans.data).cpu().sum() * 100. / len(ans)

In [32]:
pred_ans = [answer_map[i] for i in output.argmax(1)]
print(f'Predicted Answers:\n{pred_ans}')
print('\nAccuracy:', accuracy.item())

Predicted Answers:
['3', 'rectangle', '3', 'rectangle', 'circle', 'rectangle', 'rectangle', 'circle', 'rectangle', 'rectangle', 'no', 'circle', 'rectangle', 'no', 'rectangle', 'no', 'yes', 'yes', 'no', 'rectangle']

Accuracy: 55.0
