In [14]:
from transformers import AutoTokenizer
from tqdm import tqdm
from typing import List, Dict

import json
import os
import matplotlib.pyplot as plt
import numpy as np

In [15]:
def get_context_lengths_cosmosqa(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['context']
        question = instance['question']
        answers = [instance['answer0'], instance['answer1'], instance['answer2'], instance['answer3']]
        longest_answer = sorted(answers, key=lambda e: len(e))[-1]

        context = " </s> </s>".join([context, question, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [20]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
data_path = '../irt_data/cosmosqa'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_cosmosqa(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 25262/25262 [00:12<00:00, 2105.04it/s]
100%|██████████| 1492/1492 [00:00<00:00, 2161.15it/s]
100%|██████████| 1493/1493 [00:00<00:00, 2274.73it/s]

train 0.0 0.0
val 0.0 0.0
test 0.0 0.0





In [23]:
def get_context_lengths_boolq(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['passage']
        question = instance['question']

        context = " </s> </s>".join([context, question])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [24]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
data_path = '../irt_data/boolq'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_boolq(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 9427/9427 [00:06<00:00, 1347.04it/s]
100%|██████████| 1635/1635 [00:00<00:00, 1755.38it/s]
100%|██████████| 1635/1635 [00:00<00:00, 1761.49it/s]

train 5.66458046037976 0.1803330858173332
val 4.8318042813455655 0.3669724770642202
test 5.321100917431193 0.24464831804281345





In [26]:
def get_context_lengths_hellaswag(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['ctx']
        question = instance['ctx_b']
        longest_answer = sorted(instance['endings'], key=lambda e: len(e))[-1]

        context = " </s> </s>".join([context, question, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [27]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
data_path = '../irt_data/hellaswag'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_hellaswag(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 39905/39905 [00:17<00:00, 2217.04it/s]
100%|██████████| 5021/5021 [00:02<00:00, 2388.43it/s]
100%|██████████| 5021/5021 [00:02<00:00, 2400.63it/s]

train 0.0 0.0
val 0.0 0.0
test 0.0 0.0





In [28]:
def get_context_lengths_mutual(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['article']
        longest_answer = sorted(instance['options'], key=lambda e: len(e))[-1]

        context = " </s> </s>".join([context, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [30]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
data_path = '../irt_data/mutual/data/mutual'
histograms = []
num_examples = []
for phase in ["train", "dev", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_mutual(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "dev", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 7088/7088 [00:03<00:00, 2248.37it/s]
100%|██████████| 443/443 [00:00<00:00, 2458.35it/s]
100%|██████████| 443/443 [00:00<00:00, 2413.17it/s]

train 5.431715575620768 0.0
dev 3.386004514672686 0.0
test 6.772009029345372 0.0





In [31]:
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
data_path = '../irt_data/mutual/data/mutual_plus'
histograms = []
num_examples = []
for phase in ["train", "dev", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_mutual(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "dev", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 7088/7088 [00:03<00:00, 2066.18it/s]
100%|██████████| 443/443 [00:00<00:00, 1425.89it/s]
100%|██████████| 443/443 [00:00<00:00, 1955.44it/s]

train 4.500564334085778 0.0
dev 3.8374717832957113 0.0
test 4.063205417607223 0.0





In [54]:
def get_context_lengths_mcscript(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['passage']['text']
        questions = instance['passage']['questions']
        
        for q in questions:
            qas = []
            for a in q['answers']:
                qa = q['question'] + ' ' + a['text']
                qas.append(qa)
            
            longest_answer = sorted(qa, key=lambda e: len(e))[-1]

            context = " </s> </s>".join([context, longest_answer])
            tokens = context.split(" ")
            if tokenizer:
                tokens = tokenizer.tokenize(context)
            length = len(tokens)
            context_lengths.append(length)
    return context_lengths

In [55]:
data_path = '../irt_data/mcscript_2.0'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_mcscript(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 2500/2500 [00:08<00:00, 291.46it/s]
100%|██████████| 355/355 [00:01<00:00, 293.76it/s]
100%|██████████| 632/632 [00:02<00:00, 282.81it/s]

train 5.862870833626947 0.10570079627933197
val 4.603960396039604 0.0
test 8.9196675900277 0.08310249307479224





In [46]:
def get_context_lengths_quail(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            raw_data.append(json.loads(line))

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['context']
        question = instance['question']
        longest_answer = sorted(instance['answers'], key = lambda e: len(e))[-1]

        context = " </s> </s>".join([context, question, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [47]:
data_path = '../irt_data/quail'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_quail(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 10246/10246 [00:13<00:00, 763.72it/s]
100%|██████████| 2164/2164 [00:02<00:00, 776.40it/s]
100%|██████████| 555/555 [00:00<00:00, 803.00it/s]

train 95.93987897716183 4.060121022838181
val 95.00924214417745 4.990757855822551
test 89.90990990990991 10.09009009009009





In [52]:
def get_context_lengths_mrqa_nq(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            items = json.loads(line)
            if "header" in items:
                continue
            raw_data.append(items)

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['context']
        question = instance['qas'][0]['question']
        answers = instance['qas'][0]['answers']
        longest_answer = sorted(answers, key = lambda e: len(e))[-1]

        context = " </s> </s>".join([context, question, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [53]:
data_path = '../irt_data/mrqa_natural_questions'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_mrqa_nq(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 104071/104071 [01:32<00:00, 1130.11it/s]
100%|██████████| 6418/6418 [00:05<00:00, 1198.76it/s]
100%|██████████| 6418/6418 [00:05<00:00, 1273.16it/s]


train 10.15364510766688 10.608142518088613
val 9.909629167965099 12.29354939233406
test 10.003116235587411 12.480523527578685


In [56]:
def get_context_lengths_newsqa(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            items = json.loads(line)
            if "header" in items:
                continue
            raw_data.append(items)

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['text']
        qas = instance['qas']
        for qa in qas:
            question = qa['question']
            answer = context[qa['answer']['s']:qa['answer']['e']+1]

            context = " </s> </s>".join([context, question, answer])
            tokens = context.split(" ")
            if tokenizer:
                tokens = tokenizer.tokenize(context)
            length = len(tokens)
            context_lengths.append(length)
    return context_lengths

In [57]:
data_path = '../irt_data/newsqa'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_newsqa(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 11469/11469 [03:19<00:00, 57.37it/s]
100%|██████████| 638/638 [00:10<00:00, 59.27it/s]
100%|██████████| 637/637 [00:10<00:00, 60.16it/s]

train 15.84604534531397 83.28022150245533
val 16.624453142988717 82.5926778724384
test 16.538551129746097 82.64616818075937





In [58]:
num_examples

[76568, 4343, 4293]

In [68]:
def get_context_lengths_quoref(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        for line in f:
            items = json.loads(line)
            if "header" in items:
                continue
            raw_data.append(items)

    context_lengths = []

    for instance in tqdm(raw_data):
        context = instance['context']
        question = instance['question']
        answers = []
        for a in instance['answers']['text']:
            answers.append(a)
        
        longest_answer = sorted(answers, key=lambda e: len(e))[-1]
        context = " </s> </s>".join([context, question, longest_answer])
        tokens = context.split(" ")
        if tokenizer:
            tokens = tokenizer.tokenize(context)
        length = len(tokens)
        context_lengths.append(length)
    return context_lengths

In [69]:
data_path = '../irt_data/quoref'
histograms = []
num_examples = []
for phase in ["train", "val", "test"]:
    path = os.path.join(data_path, phase + ".jsonl")
    context_lengths = get_context_lengths_quoref(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "val", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 19399/19399 [00:25<00:00, 746.91it/s]
100%|██████████| 1209/1209 [00:01<00:00, 760.14it/s]
100%|██████████| 1209/1209 [00:01<00:00, 806.09it/s]

train 62.91045930202588 35.23377493685241
val 65.34325889164599 33.41604631927213
test 66.3358147229115 32.09263854425145





In [82]:
data_path = '../irt_data/squad_v2/dev-v2.0.json'
f = open(data_path)
json_data = json.load(f)

In [85]:
data = json_data['data']

In [90]:
data[0]['paragraphs'][0]

{'qas': [{'question': 'In what country is Normandy located?',
   'id': '56ddde6b9a695914005b9628',
   'answers': [{'text': 'France', 'answer_start': 159},
    {'text': 'France', 'answer_start': 159},
    {'text': 'France', 'answer_start': 159},
    {'text': 'France', 'answer_start': 159}],
   'is_impossible': False},
  {'question': 'When were the Normans in Normandy?',
   'id': '56ddde6b9a695914005b9629',
   'answers': [{'text': '10th and 11th centuries', 'answer_start': 94},
    {'text': 'in the 10th and 11th centuries', 'answer_start': 87},
    {'text': '10th and 11th centuries', 'answer_start': 94},
    {'text': '10th and 11th centuries', 'answer_start': 94}],
   'is_impossible': False},
  {'question': 'From which countries did the Norse originate?',
   'id': '56ddde6b9a695914005b962a',
   'answers': [{'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
    {'text': 'Denmark, Iceland and Norway', 'answer_start': 256},
    {'text': 'Denmark, Iceland and Norway', 'answer_star

In [107]:
def get_context_lengths_squad(raw_file_path, tokenizer = None):
    raw_data = []
    with open(raw_file_path, "r") as f:
        raw_data = json.load(f)
    
    context_lengths = []
    for data in tqdm(raw_data['data']):
        for instance in data['paragraphs']:
            context = instance['context']
            questions = instance['qas']
            for q in questions:
                question = q['question']
                if q['answers']:
                    answers = [x['text'] for x in q['answers']]
                else:
                    answers = [x['text'] for x in q['plausible_answers']]

                if answers:
                    longest_answer = sorted(answers, key=lambda e: len(e))[-1]
                else:
                    longest_answer = ""
                context = " </s> </s>".join([context, question, longest_answer])
                tokens = context.split(" ")
                if tokenizer:
                    tokens = tokenizer.tokenize(context)
                length = len(tokens)
                context_lengths.append(length)
    return context_lengths

In [108]:
data_path = '../irt_data/squad_v2'
histograms = []
num_examples = []
for phase in ["train", "dev", "test"]:
    path = os.path.join(data_path, phase + "-v2.0.json")
    context_lengths = get_context_lengths_squad(path, tokenizer)
    hist, _ = np.histogram(context_lengths, bins=[0, 129, 257, 513, 10000])
    histograms.append(hist)
    num_examples.append(len(context_lengths))

for phase, hist, total in zip(["train", "dev", "test"], histograms, num_examples):
    print(phase, hist[2] * 100 / total, hist[3] * 100 / total)

100%|██████████| 442/442 [02:05<00:00,  3.53it/s]
100%|██████████| 17/17 [00:06<00:00,  2.71it/s]
100%|██████████| 17/17 [00:06<00:00,  2.73it/s]

train 47.641556488309455 2.6557907902915154
dev 58.37885462555066 7.365638766519824
test 58.37885462555066 7.365638766519824





In [109]:
num_examples

[130319, 5675, 5675]