## Note
To run the following code correctly, you first need to download the appropriate data files from Time-Sensitive QA
(https://github.com/wenhuchen/Time-Sensitive-QA).

In [2]:
import json
import random
import re
import copy
import gzip
from tqdm import tqdm

with open('dataset/relations.json', 'r') as f:
    relations = json.load(f)
assert isinstance(relations, dict)    

mapping = {1: 'Jan', 2: 'Feb', 3: "Mar", 4: "Apr", 5: "May",
           6: 'Jun', 7: 'Jul', 8: "Aug", 9: "Sep", 10: 'Oct',
           11: "Nov", 12: 'Dec'}

imapping = {v: k for k, v in mapping.items()}
imapping.update({
    'January': 1, 'February': 2, 'March': 3, 'April': 4, 'May': 5, 'June': 6,'July': 7, 'August': 8, 'September': 9, 'October': 10,
    'November': 11, 'December': 12
})
class Time(object):
    def __init__(self, time_str):
        splits = [int(_) for _ in time_str.split('-')]
        
        self.year = max(splits[0], 1)
        self.month = splits[1]
        self.date = splits[2]

        if self.month == 1 and self.date == 1:
            self.month = 0
            self.date = 0
        elif self.month == 0 or self.date == 0:
            self.month = 0
            self.date = 0
        
        assert self.year > 0
            
    def __gt__(self, other):
        assert isinstance(other, Time)
        if self.year > other.year:
            return True
        elif self.year < other.year:
            return False
        else:
            if self.month > other.month:
                return True
            elif self.month < other.month:
                return False
            else:
                if self.date > other.date:
                    return True
                else:
                    return False
    
    def __eq__(self, other):
        assert isinstance(other, Time), other
        return self.year == other.year and self.month == other.month and self.date == other.date
    
    def __lt__(self, other):
        assert isinstance(other, Time)
        if self.year < other.year:
            return True
        elif self.year > other.year:
            return False
        else:
            if self.month < other.month:
                return True
            elif self.month > other.month:
                return False
            else:
                if self.date < other.date:
                    return True
                else:
                    return False
    
    def __repr__(self):
        if self.month == 0:
            return '{}'.format(self.year)
        else:
            return '{} {}'.format(mapping[self.month], str(self.year))
    
    def __str__(self):
        return self.__repr__()
    
    @classmethod
    def parse(cls, time):
        assert isinstance(time, str)
        if ' ' not in time:
            return cls(f'{time}-0-0')
        else:
            month, year = time.split(' ')
            month = month.lower().capitalize()
            month = imapping[month]
            return cls(f'{year}-{month}-1')
    
    @classmethod
    def minus_one_year(cls, time):
        return cls('{}-{}-{}'.format(time.year - 1, time.month, time.date))

    @classmethod
    def minus_k_year(cls, time, k):
        return cls('{}-{}-{}'.format(max(time.year - k, 2), time.month, time.date))
    
    @classmethod
    def add_one_year(cls, time):
        return cls('{}-{}-{}'.format(time.year + 1, time.month, time.date))

    @classmethod
    def add_k_year(cls, time, k):
        return cls('{}-{}-{}'.format(time.year + k, time.month, time.date))      
    
    @classmethod
    def add_one_month(cls, time):
        new_time = copy.deepcopy(time)
        if new_time.month < 12:
            new_time.month += 1
            return new_time
        else:
            new_time.month = 1
            new_time.year += 1
            return new_time

def random_pop(time_range):
    cur = time_range[0]
    end = time_range[1]
    candidates = []
    cur = Time.add_one_month(cur)
    while cur < end or cur == end:
        candidates.append(cur)
        cur = Time.add_one_month(cur)

    if candidates:
        return random.choice(candidates)
    else:
        return random.choice(time_range)

def too_close(time1, time2):
    delta = (time2.year - time1.year) * 12
    delta += time2.month - time1.month
    return delta <= 2
    
def prop(time, first_last=None, difficulty='easy'):
    if isinstance(time, tuple) or isinstance(time, list):
        assert len(time) == 2, time
        assert isinstance(time[0], Time) and isinstance(time[1], Time)
        if too_close(time[0], time[1]):
            return 'in {}'.format(str(time[0]))
        else:
            if difficulty == 'easy':
                option = random.choice(['between'])
            elif difficulty == 'hard':
                if first_last == 'first':
                    option = random.choice(['in', 'between-subset', 'before'])
                elif first_last == 'last':
                    option = random.choice(['in', 'between-subset', 'after'])
                elif first_last is None:
                    option = random.choice(['in', 'between-subset'])
                else:
                    raise ValueError()
            else:
                raise ValueError()

            if option == 'in':
                options = ['in {}'.format(str(random_pop(time)))]
                if time[1].year // 10 > time[0].year // 10:
                    if time[1].year % 10 >= 3:
                        options.append('in early {}s'.format(time[1].year // 10 * 10))
                    if time[0].year % 10 <= 7:
                        options.append('in late {}s'.format(time[0].year // 10 * 10))
                return random.choice(options)
            elif option == 'between':
                return 'from {} to {}'.format(str(time[0]), str(time[1]))
            elif option == 'between-subset':
                x1 = random_pop(time)
                x2 = random_pop((x1, time[1]))
                return 'between {} and {}'.format(str(x1), str(x2))
            elif option == 'before':
                x = random_pop(time)
                return 'before {}'.format(str(x))
            elif option == 'after':
                x = random_pop(time)
                return 'after {}'.format(str(x))
            else:
                raise ValueError('Not Existing')
    else:
        return 'in {}'.format(str(time))

def link_2_name(string):
    string = string.replace('/wiki/', '')
    string = string.replace('_', ' ')
    return string

def get_neagetive_times(time):
    k = 10 if time[1].year - time[0].year < 10 else time[1].year - time[0].year
    k_year = random.randint(1, k)
    negative_time0 = [Time.minus_k_year(time[0], k_year), Time.minus_one_year(time[0])]
    negative_time0[0].month = random.randint(1, 12) if time[0].month != 0 else time[0].month
    negative_time0[1].month = random.randint(1, 12) if time[1].month != 0 else time[1].month
    negative_time1 = [Time.add_one_year(time[1]), Time.add_k_year(time[1], k_year)]
    negative_time1[0].month = random.randint(1, 12) if time[0].month != 0 else time[0].month
    negative_time1[1].month = random.randint(1, 12) if time[1].month != 0 else time[1].month
    return negative_time0, negative_time1


def get_negative_specifier(time, first_last=None, difficulty='easy'):
    negative_time0, negative_time1 = get_neagetive_times(time)
    if first_last == "first":
        time_specifier = prop(negative_time0,first_last=first_last, difficulty=difficulty)
    elif first_last == "last":
        time_specifier = prop(negative_time1,first_last=first_last, difficulty=difficulty)
    else:
        negative_time = random.choice([negative_time0,negative_time1])
        time_specifier = prop(negative_time,first_last=first_last, difficulty='easy')
    return time_specifier


In [None]:
version = 'v3'
splits = ['train', 'dev']
difficulties = ['easy', 'hard']
file = open(f'dataset/contriever_finetune_dataset_{version}.jsonl', 'w')
dataset_dict = {}
for split in splits:
    with open(f'dataset/annotated_{split}.json', 'r') as f:
        data = json.load(f)
    for difficulty in difficulties:
        for d in tqdm(data, desc=f'{split}-{difficulty}'):
            assert isinstance(d['type'], str)
            
            paragraphs = d['paras']
            assert isinstance(paragraphs, list)
            
            templates = relations[d['type']]['template']
            template = random.choice(templates)
            template = template.replace('$1', link_2_name(d['link']))
            qas = []
            quesion_num = 1
            for _ in range(quesion_num):
                for i, entry in enumerate(d['questions']):
                    assert len(re.findall('\?$', template)) == 1, template
                    time_step = [Time.parse(entry[0][0]), Time.parse(entry[0][1])]
                    assert isinstance(entry[1], list), entry[1]
                    assert isinstance(entry[1][0], dict), entry[1]
                    if i == 0:
                        specifier = prop(time_step, 'first', difficulty)
                        negative_specifier = get_negative_specifier(time_step, 'first', difficulty)
                    elif i == len(d['questions']) - 1:
                        specifier = prop(time_step, 'last', difficulty)
                        negative_specifier = get_negative_specifier(time_step, 'last', difficulty)
                    else:
                        specifier = prop(time_step, None, difficulty)
                        negative_specifier = get_negative_specifier(time_step, None, difficulty)
                    if '$4' in template:
                        question = template.replace('$4', specifier)
                        negative_question = template.replace('$4', negative_specifier)
                    elif '$2' in template:
                        question = template.replace('$2', specifier)
                        negative_question = template.replace('$2', negative_specifier)
                    else:
                        raise "It's not a template"
                    for text_dict in entry[1]:
                        if text_dict['answer'] == '':
                            continue
                        else:
                            para = paragraphs[text_dict['para']]
                            if str(time_step[0].year) not in para and str(time_step[1].year) not in para:
                                continue
                            if dataset_dict.get(para, None) == None:
                                dataset_dict[para] = {
                                    'positive_ctxs':[{'text': question}],
                                    'negative_ctxs': [{'text': negative_question}]
                                }
                            else:
                                dataset_dict[para]['positive_ctxs'].append({'text': question})
                                dataset_dict[para]['negative_ctxs'].append({'text': negative_question})

dataset_list = []
for question, value in dataset_dict.items():
    temp_dict = {"question": question}
    temp_dict.update(value)
    dataset_list.append(
        temp_dict
    )
for data_item in dataset_list:
    data_str = json.dumps(data_item)
    file.write(data_str + '\n')


In [None]:
dataset = []
with open(f'dataset/contriever_finetune_dataset_{version}.jsonl', 'r') as f:
    for line in f:
        data = json.loads(line)
        dataset.append(data)
ratio = 0.98
total_nums = len(dataset)
print(total_nums)
train_nums = int(total_nums * ratio)
test_nums = total_nums - train_nums  
train_items = random.sample(dataset, train_nums)
print(len(train_items))
test_items = [test_item for test_item in dataset if test_item not in train_items]
with open(f'dataset/contriever_finetune_train_{version}.jsonl', 'w') as f:
    for data_item in train_items:
        data_str = json.dumps(data_item)
        f.write(data_str + '\n')
with open(f'dataset/contriever_finetune_eval_{version}.jsonl', 'w') as f:
    for data_item in test_items:
        data_str = json.dumps(data_item)
        f.write(data_str + '\n')

