In [21]:
import torch
import jsonlines
import nltk
from nltk.corpus import stopwords

from typing import *
import re

set_stopwords = set(stopwords.words())

In [11]:
def custom_tokenizer(sentence: str, marker: str) -> List[str]:
    tokens = sentence.split()
    for i, tk in enumerate(tokens):
        if marker in tk:
            target_position = i
            tokens[i] = tk[len(marker) :]
    return tokens, target_position


def preprocess(
    sentence: str, target_word=None, remove_stopwords=True, remove_digits=True
) -> str:
    # lowercase sentence
    sentence = sentence.lower()
    # remove punctuation
    sentence = re.sub(r"[^\w\s]", " ", sentence)
    # remove digits
    if remove_digits:
        sentence = re.sub(r"\d", "", sentence).strip()
    # replace multiple adjacent spaces with one single space
    sentence = re.sub(" +", " ", sentence).strip()

    # remove stopwords
    if remove_stopwords:
        tokens = sentence.split()
        tokens = [
            word
            for word in tokens
            if (not word in set_stopwords or word == target_word)
        ]
        sentence = " ".join(tokens)

    return sentence


def get_neighbourhood(
    tokens: List[str], target_position: int, width: int = 2
) -> Tuple[List[str], int]:
    neighbourhood = []
    new_position = width

    for pos in range(target_position - width, target_position + width + 1):
        if pos < 0:
            new_position -= 1
            continue
        if pos >= len(tokens):
            continue
        neighbourhood.append(tokens[pos])

    return neighbourhood, new_position


def tokens2indices(word_index, tokens: List[str]) -> torch.Tensor:
    return torch.tensor([word_index[word] for word in tokens], dtype=torch.long)


def compute_pos_tag_indexes(tokens: List[str]) -> torch.Tensor:
    tks_tags = nltk.pos_tag(tokens)
    indexes = torch.tensor([pos_indexes[tk_tag[1]] for tk_tag in tks_tags])
    return indexes

In [12]:
def read_data(dataset_path):
    data = []
    with jsonlines.open(dataset_path, "r") as f:
        for line in f.iter():
            data.append(line)
    return data

In [13]:
train_path = '../../data/train.jsonl'

In [14]:
data = read_data(train_path)

In [26]:
from collections import Counter
import string
import random


counter = Counter()
marker = "".join(random.choices(string.ascii_lowercase, k=20))
for line in data:
    s1 = line['sentence1']
    s2 = line['sentence2']
    start1 = int(line["start1"])
    start2 = int(line["start2"])
    end1 = int(line["end1"])
    end2 = int(line["end2"])
    target_word1 = s1[start1:end1]
    target_word2 = s2[start2:end2]
    s1 = s1[:start1] + marker + s1[start1:]
    s2 = s2[:start2] + marker + s2[start2:]
    
    s1 = preprocess(s1, target_word1)
    s2 = preprocess(s2, target_word2)
    
    t1, target_position1 = custom_tokenizer(s1, marker)
    t2, target_position2 = custom_tokenizer(s2, marker)
    counter.update(t1 + t2)

In [30]:
len(counter)

26119