## Imports

In [192]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from pprint import pprint
from tqdm import tqdm
from typing import  List, Tuple, Dict
from collections import Counter
from pathlib import Path


In [138]:
%%time

words = []
labels = []
sentences = []
with open ('../../data/train.tsv', 'r') as f:
    # strip lines, remove empty ones and the first
    lines = list(filter(None, map(str.strip, f.readlines())))[1:]
    for line in tqdm(lines, desc='Reading data', total=len(lines)):
        line: List[str] = line.split('\t')
        if line[0] == '#':
            sentences.append((words, labels))
            words = []
            labels = []
        else:
            words.append(line[0])
            labels.append(line[1])

print(f'{len(sentences)=}')
# pprint(sentences[0], compact=True)
# print(sentences[1])

Reading data: 100%|██████████| 254592/254592 [00:00<00:00, 1067203.39it/s]

len(sentences)=14534
CPU times: user 292 ms, sys: 2.59 ms, total: 295 ms
Wall time: 294 ms





In [295]:
class Vocabulary():
    """Implements a vocabulary of both words and labels. Automatically adds '<unk>' and '<pad>' word types.
    """

    def __init__(self, sentences: List[Tuple[List[str], List[str]]], threshold: int = 1):
        """Initialize the vocabulary from a dataset

        Args:
            sentences (List[Tuple[List[str], List[str]]]):
                The dataset as a list of tuples. 
                Each tuple contains two lists: the words of a sentence
                and the corresponding labels

            threshold (int, optional): 
                Number of appearances needed for a word to
                be inserted in the dictionary. Defaults to 1.
        """
        
        self.threshold: int = threshold
        self.counts: Counter = Counter()
        self.lcounts: Counter  = Counter()

        for sentence, labels in sentences:
            for word, label in zip(sentence, labels):
                self.counts[word] += 1
                self.lcounts[label] += 1

                if label == 'id':
                    print(f'{sentence=}')
                    print(f'{labels=}')
            
        # word vocabularies
        self.itos: List[str] = sorted(list(filter(lambda x: v.counts[x] >= threshold, self.counts.keys())) + ['<unk>', '<pad>'])
        self.stoi: Dict[str, int] = {s: i for i, s in enumerate(self.itos)}

        # label vocabularies
        self.ltos: List[str] = sorted(list(self.lcounts.keys()))
        self.stol: Dict[str, int] = {s: i for i, s in enumerate(self.ltos)}

        # unk and pad for ease of use
        self.unk: int = self.stoi['<unk>']
        self.pad: int = self.stoi['<pad>']


    def __contains__(self, word: str):
        return word in self.stoi


    def getWord(self, id: int) -> str:
        """Return the word at a given index

        Args:
            id (int): the index of a word

        Returns:
            str: the word corresponding to the given index
        """
        return self.itos[id]


    def getWordId(self, word: str) -> int:
        """Get the index of a given word

        Args:
            word (str): The word to retrieve the index of

        Returns:
            int: Index of the word if present, otherwise the index of '<unk>'
        """
        return self.stoi[word] if word in self.stoi else self.stoi['<unk>']


    def getLabel(self, id: int) -> str:
        """Get a label name from its index

        Args:
            id (int): the index of a label

        Returns:
            str: the correpsonding label name
        """
        return self.ltos[id]


    def getLabelId(self, label: str) -> int:
        """Get the id of a label

        Args:
            label (str): the name of the label

        Returns:
            int: the corresponding index
        """
        return self.stol[label]


    def __getitem__(self, idx: int or str) -> str or int:
        if isinstance(idx, str):
            return self.getWordId(idx)
        elif isinstance(idx, int):
            return self.getWord(idx)
        raise NotImplementedError()





class NerDataset(Dataset):

    def __init__(self, path: Path = Path('../../data/train.tsv'), vocab: Vocabulary = None, threshold: int = 2, window_size: int = 7, window_shift: int = None):
        """Build a Named Entity Recognition dataset from a .tsv file, which loads data as fixed-size windows

        Args:
            path (Path, optional): Path of the .tsv dataset file. Defaults to Path('../../data/train.tsv').
            vocab (Vocabulary, optional): Vocabulary to index the data. If none, build one. Defaults to None.
            threshold (int, optional): If vocab is None, threshold for the vocabulary. Defaults to 1.
            window_size (int, optional): Size of the windows. Defaults to 5.
            window_shift (int, optional): Shift of the windows. Defaults to None.
        """
        super().__init__()
        self.path: Path = path
        self.sentences: List[Tuple[List[str], List[str]]] = self.loadData(self.path)
        self.vocab: Vocabulary = vocab or Vocabulary(sentences, threshold=threshold)
        self.indexed_data: List[Tuple[List[int], List[int]]] = self.indexData()
        self.window_size: int = window_size
        self.window_shift: int = window_shift or window_size
        assert self.window_shift <= self.window_size and self.window_shift >= 0 and self.window_size > 0, \
            "Window shift must be equal or less than window size, both must be positive"
        self.windows: List[Tuple[torch.Tensor, torch.Tensor]] = self.build_windows()
        

    def loadData(self, path: Path):
        """Loads the dataset from file

        Args:
            path (Path): path of the .tsv dataset

        Returns:
            sentences (List[Tuple[List[str], List[str]]]):
                a list of sentences. Each sentences is a tuple made of:
                - list of words in the sentence
                - list of labels of the words
        """
        words = []
        labels = []
        sentences = []
        with open (path, 'r') as f:
            # strip lines, remove empty ones and the first
            lines = list(filter(None, map(str.strip, f.readlines())))[1:]
            for line in tqdm(lines, desc='Reading data', total=len(lines)):
                line: List[str] = line.split('\t')
                if line[0] == '#':
                    sentences.append((words, labels))
                    words = []
                    labels = []
                else:
                    words.append(line[0])
                    labels.append(line[1])
        return sentences


    def indexData(self) -> List[Tuple[List[int], List[int]]]:
        """Builds self.indexed_data transforming both words and labels in integers

        Args:
            vocab (Vocabulary): the vocabulary to use to convert words to indices
        """
        data = list(map(
            lambda sentence: (
                [self.vocab[w] for w in sentence[0]],
                [self.vocab.getLabelId(l) for l in sentence[1]]
            ),
            self.sentences
        ))
        return data


    def build_windows(self) -> List[Tuple[List[int], List[int]]]: 
        """Builds fixed-size windows from the indexed data

        Returns:
            List[Tuple[Tensor, Tensor]]: List of fixed-size windows
        """
        windows: List[Tuple[List[int], List[int]]] = []
        for word_ids, label_ids in self.indexed_data:
            start = 0
            while start < len(word_ids):
                # generate window
                word_window = word_ids[start: start+self.window_size]
                label_window = label_ids[start: start+self.window_size]
                # pad
                word_window += [self.vocab.pad] * (self.window_size - len(word_window))
                label_window += [self.vocab.getLabelId('O')] * (self.window_size - len(label_window))
                # append
                windows.append((torch.tensor(word_window), torch.tensor(label_window)))
                start += self.window_shift
        return windows


    def __len__(self):
        return len(self.windows)


    def __getitem__(self, idx):
        return self.windows[idx]



d = NerDataset(threshold=2, window_size=100, window_shift=0)
v = d.vocab


print('Examples of Vocabulary usage\n---------------')
# as an examples, 'is?' has count=1
print(f'{v.counts["is?"]=}')
print(f"'is?' in v? {'is?' in v}")

print(f"{v['is?']=}")
print(f"{v['<unk>']=}")
print(f'{v[605]=}')

# this is equal to 605 because the string "605" never figures in the dataset, so it is an unknown token with index 605
print(f'{v["605"]=}')
print(f'{v["60006665"]=}')

print(f'{v.unk=}')
print(f'{v.pad=}')
print('---------------')


Reading data: 100%|██████████| 254592/254592 [00:00<00:00, 2585703.91it/s]


Examples of Vocabulary usage
---------------
v.counts["is?"]=1
'is?' in v? False
v['is?']=605
v['<unk>']=605
v[605]='<unk>'
v["605"]=605
v["60006665"]=605
v.unk=605
v.pad=604
---------------


In [324]:
# Try the dataset and make sure everything is as expected
torch.manual_seed(777)

trainset = NerDataset()
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

windows, labels = next(iter(trainloader))
print(f'{windows[:5]=}')
print(f'{labels[:5]=}')

v = trainset.vocab
print(f'word 4792: {v[4792]} - label n.3: {v.getLabel(3)}')
print(f"word 6109: {v[6109]} - label n.9: {v.getLabel(9)}")
print(f"word 3284: {v[3284]} - label n.2: {v.getLabel(2)}")

Reading data: 100%|██████████| 254592/254592 [00:00<00:00, 2403201.71it/s]


windows[:5]=tensor([[ 4595,  4792,  6109,    14,   604,   604,   604],
        [  605,     2,    10,  1003,    78,  3830, 11274],
        [ 5110,  5676,   612,  6805,  7452, 10572,  3284],
        [ 5110,  7471,  1095,  5470,  6293,  8664, 11521],
        [   14,   604,   604,   604,   604,   604,   604]])
labels[:5]=tensor([[12,  3,  9, 12, 12, 12, 12],
        [12, 12, 12, 12, 12, 12, 12],
        [12, 12, 12, 12, 12, 12,  2],
        [12, 12, 12, 12,  1, 12, 12],
        [12, 12, 12, 12, 12, 12, 12]])
word 4792: glen - label n.3: B-LOC
word 6109: lake - label n.9: I-LOC
word 3284: democratic - label n.2: B-GRP


'blake'