In [167]:
import datasets
from transformers import MT5Tokenizer
from functools import reduce

dataset = datasets.load_dataset("lecslab/usp-igt")

tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small", legacy=False)

# Collect the unique set of gloss labels
all_glosses = sorted(set([gloss for glosses in dataset['train']['pos_glosses'] +
                            dataset['eval']['pos_glosses'] +
                            dataset['test']['pos_glosses'] for gloss in glosses.replace("-", " ").split()]))
all_glosses = ["<sep>", "<pad>"] + all_glosses
SEP_TOKEN_ID = all_glosses.index("<sep>")
PAD_TOKEN_ID = all_glosses.index("<pad>")
print(f"{len(all_glosses)} unique glosses")

def encode_gloss_labels(label_string: str):
    """Encodes glosses as an id sequence. Each morpheme gloss is assigned a unique id."""
    word_glosses = label_string.split()
    glosses = [word_gloss.split("-") for word_gloss in word_glosses]
    glosses = [[all_glosses.index(gloss) for gloss in word if gloss != ''] for word in glosses]
    glosses = reduce(lambda a, b: a + [SEP_TOKEN_ID] + b, glosses)
    return glosses + [PAD_TOKEN_ID]

def tokenize(batch):
    inputs = tokenizer(batch['transcription'], truncation=True, padding=False, max_length=MAX_INPUT_LENGTH)
    inputs['labels'] = [encode_gloss_labels(label) for label in batch['pos_glosses']]
    return inputs

dataset = dataset.map(tokenize, batched=True)

dataset['train'][0]

67 unique glosses


Map:   0%|          | 0/9774 [00:00<?, ? examples/s]

Map:   0%|          | 0/232 [00:00<?, ? examples/s]

Map:   0%|          | 0/633 [00:00<?, ? examples/s]

{'transcription': 'o sey xtok rixoqiil',
 'segmentation': "o' sea x-tok r-ixóqiil",
 'pos_glosses': 'CONJ ADV COM-VT E3S-S',
 'glosses': 'o sea COM-buscar E3S-esposa',
 'translation': 'O sea busca esposa.',
 'input_ids': [259, 268, 303, 276, 259, 329, 11207, 1418, 329, 159121, 696, 1],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 'labels': [19, 0, 8, 0, 17, 66, 0, 31, 56, 1]}

In [171]:
from transformers import DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
# collator(batch.select_columns(['input_ids', 'attention_mask', 'labels']).to_list()


collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)
dataloader = DataLoader(dataset['train'].select_columns(['input_ids', 'attention_mask', 'labels']), batch_size=32, collate_fn=collator)

for batch in dataloader:
    # print(batch['input_ids'].shape)
    print(batch['labels'].unsqueeze(-2))
    break

tensor([[[  19,    0,    8,    0,   17,   66,    0,   31,   56,    1, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100]],

        [[  43,    0,    8,    0,   51,    0,   33,    0,   21,    0,   31,
            56,    1, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100]],

        [[   8,    0,   21,    0,   49,   12,    0,   56,    0,   46,    1,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100]],

        [[   7,   48,    0,   56,    0,   46,    1, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100]],

        [[  59,    0,    8,    1, -100, -100, -100, -100, -100, -100, -100,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
          -100]],

        [[  53,    0,    7,   32,    0,   31,   56,   48,    0,    8,    1,
          -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,

In [71]:
import torch
torch.tensor(dataset['train'][0]['labels'])

tensor([[18,  0,  7,  0, 16, 65,  0, 30, 55]])

In [165]:
from typing import List

def greedy_decode(feature_map, feature_logits: List[torch.Tensor]):
        """Decodes a bundle of feature logits into gloss ID predictions. 
        The output should align with the input vocabulary of the decoder.

        Args:
            feature_logits (List[torch.Tensor]): List of feature logit tensors each of size `(batch_size, seq_length, feature_size)`
        """
        batch_size = feature_logits[0].shape[0]
        seq_length = feature_logits[0].shape[1]

        primary_features = torch.argmax(feature_logits[0], -1)

        def _decode_row(row_index: int):
            """Decodes the tokens in a given row of the batch. Generator function that `yields` a single prediction (id) for each token."""
            for token_index in range(seq_length):
                # Filter based on the primary feature
                possible_label_ids = [index for index, value in enumerate(
                    feature_map) if value[0] == primary_features[row_index, token_index].item()]

                if len(possible_label_ids) == 1:
                    yield possible_label_ids[0]
                    continue

                # Now we have a list of possible feature matrices
                # Compute the joint probability based on the logits of each feature
                softmax = torch.nn.LogSoftmax(dim=-1)
                # (num_features, feature_size)
                feature_probs = [softmax(feature_logits[feature_index][row_index, token_index])
                                for feature_index in range(1, len(feature_logits))]  # Omits the first feature
                
                # (prob, id)
                most_probable_label = (None, None)
                # Try each possible label, choose one with highest probability
                for possible_label_id in possible_label_ids:
                    possible_label_features = feature_map[possible_label_id]
                    label_prob = 0
                    # Sum up log probabilities for each true feature of the label
                    for feature_index, feature_value in enumerate(possible_label_features):
                        if feature_index == 0:
                            continue
                        label_prob += feature_probs[feature_index - 1][feature_value]

                    if most_probable_label[0] is None or label_prob > most_probable_label[0]:
                        most_probable_label = (label_prob, possible_label_id)
                yield most_probable_label[1]

        next_tokens = [list(_decode_row(row_index)) for row_index in range(batch_size)]
        return torch.tensor(next_tokens)

primary_feature = torch.tensor([[[0.1, 0.2, 0.1], [0.1, 0.2, 0.1]],
                                [[0.05, 0.01, 10], [0.05, 0.01, 10]]])
secondary_feature =  torch.tensor([[[8, 2, 1], [8, 2, 1]],
                                [[1, 5, 10], [1, 5, 10]]], dtype=float)
feature_map = [[0, 0], [0, 1], [1, 0], [1,1], [2,0], [2,2]]
greedy_decode(feature_map, [primary_feature, secondary_feature])

tensor([[2, 2],
        [5, 5]])

In [143]:
torch.log(torch.softmax(torch.tensor([1, 5, 10], dtype=float), 0))

tensor([1.2257e-04, 6.6920e-03, 9.9319e-01], dtype=torch.float64)

In [141]:
softmax = torch.nn.LogSoftmax(dim=0)
softmax(torch.tensor([1, 5, 10], dtype=float))

tensor([-9.0068e+00, -5.0068e+00, -6.8379e-03], dtype=torch.float64)

In [159]:
max([len(x) for x in dataset['train']['labels']])

39

In [160]:
dataset

DatasetDict({
    train: Dataset({
        features: ['transcription', 'segmentation', 'pos_glosses', 'glosses', 'translation', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 9774
    })
    eval: Dataset({
        features: ['transcription', 'segmentation', 'pos_glosses', 'glosses', 'translation', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 232
    })
    test: Dataset({
        features: ['transcription', 'segmentation', 'pos_glosses', 'glosses', 'translation', 'input_ids', 'attention_mask', 'labels'],
        num_rows: 633
    })
})