In [2]:
from typing import Iterable
from datasets import Dataset, DatasetDict, Features, Sequence, Value

import sys
sys.path.append("..")
from common.token import Token
from common.sentence import Sentence
from common.parse_conllu import parse_conllu_incr


def convert_sentence_to_dict(sentence: Sentence) -> dict:
    return {
        "words": sentence.words,
        "lemmas": sentence.lemmas,
        "upos": sentence.upos,
        "xpos": sentence.xpos,
        "feats": sentence.feats,
        "heads": sentence.heads,
        "deprels": sentence.deprels,
        "deps": sentence.deps,
        "miscs": sentence.miscs,
        "deepslots": sentence.semslots,
        "semclasses": sentence.semclasses,
        "metadata": sentence.metadata
    }


def convert_conllu_to_hf(file_path: str) -> Iterable[dict]:
    with open(file_path, "r") as file:
        for sentence in parse_conllu_incr(file):
            yield convert_sentence_to_dict(sentence)


def build_raw_dataset(file_paths: dict[str, str]) -> DatasetDict:
    features = Features({
        "words": Sequence(Value("string")),
        "lemmas": Sequence(Value("string")),
        "upos": Sequence(Value("string")),
        "xpos": Sequence(Value("string")),
        "feats": Sequence(Value("string")),
        "heads": Sequence(Value("int32")),
        "deprels": Sequence(Value("string")),
        "deps": Sequence(Value("string")),
        "miscs": Sequence(Value("string")),
        "deepslots": Sequence(Value("string")),
        "semclasses": Sequence(Value("string")),
        "metadata": Value("string")
    })

    splits = {}
    for split_name, file_path in file_paths.items():
        splits[split_name] = Dataset.from_generator(convert_conllu_to_hf, gen_kwargs={"file_path": file_path}, features=features)
    return DatasetDict(splits)


def create_dataset():
    file_paths = {
        "train": "../data/train.conllu",
        "validation": "../data/validation.conllu",
        "test": "../data/test_clean.conllu",
    }
    dataset = build_raw_dataset(file_paths)
    return dataset

In [3]:
import ast
import itertools
from typing import Iterable
from datasets import Dataset, DatasetDict, Features, Sequence, Value, ClassLabel, Array2D

import torch

from lemmatize_helper import predict_lemma_rule


def dict_from_str(s: str) -> dict:
    """One cannot simply convert a string representation of a dict to a dict."""
    return ast.literal_eval(s)


def preprocess(batch: dict[str, list]) -> dict[str, list]:
    """Return lemma rules and joint pos-feats columns for a batch."""
    words = batch["words"]
    lemmas = batch["lemmas"]
    upos = batch["upos"]
    xpos = batch["xpos"]
    feats = batch["feats"]
    heads = batch["heads"]
    deprels = batch["deprels"]
    deps = batch["deps"]

    lemma_rules: list[str] = None
    if lemmas is not None:
        lemma_rules = [
            str(predict_lemma_rule(word if word is not None else '', lemma if lemma is not None else ''))
            for word, lemma in zip(words, lemmas, strict=True)
        ]

    joint_pos_feats: list[str] = None
    if upos is not None and xpos is not None and feats is not None:
        joint_feats = [
            '|'.join([f"{k}={v}" for k, v in dict_from_str(feat).items()]) if 0 < len(dict_from_str(feat)) else '_'
            for feat in feats
        ]
        joint_pos_feats = [
            f"{token_upos}#{token_xpos}#{token_joint_feats}"
            for token_upos, token_xpos, token_joint_feats in zip(upos, xpos, joint_feats, strict=True)
        ]
        
    sequence_length = len(words)
    
    deps_matrix_ud = None
    if heads is not None and deprels is not None:
        deps_matrix_ud = [[''] * sequence_length for _ in range(sequence_length)]
        for index, (head, relation) in enumerate(zip(heads, deprels, strict=True)):
            # Skip nulls.
            if head == -1:
                continue
            assert 0 <= head
            # Hack: start indexing at 0 and replace ROOT with self-loop.
            # It makes parser implementation much easier.
            if head == 0:
                # Replace ROOT with self-loop.
                head = index
            else:
                # If not ROOT, shift token left.
                head -= 1
                assert head != index, f"head = {head + 1} must not be equal to index = {index + 1}"
            deps_matrix_ud[index][head] = relation

    deps_matrix_eud = None
    if deps is not None:
        deps_matrix_eud = [[''] * sequence_length for _ in range(sequence_length)]
        for index, dep in enumerate(deps):
            dep = dict_from_str(dep) # Convert string representation of dict to a dict.
            assert 0 < len(dep), f"Deps must not be empty"
            for head, relation in dep.items():
                assert 0 <= head
                # Hack: start indexing at 0 and replace ROOT with self-loop.
                # It makes parser implementation much easier.
                if head == 0:
                    # Replace ROOT with self-loop.
                    head = index
                else:
                    # If not ROOT, shift token left.
                    head -= 1
                    assert head != index, f"head = {head + 1} must not be equal to index = {index + 1}"
                deps_matrix_eud[index][head] = relation

    return {
        "lemma_rules": lemma_rules,
        "joint_pos_feats": joint_pos_feats,
        "deps_ud": deps_matrix_ud,
        "deps_eud": deps_matrix_eud
    }


def update_schema_with_class_labels(dataset_dict: DatasetDict) -> Features:
    """Update the schema to use ClassLabel for specified columns."""

    def extract_unique_labels(dataset, column_name, is_matrix=False) -> list[str]:
        """Extract unique labels from a specific column in the dataset."""
        if is_matrix:
            all_labels = [value for matrices in dataset[column_name] for matrix in matrices for value in matrix]
        else:
            all_labels = itertools.chain.from_iterable(dataset[column_name])
        return sorted(set(all_labels)) # Ensure consistent ordering of labels

    # Extract labels from train dataset only, since all the labels must be present in training data.
    train_dataset = dataset_dict['train']

    # Extract unique labels for each column that needs to be ClassLabel.
    lemma_rule_labels = extract_unique_labels(train_dataset, "lemma_rules")
    joint_pos_feats_labels = extract_unique_labels(train_dataset, "joint_pos_feats")
    deps_ud_labels = extract_unique_labels(train_dataset, "deps_ud", is_matrix=True)
    deps_eud_labels = extract_unique_labels(train_dataset, "deps_eud", is_matrix=True)
    misc_labels = extract_unique_labels(train_dataset, "miscs")
    deepslot_labels = extract_unique_labels(train_dataset, "deepslots")
    semclass_labels = extract_unique_labels(train_dataset, "semclasses")

    # Define updated features schema
    features = Features({
        "words": Sequence(Value("string")),
        "lemma_rules": Sequence(ClassLabel(names=lemma_rule_labels), ),
        "joint_pos_feats": Sequence(ClassLabel(names=joint_pos_feats_labels)),
        "deps_ud": Sequence(Sequence(ClassLabel(names=deps_ud_labels))),
        "deps_eud": Sequence(Sequence(ClassLabel(names=deps_eud_labels))),
        "miscs": Sequence(ClassLabel(names=misc_labels)),
        "deepslots": Sequence(ClassLabel(names=deepslot_labels)),
        "semclasses": Sequence(ClassLabel(names=semclass_labels)),
        "metadata": Value("string")
    })
    return features


dataset = create_dataset()
dataset = dataset.map(preprocess, remove_columns=['lemmas', 'upos', 'xpos', 'feats', 'heads', 'deprels', 'deps'])

class_features = update_schema_with_class_labels(dataset)
dataset = dataset.cast(class_features)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = dataset.with_format("torch", device=device)

dataset['train'][0]

Map:   0%|          | 0/6912 [00:00<?, ? examples/s]

Map:   0%|          | 0/1729 [00:00<?, ? examples/s]

Map:   0%|          | 0/1296 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/6912 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1729 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1296 [00:00<?, ? examples/s]

{'words': ['The',
  'firm',
  "'s",
  'snowmobile',
  'division',
  'and',
  'defence',
  'services',
  'unit',
  'were',
  'also',
  'sold',
  'and',
  'Bombardier',
  'started',
  'the',
  'development',
  'of',
  'a',
  'new',
  'aircraft',
  'seating',
  '110',
  'to',
  '135',
  'passengers',
  '.'],
 'lemma_rules': tensor([  0,   0,   0,   0,   0,   0,   0,  13,   0, 143,   0,  74,   0,   0,
          32,   0,   0,   0,   0,   0,   0,  63,   0,   0,   0,  13,   0],
        device='cuda:0'),
 'joint_pos_feats': tensor([ 53,  79,  96,  79,  79,  50,  79,  78,  79,  34,  21, 192,  50, 152,
         179,  53,  79,  12,  54,   2,  79, 195,  85,  12,  85,  78, 160],
        device='cuda:0'),
 'deps_ud': tensor([[ 0, 20,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0,  0,  0,  0, 34,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 0, 11,  0,  0,  0, 

In [4]:
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

# utils.py
def pad_matrices(matrices: Tensor, padding_value: int = -1) -> Tensor:
    # Determine the maximum size in each dimension
    max_height = max(t.size(0) for t in matrices)
    max_width = max(t.size(1) for t in matrices)
    assert max_height == max_width, "UD and E-UD matrices must be square."
    
    # Create a single tensor for stacking with padding
    # Initialize with -1 and then copy the tensors into it
    padded_tensor = torch.full((len(matrices), max_height, max_width), padding_value)
    
    # Stack tensors directly into the larger tensor
    for i, matrix in enumerate(matrices):
        padded_tensor[i, :matrix.size(0), :matrix.size(1)] = matrix
    return padded_tensor


# data.py
def collate_fn(batches: list[dict[str, list | Tensor]]) -> dict[str, list | Tensor]:
    padding_value = -1
    stack_list_column = lambda column: [batch[column] for batch in batches]
    pad_sequence_column = lambda column: pad_sequence([batch[column] for batch in batches], padding_value=padding_value, batch_first=True)
    pad_matrix_column = lambda column: pad_matrices([batch[column] for batch in batches], padding_value=padding_value)
    return {
        "words": stack_list_column('words'),
        "lemma_rules": pad_sequence_column('lemma_rules'),
        "joint_pos_feats": pad_sequence_column('joint_pos_feats'),
        "deps_ud": pad_matrix_column('deps_ud'),
        "deps_eud": pad_matrix_column('deps_eud'),
        "miscs": pad_sequence_column('miscs'),
        "deepslots": pad_sequence_column('deepslots'),
        "semclasses": pad_sequence_column('semclasses'),
        "metadata": stack_list_column('metadata')
    }


dataloader = DataLoader(dataset['train'], batch_size=4, collate_fn=collate_fn)

for batch in dataloader:
    print(batch)
    break

{'words': [['The', 'firm', "'s", 'snowmobile', 'division', 'and', 'defence', 'services', 'unit', 'were', 'also', 'sold', 'and', 'Bombardier', 'started', 'the', 'development', 'of', 'a', 'new', 'aircraft', 'seating', '110', 'to', '135', 'passengers', '.'], ['Mr', 'Majumdar', 'also', 'said', 'an', 'assessment', 'should', 'be', 'made', 'as', 'to', 'whether', 'foreign', 'investment', 'is', 'indeed', 'beneficial', 'to', 'the', 'country', '-', 'in', 'terms', 'of', 'employment', 'and', 'money', 'generated', '-', 'or', 'just', 'another', 'way', 'of', 'international', 'companies', 'filling', 'their', 'deep', 'pockets', '.'], ['This', '#NULL', 'means', 'mobile', 'companies', 'have', 'to', 'think', 'carefully', 'about', 'what', 'they', 'are', 'offering', 'in', 'new', 'models', 'so', 'that', 'people', 'see', 'a', 'compelling', 'reason', 'to', 'upgrade', ',', 'said', 'Gartner', '.'], ['The', 'agenda', 'was', 'just', 'too', 'broad', 'and', 'as', 'a', 'result', 'nothing', 'was', 'prioritised', '.']],