# Relation Network
## Sort-of-CLEVR Dataset

## Data

### Exploring the data

In [5]:
from data_generator import build_sample, translate_sample

In [6]:
sample = build_sample()

In [12]:
sample[2][1]

[3, 0, 3, 0, 0, 0, 1, 1, 0, 1]

In [13]:
translate_sample(sample, show_img=True)

Q0. How many objects of the same shape as the orange object are there? ==> 1
Q1. What is the furthest shape from the red object? ==> rectangle
Q2. How many objects of the same shape as the green object are there? ==> 5
Q3. What is the furthest shape from the orange object? ==> circle
Q4. What is the closest shape to the yellow object? ==> circle
Q5. How many objects of the same shape as the blue object are there? ==> 5
Q6. How many objects of the same shape as the yellow object are there? ==> 5
Q7. What is the furthest shape from the yellow object? ==> rectangle
Q8. What is the furthest shape from the red object? ==> rectangle
Q9. How many objects of the same shape as the yellow object are there? ==> 5
Q10. What is the shape of the red object? ==> circle
Q11. Is there a yellow object on the top? ==> yes
Q12. What is the shape of the blue object? ==> circle
Q13. Is there a yellow object on the left? ==> yes
Q14. Is there a yellow object on the top? ==> yes
Q15. Is there a gray object on

### Generating the Data

In [1]:
!python data_generator.py

Building Train Dataset...
Building Test Dataset...
Saving Datasets...
Datasets saved at ./data


## Train

In [None]:
!python train.py

## Predictions

In [14]:
import cv2
import numpy as np
import torch
from model import RNModel

In [15]:
translate_sample(sample)

Q0. How many objects of the same shape as the orange object are there? ==> 1
Q1. What is the furthest shape from the red object? ==> rectangle
Q2. How many objects of the same shape as the green object are there? ==> 5
Q3. What is the furthest shape from the orange object? ==> circle
Q4. What is the closest shape to the yellow object? ==> circle
Q5. How many objects of the same shape as the blue object are there? ==> 5
Q6. How many objects of the same shape as the yellow object are there? ==> 5
Q7. What is the furthest shape from the yellow object? ==> rectangle
Q8. What is the furthest shape from the red object? ==> rectangle
Q9. How many objects of the same shape as the yellow object are there? ==> 5
Q10. What is the shape of the red object? ==> circle
Q11. Is there a yellow object on the top? ==> yes
Q12. What is the shape of the blue object? ==> circle
Q13. Is there a yellow object on the left? ==> yes
Q14. Is there a yellow object on the top? ==> yes
Q15. Is there a gray object on

In [16]:
cv2.imwrite('sample.jpg', cv2.resize(sample[0]*255, (512, 512)))

True

In [17]:
def preprocess_sample(sample):
    '''Preprocess a single sample'''
    img = np.swapaxes(sample[0], 0, 2)
    relations = sample[1]
    norelations = sample[2]
    
    sample_data = []
    for ques, ans in zip(relations[0], relations[1]):
        sample_data.append((img, ques, ans))
    for ques, ans in zip(norelations[0], norelations[1]):
        sample_data.append((img, ques, ans))
        
    imgs = [e[0] for e in sample_data]
    ques = [e[1] for e in sample_data]
    ans = [e[2] for e in sample_data]
    
    return torch.Tensor(imgs).float(), torch.Tensor(ques).float(), torch.Tensor(ans).long()

In [18]:
# Preprocess the inputs
imgs, ques, ans = preprocess_sample(sample)

In [19]:
# Create model and load the weights
model = RNModel(None)
model.load_state_dict(torch.load('models/epoch_40.pth'))

<All keys matched successfully>

In [20]:
# Make prediction
output = model(imgs, ques)

In [21]:
answer_map = ['yes', 'no', 'rectangle', 'circle', '1', '2', '3', '4', '5', '6']

In [22]:
pred = output.argmax(1)
accuracy = pred.eq(ans.data).cpu().sum() * 100. / len(ans)

In [23]:
pred_ans = [answer_map[i] for i in output.argmax(1)]
print(f'Predicted Answers:\n{pred_ans}')
print('\nAccuracy:', accuracy.item())

Predicted Answers:
['2', 'rectangle', '4', 'rectangle', 'circle', '4', '4', 'rectangle', 'rectangle', '4', 'circle', 'yes', 'circle', 'yes', 'yes', 'yes', 'no', 'no', 'yes', 'no']

Accuracy: 70.0
