For more information on the dataset, see the dolly-alpaca-hindi datasets page [here](https://www.kaggle.com/datasets/heyytanay/dolly-alpaca-hindi).

In [None]:
import lance

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType

In [None]:
# In this example we are fine-tuning the Gemma-2b model but you can change it to any model of your choice
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
)

In [None]:
# Define a Prompt prefix, middle and suffix that we will pre-tokenize (to save redundant computation)
# and them arrange to include the actual instructions, inputs and outputs.
PROMPT_PRE = """नीचे एक निर्देश है जो किसी कार्य का वर्णन करता है, जिसे एक इनपुट के साथ जोड़ा गया है जो आगे का संदर्भ प्रदान करता है। एक प्रतिक्रिया लिखें जो अनुरोध को उचित रूप से पूरा करती है।\n### निर्देश:\n"""
PROMPT_MID = """\n### इनपुट:\n"""
PROMPT_SUF = """\n### प्रतिक्रिया:\n"""
pre_tok = tokenizer(PROMPT_PRE)['input_ids']
mid_tok = tokenizer(PROMPT_MID)['input_ids']
suf_tok = tokenizer(PROMPT_SUF)['input_ids']

In [None]:
class LanceDataset(Dataset):
    """
    Custom Dataset that does the following:
     - Load instructions, inputs and outputs from a Lance dataset
     - Truncates them to a cutoff length (this is to stop an exceptionally long example from crashing our training)
     - Arrange them to be in the right format by adding them in a prompt (pre-tokenized, in the above cell)
    """
    def __init__(self, dataset, pad_tok_id=None, cutoff=None):
        self.ds = lance.dataset(dataset)
        # Default cutoff length is 128 tokens and padding token id is 0
        self.cutoff = cutoff if cutoff else 128
        self.pad_tok_id = pad_tok_id if pad_tok_id else 0

    def __getitem__(self, idx):
        # Get the data at the current index as a list
        raw = self.ds.take([idx]).to_pylist()[0]
        ins, inp, out = raw['instructions'], raw['inputs'], raw['outputs']

        # Trim them so they are all 'cutoff' length long
        ins, inp, out = self.trim(self.cutoff, self.pad_tok_id, ins, inp, out)

        # Add the prompt's prefix, middle and suffix tokens to be in place
        final_output = pre_tok + ins + mid_tok + inp + suf_tok + out
        return final_output

    def __len__(self):
        # Since each row is a sample in our dataset, number of rows is number of samples
        return self.ds.count_rows()

    def trim(self, cutoff: int, pad_token: int, *args) -> list:
        # Truncate (or pad) each passed-in list of tokens to the cutoff length
        return [el[:cutoff] if len(el) >= cutoff else el+[pad_token]*(cutoff-len(el)) for el in args]

In [None]:
# Define the LoRA configuration
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=4,
    lora_alpha=16,
    lora_dropout=0.2,
)

# Apply LoRA to the model
model = get_peft_model(model, peft_config)

In [None]:
# Define the train and validation datasets
# The datasets should be in the current directory in the current folder
train_dataset = LanceDataset(
    "hindi_alpaca_dolly_train.lance/",
)
valid_dataset = LanceDataset(
    "hindi_alpaca_dolly_val.lance/",
)

# Define the data collator
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir="output",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=4, # change if your individual GPUs have more memory
    save_steps=10_000,
    save_total_limit=2,
    prediction_loss_only=True,
    remove_unused_columns=False,
    report_to=None, # remove this if you want to log to wandb
)

# Define the trainer and train and save the model
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
)

trainer.train()

model.save_pretrained(f"{model_id.split('/')[-1]}-hindi")