In [1]:
import json
import pandas as pd
import numpy as np
import torch.utils.data as data
import networkx as nx
import nltk
import spacy
import gensim
import en_core_web_sm
from nltk.data import find
from nltk.corpus import wordnet



In [16]:
import os
import h5py

In [8]:
SRC_SG_PATH = "../VG-data/scene_graphs.json"

SRC_ANSWER_VOCAB_FILE = "./intermediate_files/answer_vocab.txt"
SRC_SGQAS_OF_INTEREST_QA_DATA_FILE = "./intermediate_files/filtered_qa_data.json"

DST_SG_FEATURES_DATA_FOLDER = "./intermediate_files/sg_features/"

In [7]:
global_sg_data = json.load(open(SRC_SG_PATH, 'r'))
global_qa_data = json.load(open(SRC_SGQAS_OF_INTEREST_QA_DATA_FILE, 'r'))

In [4]:
class VQAVisualGenomeDataset(data.Dataset):

    def __init__(self, sg_data_path, qa_data_path, ans_vocab_data_path):
        self.sg_data_path = sg_data_path
        self.qa_data_path = qa_data_path
        self.ans_vocab_data_path = ans_vocab_data_path  
        
        self.sample_cnt = 0
        self.data_sgvqa = []
        self._load_dataset()
        
    def _load_dataset(self):
        print('-> Loading filtered dataset ...')
        self.ans_vocab_data = sorted(open(self.ans_vocab_data_path, 'r').read().strip().split("\n"))
        sg_data = global_sg_data
        qa_data = global_qa_data
#         sg_data = json.load(open(self.sg_data_path, 'r'))
#         qa_data = json.load(open(self.qa_data_path, 'r'))        
        
        for sample_img, sample_ans in zip(sg_data, qa_data):
            if sample_img['image_id'] != sample_ans['id']:
                print("IDs did not match !")
                continue                
            for qa_index, qa in enumerate(sample_ans['qas']):
                if (qa['qas_skip']): continue
                question = qa['question']
                answer = qa['answer'].replace(".", "").lower()
                # sg = sample_img                
                self.data_sgvqa.append({"question": question, "answer": answer, "sg": sample_img})
                
                self.sample_cnt += 1
                if (self.sample_cnt > 10): break # todo: remove                       
        
        print('-> Finished loading data : num. samples -> {}'.format(self.sample_cnt))
        
    def __len__(self):
        return self.sample_cnt
                
    def __getitem__(self, index):
        if index < self.sample_cnt:
            item = self.data_sgvqa[index]
        else:
            item = self.data_sgvqa[index - self.sample_cnt]
        return item

    def num_classes(self):
        return len(self.ans_vocab_data)

#     def vocab_words(self):
#         return self.dataset_vqa.vocab_words()

#     def vocab_answers(self):
#         return self.dataset_vqa.vocab_answers()

    def data_loader(self, batch_size=10, num_workers=4, shuffle=False):
        return DataLoader(self, 
                          batch_size=batch_size, 
                          shuffle=shuffle,
                          num_workers=num_workers, 
                          pin_memory=True)

    def split_name(self, testdev=False):
        return self.data_sgvqa.split_name(testdev=testdev)

In [5]:
vqa_dataloader = VQAVisualGenomeDataset(SRC_SG_PATH, SRC_SGQAS_OF_INTEREST_QA_DATA_FILE, SRC_ANSWER_VOCAB_FILE)

-> Loading filtered dataset ...
-> Finished loading data : num. samples -> 95751


In [6]:
vqa_dataloader.__getitem__(3)

{'question': 'What are the men doing?',
 'answer': 'interacting',
 'sg': {'relationships': [{'synsets': ['along.r.01'],
    'predicate': 'ON',
    'relationship_id': 15927,
    'object_id': 5046,
    'subject_id': 5045},
   {'synsets': ['wear.v.01'],
    'predicate': 'wears',
    'relationship_id': 15928,
    'object_id': 5048,
    'subject_id': 1058529},
   {'synsets': ['have.v.01'],
    'predicate': 'has',
    'relationship_id': 15929,
    'object_id': 5050,
    'subject_id': 5049},
   {'synsets': ['along.r.01'],
    'predicate': 'ON',
    'relationship_id': 15930,
    'object_id': 1058508,
    'subject_id': 1058507},
   {'synsets': ['along.r.01'],
    'predicate': 'ON',
    'relationship_id': 15931,
    'object_id': 1058534,
    'subject_id': 5055},
   {'synsets': ['have.v.01'],
    'predicate': 'has',
    'relationship_id': 15932,
    'object_id': 1058511,
    'subject_id': 1058529},
   {'synsets': ['next.r.01'],
    'predicate': 'next to',
    'relationship_id': 15933,
    'object

# Simple Dataset for GCN

In [46]:
from scipy.sparse import csgraph
import scipy.sparse as sp
import torch


def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    
def normalize_adj(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = sp.diags(r_inv_sqrt)

    return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt).tocoo()

In [93]:
class VQAVGSimpleDataset(data.Dataset):

    def __init__(self, sg_data_path, qa_data_path, ans_vocab_data_path, feature_path):
        self.sg_data_path = sg_data_path
        self.qa_data_path = qa_data_path
        self.ans_vocab_data_path = ans_vocab_data_path  
        self.feature_path = feature_path
        
        self.sample_cnt = 0
        self.data_sgvqa = []
        self._load_dataset()
        
    def _load_dataset(self):
        print('-> Loading filtered dataset ...')
        self.ans_vocab_data = sorted(open(self.ans_vocab_data_path, 'r').read().strip().split("\n"))
        sg_data = global_sg_data[:10]
        qa_data = global_qa_data[:10]
#         sg_data = json.load(open(self.sg_data_path, 'r'))
#         qa_data = json.load(open(self.qa_data_path, 'r'))            
        
        for sample_img, sample_ans in zip(sg_data, qa_data):
            if sample_img['image_id'] != sample_ans['id']:
                print("IDs did not match !")
                continue                
                
            feature_filename = os.path.join(DST_SG_FEATURES_DATA_FOLDER, "{}.h5".format(sample_img['image_id']))
            with h5py.File(feature_filename, 'r') as hf:
                
                g = nx.Graph()
                feature_matrix = []
                for obj in sample_img['objects']:
                    obj_name = obj['names'][0]
                    obj_id = obj['object_id']                    
                    emb_vec = np.array(hf.get(str(obj_id)))

                    g.add_node(obj_id, feature=emb_vec)
                    feature_matrix.append(emb_vec)
                        
                for rel in sample_img['relationships']:
                    g.add_edge(rel['subject_id'], rel['object_id'], id=rel['relationship_id'])

                adj = nx.adjacency_matrix(g)
                # print(adj.todense())
                adj = normalize_adj(adj + sp.eye(adj.shape[0]))
                sparse_mx = adj.tocoo().astype(np.float32)
                # adj = torch.FloatTensor(np.array(adj.todense()))
                adj = np.array(adj.todense())        
                
                # feature_matrix = normalize(np.asarray(feature_matrix))
                # feature_matrix = torch.FloatTensor(np.array(feature_matrix.todense()))
#                 print(np.array(feature_matrix, np.float32).shape)
#                 feature_matrix = torch.FloatTensor(np.array(feature_matrix, np.float32))
                
                
                adj_matrix_per_sample = []    
                feature_matrix_per_sample = []
                for qa_index, qa in enumerate(sample_ans['qas']):
                    if (qa['qas_skip']): continue
                    question = qa['question']
                    answer = qa['answer'].replace(".", "").lower()
                    # sg = sample_img  
                    
                    adj_matrix_per_sample.append(adj)
                    feature_matrix_per_sample.append(feature_matrix)
                
                print(len(feature_matrix_per_sample))
                print(len(adj_matrix_per_sample))
                print(len(feature_matrix_per_sample[0]))
                print(len(adj_matrix_per_sample[0]))
                print(adj_matrix_per_sample[0].shape)
                f = np.array(feature_matrix_per_sample, np.float32)
                a = np.array(adj_matrix_per_sample)
                print(f.shape)
                print(a.shape)
                feature_matrix_per_sample = torch.FloatTensor(np.array(feature_matrix_per_sample, np.float32))
                adj_matrix_per_sample = torch.FloatTensor(np.array(adj_matrix_per_sample))
                
                self.data_sgvqa.append({"question": question, "answer": answer,
                                        "sg_adj": adj_matrix_per_sample, "sg_feat": feature_matrix})
                
                self.sample_cnt += 1
                if (self.sample_cnt > 10): break # todo: remove                       
        
        print('-> Finished loading data : num. samples -> {}'.format(self.sample_cnt))
        
    def __len__(self):
        return self.sample_cnt
                
    def __getitem__(self, index):
        if index < self.sample_cnt:
            item = self.data_sgvqa[index]
        else:
            item = self.data_sgvqa[index - self.sample_cnt]
        return item

    def num_classes(self):
        return len(self.ans_vocab_data)

#     def vocab_words(self):
#         return self.dataset_vqa.vocab_words()

#     def vocab_answers(self):
#         return self.dataset_vqa.vocab_answers()

    def data_loader(self, batch_size=10, num_workers=4, shuffle=False):
        return DataLoader(self, 
                          batch_size=batch_size, 
                          shuffle=shuffle,
                          num_workers=num_workers, 
                          pin_memory=True)

    def split_name(self, testdev=False):
        return self.data_sgvqa.split_name(testdev=testdev)

In [94]:
vqa_dataloader = VQAVGSimpleDataset(SRC_SG_PATH, SRC_SGQAS_OF_INTEREST_QA_DATA_FILE, SRC_ANSWER_VOCAB_FILE,
                                    DST_SG_FEATURES_DATA_FOLDER)

-> Loading filtered dataset ...
61
61
40
40
(40, 40)


ValueError: setting an array element with a sequence.