In [16]:
from typing import TypedDict, Dict, List, Any, Optional
from PIL.Image import Image

class Doc(TypedDict):
    doc_id: str
    doc_score: str
    title: Optional[str]
    text: Optional[str]

class Question(TypedDict):
    question_id: str
    text: str
    image_id: Optional[str]
    image_file: Optional[str]

class Response(TypedDict):
    text: str
    reward: Optional[float]
    rid: str

class DataQA(TypedDict):
    question: Question              # Question type (e.g., text, image_file) to question
    ref_cap: Optional[Dict]       # Caption type to caption
    ref_ans: Dict              # List of answers, first being most relevant
    split: str                      # "train" / "dev" / "test"
    dataset: str                    # e.g, "a-okvqa"
    blip_cap: Optional[str]
    
    docs: Optional[List[Doc]]        
    responses: Optional[List[Response]]
    top_response: Optional[Response]
    misc: Optional[Dict]

In [17]:
import pandas as pd
import os
import json
from collections import Counter

In [18]:
from rampa.passage_retriever import ContrieverRetriever
passages = "data/contriever_msmarco/psgs_w100.tsv"
passages_embeddings = "data/contriever_msmarco/wikipedia_embeddings/passages_*"
retriever = ContrieverRetriever(passages=passages)

Loading the passages...


461148it [00:02, 174624.36it/s]


KeyboardInterrupt: 

In [4]:
retriever.prepare_model(model_path="models/contriever/", 
                        passages_embeddings=passages_embeddings)

Loading index from data/contriever_msmarco/wikipedia_embeddings/index.faiss, meta data from data/contriever_msmarco/wikipedia_embeddings/index_meta.faiss
Loaded index of type %s and size %d <class 'faiss.swigfaiss_avx2.IndexFlat'> 21015324


Some weights of the model checkpoint at models/contriever/ were not used when initializing Contriever: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
doc_ids = retriever.retrieve_doc_ids(['What is LLaMA?']*128,
                            per_gpu_batch_size=64)

Questions embeddings shape: torch.Size([128, 768])


100%|██████████| 1/1 [01:06<00:00, 66.88s/it]


In [15]:
doc_ids[0]

(['248462',
  '15467386',
  '248463',
  '14387423',
  '248472',
  '6419735',
  '15467387',
  '20273272'],
 array([1.1628287 , 1.0403614 , 1.0338705 , 1.0075382 , 0.9987637 ,
        0.99689585, 0.99152833, 0.98600364], dtype=float32))

In [128]:
in_path = "/mmfs1/gscratch/ark/chan0369/rampa-project/data"
out_path = "/mmfs1/gscratch/ark/chan0369/rampa-project/data_prep/iter0"
dataset = 'aokvqa'
split = 'train'
in_data_path = os.path.join(in_path, dataset)
in_file = os.path.join(in_data_path, f"aokvqa_v1p0_{split}.json")
df = pd.read_json(in_file)

In [19]:
from collections import Counter

in_path = "/mmfs1/gscratch/ark/chan0369/rampa-project/data"
out_path = "/mmfs1/gscratch/ark/chan0369/rampa-project/data_prep/iter0"
datasets = ['aokvqa', 's3vqa', 'okvqa', 'scienceqa']
splits = ['train', 'val']

for dataset in datasets:
    out_data_path = os.path.join(out_path, dataset)
    if not os.path.exists(out_data_path):
        os.mkdir(out_data_path)
    for split in splits:
        out = []
        if dataset == 'aokvqa':
            in_file = os.path.join(in_path, dataset, f"aokvqa_v1p0_{split}.json")
            df = pd.read_json(in_file)
            df = df[~df['difficult_direct_answer']].reset_index(drop=True)
            for _, row in df.iterrows():
                line: DataQA = {
                    'question': {
                        'qid': row['question_id'], 
                        'text': row['question'], 
                        'image_id': row['image_id'],
                        'image_file': f"coco/{split}2017/{str(row['image_id']).zfill(12)}.jpg",
                    },
                    'ref_ans': {
                        'direct_answers': row['direct_answers'],
                        'text': row['choices'][row['correct_choice_idx']],
                        'rationales': row['rationales']
                    },
                    'ref_cap': row['rationales'], # WIP: subject to change
                    'split': split,
                    'dataset': dataset
                    # 'misc': {
                    #     'dfficult_direct_answer': row['difficult_direct_answer']
                    # },
                }
                out.append(line)
        elif dataset == 's3vqa':
            df_questions = pd.read_json(os.path.join(in_path, dataset, f"S3-VQA_{split if split != 'val' else 'dev'}_questions.json"))
            df_annotations = pd.read_json(os.path.join(in_path, dataset, f"S3-VQA_{split if split != 'val' else 'dev'}_annotations.json"))
            df = df_questions.merge(df_annotations, how='inner')
            for _, row in df.iterrows():
                line: DataQA = {
                    'question': {
                        'qid': row['question_id'], 
                        'text': row['question'], 
                        'image_id': row['image_id'],
                        'image_file': f"openimages/s3vqa/{row['image_id']}.jpg",
                    },
                    'ref_ans': {
                        'text': row['answer']['raw'],
                        'stem': row['answer']['answer'],
                        'hyponym': row['hyponym'],
                        'hypernym': row['hypernym']
                    },
                    'split': split,
                    'dataset': dataset
                }
                out.append(line)
        elif dataset == 'okvqa':
            with open(os.path.join(in_path, dataset, f"OpenEnded_mscoco_{split}2014_questions.json")) as f:
                df_questions = pd.DataFrame(json.load(f)['questions'])
            with open(os.path.join(in_path, dataset, f"mscoco_{split}2014_annotations.json")) as f:
                df_annotations = pd.DataFrame(json.load(f)['annotations'])
            df = df_annotations.merge(df_questions, how='inner')
            df = df[df['confidence'] == 5].reset_index(drop=True)
            for _, row in df.iterrows():
                line: DataQA = {
                    'question': {
                        'qid': row['question_id'], 
                        'text': row['question'], 
                        'image_id': row['image_id'],
                        'image_file': f"coco/{split}2017/{str(row['image_id']).zfill(12)}.jpg",
                    },
                    'ref_ans': {
                        'text': Counter([a['raw_answer'] for a in row['answers']]).most_common(1)[0][0],
                        'stem': Counter([a['answer'] for a in row['answers']]).most_common(1)[0][0]
                    },
                    'misc': {
                        'confidence': row['confidence'],
                        'question_type': row['question_type'],
                        'answer_type': row['answer_type'],
                    },
                    'split': split,
                    'dataset': dataset
                }
                out.append(line)
        elif dataset == 'scienceqa':
            with open(os.path.join(in_path, dataset, "problems.json")) as f:
                df = pd.DataFrame(json.load(f)).T
                df = df[df['split']==split] # do not reset index!
            choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z
            def format_options(options, choice_prefixes):
                return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])
            def format_prompt(r, choice_prefixes):
                options = format_options(r['choices'], choice_prefixes)
                # context = f"Context: {r['hint']}\n" if r['hint'].strip() != "" else ""
                # return f'''{context}Question: {r["question"]}\nOptions:{options}'''
                return f'''{r["question"]}\nOptions: {options}'''
            def format_label(r, choice_prefixes):
                # letter_answer, direct_answer
                return choice_prefixes[r['answer']], r['choices'][r['answer']]
            for i, row in df.iterrows():
                question_with_choices = format_prompt(row, choice_prefixes=choice_prefixes)
                letter_answer, direct_answer = format_label(row, choice_prefixes=choice_prefixes)
                line: DataQA = {
                    'question': {
                        'qid': i,
                        'without_choices': row['question'],
                        'choices': row['choices'],
                        'text': question_with_choices
                    },
                    'ref_ans': {
                        'index': row['answer'],
                        'letter_answer': letter_answer,
                        'direct_answer': direct_answer,
                        'text': f"{letter_answer} {direct_answer}"
                    },
                    'misc': {
                        'hint': row['hint'],
                        'task': row['task'],
                        'solution': row['solution'],
                    },
                    'split': split,
                    'dataset': dataset
                }
                if row['image'] is not None:
                    line['question']['image_file'] = f"scienceqa/{split}/{i}"
                out.append(line)
        else:
            raise ValueError("dataset name not featured")
        out_file = os.path.join(out_data_path, f"{split}.json")
        with open(out_file, 'w') as f:
            json.dump(out, f)
    print(f"{dataset} conversion finished!")

aokvqa conversion finished!
s3vqa conversion finished!
okvqa conversion finished!
scienceqa conversion finished!


In [174]:
sqa = pd.read_json("/mmfs1/gscratch/ark/chan0369/rampa-project/data/scienceqa/problems.json").T

In [178]:
sqa = sqa.T
sqa['image'].loc[4]

In [164]:
sqa['split'].value_counts()

split
train    12726
test      4241
val       4241
Name: count, dtype: int64

In [155]:
choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z
def format_options(options, choice_prefixes):
    return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])
def format_prompt(r, choice_prefixes):
    options = format_options(r['choices'], choice_prefixes)
    # context = f"Context: {r['hint']}\n" if r['hint'].strip() != "" else ""
    # return f'''{context}Question: {r["question"]}\nOptions:{options}'''
    return f'''{r["question"]}\nOptions: {options}'''
def format_label(r, choice_prefixes):
    return choice_prefixes[r['answer']], r['choices'][r['answer']]
format_label(row, choice_prefixes=choice_prefixes)

('A', 'West Virginia')

In [159]:
sqa['answer'].value_counts()

answer
1    8542
0    8399
2    2961
3    1275
4      31
Name: count, dtype: int64

In [78]:
random_samples['question']

219      Which bird's beak is also adapted to get necta...
20021            Which wax candle has more thermal energy?
5912     Compare the average kinetic energies of the pa...
17250                        Which is a complete sentence?
10129    Compare the motion of two blue jays. Which blu...
Name: question, dtype: object

In [13]:
sample_path = os.path.join(data_path, 'aokvqa', 'aokvqa_v1p0_train.json')
dataset = pd.read_json(sample_path)
split = 'train'
row = dataset.loc[0]
line: DataQA = {
    'question': {
        'text': row['question'],
        'qid': row['question_id'],
        'image_file': f"coco/{split}2017/{str(row['image_id']).zfill(12)}.jpg"
    },
    'ref_cap': {
        'rationales': row['rationales']
    },
    'ref_ans': {
        'direct_answers': row['direct_answers']
    },
    'misc': {
        'dfficult_direct_answer': row['difficult_direct_answer']
    }
}

In [134]:
line

{'question': {'qid': 3281015,
  'text': 'What are the people riding?',
  'image_id': 328101,
  'image_file': 'coco/val2017/000000328101.jpg'},
 'ref_ans': {'raw': 'elephants', 'answer': 'elephant'},
 'misc': {'confidence': 5, 'question_type': 'eight', 'answer_type': 'other'}}