## Get Subset for Training

In [None]:
from datasets import load_dataset

ds = load_dataset('openpecha/tagged_cleaned_MT_v1.0.3', split='train')

In [None]:
def condition(example):
    return example['Tag'] != ''

ds = ds.filter(condition)

In [None]:
# Define the size of the random subset
subset_size = 50_000

# Shuffle the dataset
shuffled_ds = ds.shuffle(seed=0)  # Use a fixed seed for reproducibility

# Select the first `subset_size` examples
random_subset = shuffled_ds.select(range(subset_size))

In [None]:
random_subset = random_subset.train_test_split(.1)

In [None]:
random_subset

In [None]:
random_subset.save_to_disk('rat-poc-ds')

## Add Contexts

Add similar sentences as context to mimic retrieval augmentation. The context for both train and eval come from the train set to mimic having a set of contexts from the training data.

In [1]:
from datasets import load_from_disk

ds = load_from_disk('rat-poc-ds')

In [2]:
from sentence_transformers import SentenceTransformer

# Load pre-trained embedding model
embedding_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')  # Or a domain-specific model

# Encode all source sentences into vectors
sentences = ds['train']['Target']
embeddings = embedding_model.encode(sentences, convert_to_tensor=True)




In [3]:
import torch
from torch.nn.functional import cosine_similarity

def get_top_n_contexts(source_idx, source_embeddings, n=3):
    # Get the embedding of the target sentence
    query_embedding = source_embeddings[source_idx]

    # Compute cosine similarities with all other embeddings
    similarities = cosine_similarity(query_embedding.unsqueeze(0), source_embeddings)

    # Get the indices of the top-N most similar sentences (excluding itself)
    top_n_indices = torch.topk(similarities, n + 1).indices[1:]  # Skip the first (self)
    
    # Retrieve the corresponding sentences
    return list(top_n_indices)

# Example usage for the first sentence
context_idxs = get_top_n_contexts(0, embeddings, n=3)
print("Contexts for the first sentence:", context_idxs)

Contexts for the first sentence: [tensor(34123, device='cuda:0'), tensor(18495, device='cuda:0'), tensor(34931, device='cuda:0')]


In [4]:
# Add contexts to each example using `dataset.map`
def add_contexts(example, idx):
    # Call get_top_n_contexts with only necessary arguments
    context_idxs = get_top_n_contexts(idx, embeddings, n=3)
    
    # Construct the context by accessing the dataset using the indices
    example["context"] = [
        ds['train'][int(context_idx)]['Source'] + ' -> ' + ds['train'][int(context_idx)]['Target']
        for context_idx in context_idxs
    ]
    return example

# Apply the mapping
dataset_with_contexts = ds.map(
    add_contexts, 
    with_indices=True,  # Pass the index to `add_contexts`
    batched=False  # Process one example at a time
)


In [5]:
# Inspect the new dataset
print(dataset_with_contexts['train'][0])

{'Source': 'འཇིག་ལས་འདས་པའི་གང་འདུལ་ལོ།།', 'Target': 'Taming with transcendent beings.', 'File_Name': 'TM3076', 'Machine Aligned': False, '__index_level_0__': 1176089, 'Tag': 'Intrinsic Existence, Conventional Existence', 'context': ['འགྲོ་ཀུན་སྒྲིབ་པ་གཉིས་སྤངས་ཏེ།\xa0། -> May all beings conquer the two obscurations', 'དགེ་བས་མཁའ་མཉམ་ལུས་ཅན་མ་ལུས་པ།། ཐེག་མཆོག་གོ་གྱོན་ཤེས་རབ་མཚོན་ཐོགས་ནས།། བདུད་བཞིའི་དགྲ་སྡེ་མ་ལུས་ཀུན་བཅོམ་སྟེ།། སྐུ་གསུམ་ནོར་བུའི་ཁྲི་ལ་འཁོད་གྱུར་ཅིག། -> Through this virtue, may all embodied beings throughout space without exception, Put on the armor of the Supreme Vehicle and having raised the weapon of wisdom, May they overcome all without exception of the host of enemies which are the four demons And be set on the jeweled throne of the three bodies.', 'སྐྱེ་འགག་ཡོད་མེད་ལ་སོགས་པའི་དམིགས་པ་དང་འཛིན་པའི་ཡུལ་ལས་འདས་པའི་རིག་སྟོང་སྤྲོས་བྲལ་མཉམ་པ་ཉིད་ཀྱི་ཁོར་ཡུག་ཡིན་ཏེ། -> Phenomena therefore transcend all objects of reference and clinging, such as origin and cessation, exist

In [7]:
dataset_with_contexts.save_to_disk('rat-poc-ds-w-context')

Saving the dataset (0/1 shards):   0%|          | 0/45000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5000 [00:00<?, ? examples/s]