In [51]:
import os
import torch
import numpy as np
import json
import collections

In [22]:
original_vqa_link_train = "/private/home/sash/mmf/mmf/v2_mscoco_train2014_annotations.json"

with open(original_vqa_link_train) as f:
    original_vqas = json.load(f)
    original_vqas = original_vqas["annotations"] #443757

In [71]:
question_types_dict = collections.defaultdict(lambda: []) 

for original_vqa in original_vqas:
    question_type = original_vqa["answer_type"]
    question_types_dict[question_type].append(original_vqa)

In [74]:
sorted_vqa_questions = question_types_dict["other"] + question_types_dict["yes/no"] + question_types_dict['number']
len(sorted_vqa_questions)

443757

In [115]:
original_vqa_question_link_train = "/private/home/sash/mmf/mmf/v2_OpenEnded_mscoco_train2014_questions.json"
with open(original_vqa_question_link_train) as f:
    original_vqa_questions = json.load(f)
    original_vqa_questions = original_vqa_questions["questions"]
    
question_id_to_question_text = {}
for question in original_vqa_questions:
    question_id_to_question_text[question["question_id"]] = question["question"]

In [58]:
def stratify_sampling(x, n_samples, stratify):
    """Perform stratify sampling of a tensor.
    
    parameters
    ----------
    x: np.ndarray or torch.Tensor
        Array to sample from. Sampels from first dimension.
        
    n_samples: int
        Number of samples to sample
        
    stratify: tuple of int
        Size of each subgroup. Note that the sum of all the sizes 
        need to be equal to `x.shape[']`.
    """
    n_total = x.shape[0]
    assert sum(stratify) == n_total
    
    n_strat_samples = [int(i*n_samples/n_total) for i in stratify]
    cum_n_samples = np.cumsum([0]+list(stratify))
    sampled_idcs = []
    for i, n_strat_sample in enumerate(n_strat_samples):
        sampled_idcs.append(np.random.choice(range(cum_n_samples[i], cum_n_samples[i+1]), 
                                            replace=False, 
                                            size=n_strat_sample))
        
    # might not be correct number of samples due to rounding
    n_current_samples = sum(n_strat_samples)
    if  n_current_samples < n_samples:
        delta_n_samples = n_samples - n_current_samples
        # might actually resample same as before, but it's only for a few
        sampled_idcs.append(np.random.choice(range(n_total), replace=False, size=delta_n_samples))
        
    samples = x[np.concatenate(sampled_idcs), ...]
    
    return samples

In [110]:
samples = np.arange(len(sorted_vqa_questions))
n_samples = 30
other_len = len(question_types_dict["other"])
yesno_len = len(question_types_dict["yes/no"])
number_len = len(question_types_dict['number'])
stratify = [other_len, yesno_len, number_len]
output_indexes = stratify_sampling(samples, n_samples, stratify)

num_selected_types = collections.defaultdict(lambda: 0)
for output_index in output_indexes:
    if output_index < other_len:
        num_selected_types["other"] += 1
    elif output_index < other_len+yesno_len and output_index >= other_len:
        num_selected_types["yes/no"] += 1
    else:
        num_selected_types["number"] += 1

print(f"Total Selected: {n_samples}")
print(f"Other Selected: {num_selected_types['other']}/{other_len}")
print(f"YesNo Selected: {num_selected_types['yes/no']}/{yesno_len}")
print(f"Numbe Selected: {num_selected_types['number']}/{number_len}")

selected_vqa_questions = [(sorted_vqa_questions[output_index]["question_id"], sorted_vqa_questions[output_index]["image_id"]) for output_index in output_indexes]

Total Selected: 30
Other Selected: 15/219269
YesNo Selected: 12/166882
Numbe Selected: 3/57606


In [124]:
for question_id, image_id in selected_vqa_questions:
    print(f"question_id: {question_id}")
    print(f"text: {question_id_to_question_text[question_id]}")
    print(f"image_id: {image_id}")
    print(f"\n")

question_id: 395456001
text: What are the spots on the floor?
image_id: 395456


question_id: 283809004
text: Which remote is the biggest?
image_id: 283809


question_id: 555586014
text: What color is the man's tie?
image_id: 555586


question_id: 109816002
text: What is the guy in black holding in his hand?
image_id: 109816


question_id: 474601003
text: What gender is the birthday person?
image_id: 474601


question_id: 360441000
text: What street sign is at the bottom?
image_id: 360441


question_id: 480890001
text: Where is the man staring?
image_id: 480890


question_id: 524866016
text: What brand is this phone?
image_id: 524866


question_id: 88527004
text: What type of tie is he wearing?
image_id: 88527


question_id: 462512002
text: Which way is the convertible turning?
image_id: 462512


question_id: 576809001
text: Where are the orange stripes?
image_id: 576809


question_id: 333848003
text: What breed is the dog?
image_id: 333848


question_id: 180098000
text: Is the bridge 