In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("billingsmoore/mlotsawa-ground-base", device_map="cuda:0")
tokenizer = AutoTokenizer.from_pretrained('billingsmoore/mlotsawa-ground-base')

In [2]:
from datasets import load_from_disk

ds = load_from_disk('small-pred-ds')

In [3]:
import torch

def generate_and_decode(batch):
    # Convert inputs to torch tensors on the model's device
    input_ids = torch.tensor(batch["input_ids"]).to(model.device)
    attention_mask = torch.tensor(batch["attention_mask"]).to(model.device)

    # Generate predictions
    with torch.no_grad():
        output_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=128,
            num_beams=4,
            do_sample=False,
        )

    # Decode predictions
    decoded_preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    return {
        "base_predictions": decoded_preds
    }


In [4]:
pred_ds = ds.map(generate_and_decode, batched=True, batch_size=16)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [5]:
pred_ds

Dataset({
    features: ['bo', 'en', 'input_ids', 'token_type_ids', 'attention_mask', 'labels', 'small_predictions', 'base_predictions'],
    num_rows: 100000
})

In [6]:
pred_ds.save_to_disk('pred-ds')

Saving the dataset (0/1 shards):   0%|          | 0/100000 [00:00<?, ? examples/s]

In [8]:
pred_ds.push_to_hub('billingsmoore/temp')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/100 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/billingsmoore/temp/commit/5f56d9ee49e20879d0ce0dc0a0575b9244239ae7', commit_message='Upload dataset', commit_description='', oid='5f56d9ee49e20879d0ce0dc0a0575b9244239ae7', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/billingsmoore/temp', endpoint='https://huggingface.co', repo_type='dataset', repo_id='billingsmoore/temp'), pr_revision=None, pr_num=None)