In [1]:
import os
import sys
import json
import numpy as np
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
import torchvision
from PIL import Image
from skimage.transform import resize

sys.path.append('./src')
from dataset import HDMdataset
from models import IT2P_history, IT2P_nonhistory
from utils import generate_spatial_batch

In [2]:
dictionary = json.load(open('./data/dictionary.json', 'r'))

In [3]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark=False

### Set below history_flag as True to add history information in training.

In [4]:
history_flag = False

In [5]:
transform = torchvision.transforms.ToTensor()

In [6]:
device = 'cuda'
temp = 2 # 2 for models proposed in paper. 

if history_flag:
    model = IT2P_history(512, 2, dictionary, 300, temp, depth=4).to(device)
else:
    model = IT2P_nonhistory(512, 2, dictionary, 300, temp, depth=4).to(device)

In [7]:
spatial_coords = torch.FloatTensor(generate_spatial_batch(1)).permute(0, 3, 1, 2).to(device)

model.load_state_dict(torch.load('./best_models/proposed_wo_history.pth'))
model.eval()

print('loaded model')

data_dir = './test_tasks/'
result_dir = './performance/best_proposed_wo_history_rep/'

pick_correct = 0 
place_correct = 0
both_correct = 0
pp_cnt = 0
task_correct = 0
task_cnt = 0
thres = 15
for d in sorted(os.listdir(data_dir)):
    if d[0] != '.':
        img_dir = os.path.join(data_dir, d, 'image')
        meta_dir = os.path.join(data_dir, d, 'meta')

        os.makedirs(os.path.join(result_dir, d), exist_ok=True)

        images = []
        for fp in sorted(os.listdir(img_dir)):
            img_fp = os.path.join(img_dir, fp)
            start_img = transform(Image.open(img_fp))[:3] * 2 - 1
            images.append(start_img.unsqueeze(0))

        num_pp = len(images)
        for fp in sorted(os.listdir(meta_dir)):
            meta = json.load(open(os.path.join(meta_dir, fp), 'r'))
            dist_results = {'pick':[], 'place':[]}
            histories = []

            for i in range(num_pp):
                curr_sentence = meta['sentence'][i]
                curr_explicit = meta['explicit'][i]
                curr_implicit = meta['implicit'][i]
                gt_bbox = meta['bbox'][i]
                gt_pick = [(gt_bbox['pick'][0] + gt_bbox['pick'][2] ) / 2,
                            (gt_bbox['pick'][1] + gt_bbox['pick'][3] ) / 2]
                gt_place = [(gt_bbox['place'][0] + gt_bbox['place'][2] ) / 2,
                             (gt_bbox['place'][1] + gt_bbox['place'][3] ) / 2]

                language = torch.LongTensor([dictionary[w] for w in curr_sentence.split()]).unsqueeze(0)
                time = torch.LongTensor([i]).to(device)
                lang_lengths = torch.LongTensor([len(curr_sentence.split())])

                with torch.no_grad():
                    if history_flag:
                        pred, histories = model(images[i].float().to(device), 
                                                language.long().to(device),
                                                lang_lengths, spatial_coords, time, histories)
                    else:
                        pred = model(images[i].float().to(device), 
                                     language.long().to(device),
                                     lang_lengths, 
                                     spatial_coords)


                pick_pred = resize(pred[:, 0, :, :].squeeze().detach().cpu().numpy(), (256, 256))
                place_pred = resize(pred[:, 1, :, :].squeeze().detach().cpu().numpy(), (256, 256))

                esti_pick = np.where(pick_pred == np.max(pick_pred))
                esti_pick = [esti_pick[1][0], esti_pick[0][0]]
                esti_place = np.where(place_pred == np.max(place_pred))
                esti_place = [esti_place[1][0], esti_place[0][0]]

                pick_dist = np.sqrt((gt_pick[0]-esti_pick[0]) ** 2 + (gt_pick[1]-esti_pick[1])**2)
                place_dist = np.sqrt((gt_place[0]-esti_place[0]) ** 2 + (gt_place[1]-esti_place[1])**2)

                dist_results['pick'].append([int(esti_pick[0]), int(esti_pick[1])])
                dist_results['place'].append([int(esti_place[0]), int(esti_place[1])])

                if pick_dist < thres:
                    pick_correct += 1
                if place_dist < thres:
                    place_correct += 1
                if pick_dist < thres and place_dist < thres:
                    both_correct += 1
                pp_cnt += 1
            del histories
            histories = []
            torch.cuda.empty_cache()
            json.dump(dist_results, open(os.path.join(result_dir, d, fp), 'w' ))
print(pp_cnt, pick_correct, place_correct, both_correct)

loaded model
3345 2978 2439 2199


In [8]:
print(pp_cnt, pick_correct, place_correct, both_correct)

3345 2978 2439 2199
