In [12]:
import os

from transformers import AutoTokenizer

from shortcutfm.__main__ import parse_config

In [6]:
os.chdir("..")

In [7]:
cfg = parse_config("configs/training/qqp.yaml", [])

In [20]:
from torch.utils.data import DataLoader

from datasets import Dataset
from shortcutfm.batch import collate
from shortcutfm.text_datasets import TextDataset

train_ds = Dataset.load_from_disk(cfg.training_data_path)
train_text_ds = TextDataset(train_ds)
train = DataLoader(
    train_text_ds,
    batch_size=cfg.batch_size,
    collate_fn=collate,
    shuffle=False,
    num_workers=8,
    persistent_workers=True,
)

In [21]:
len(train)

18090

In [22]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [31]:
for batch in train:
    texts = tokenizer.batch_decode(batch.seqs, skip_special_tokens=False)

    for idx, (text, seq, input_mask, pad_mask) in enumerate(
        zip(texts, batch.seqs, batch.input_ids_mask, batch.padding_mask, strict=False)
    ):
        print(f"\nExample {idx + 1}:")
        print(f"Decoded full sequence:\n{text}\n")

        # Compute loss mask
        loss_mask = pad_mask * input_mask

        # Token IDs and words belonging to the source sequence (input part)
        src_token_ids = seq[(input_mask == 0).bool()].tolist()
        decoded_src_tokens = tokenizer.batch_decode(src_token_ids, skip_special_tokens=False)

        print(f"Token IDs (Source Sequence):\n{src_token_ids}\n")
        print(f"Decoded words (Source Sequence):\n{decoded_src_tokens}\n")

        # Token IDs and words contributing to loss (target part)
        loss_token_ids = seq[loss_mask.bool()].tolist()
        decoded_loss_tokens = tokenizer.batch_decode(loss_token_ids, skip_special_tokens=False)

        print(f"Token IDs contributing to loss:\n{loss_token_ids}\n")
        print(f"Decoded words contributing to loss:\n{decoded_loss_tokens}\n")

    break  # Process only the first batch for verification


Example 1:
Decoded full sequence:
[CLS] academic and educational advice : what can i do after completing bcom? [SEP] [CLS] what should i do after bcom? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

Token IDs (Source Sequence):
[101, 3834, 1998, 4547, 6040, 1024, 2054, 2064, 1045, 2079, 2044, 7678, 4647, 5358, 1029, 102]

Decoded words (Source Sequence):
['[CLS]', 'academic', 'and', 'educational', 'advice', ':', 'what', 'ca