## Notes: 
- some parts of the notebook are not displayed properly on github, please download and view it locally to use it correctly.
- This notebook has been tested only in Google Colab. It should also work in other Jupyter environments, but you might run into issues due to differences in dependencies or setup. 

## model

In [None]:
"""Note
Unsloth frequently updates their dependency requirements for installation on colab. So please do check their latest dependency requirements for colab by visiting any notebook listed on their github.
"""

# %%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
    !pip install transformers==4.51.3
    !pip install --no-deps unsloth
    !pip install flash-attn

In [None]:
from unsloth import FastLanguageModel
import torch

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "nis12ram/Nemotron-4-Mini-Hindi-4B-Instruct",
    max_seq_length = 4000,
    load_in_4bit = False,
    load_in_8bit = False,
    full_finetuning = False,
)

In [None]:
model

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 512,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "up_proj", "down_proj"],
    lora_alpha = 512,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
model

## dataset

#### dataset utils

In [None]:
ner_user_prompt = '''You are a Hindi language expert who specializes in extracting entities from text. Given a piece of text, extract all crucial entities along with their respective context-aware entity types. Ensure that entity type is in Hindi. The output should be in JSON format.

## Output format:
```json
{{
  "entities": [
    {{
      "type": "_",
      "value": ["_", "_"]
    }},
    {{
      "type": "_",
      "value": ["_"]
    }}
  ]
}}
```

## Text:
""" {text} """'''



In [None]:
import re, json
def desired_json_structure(json_obj, indent) -> str:
    ## convert json like object to json like string
    json_str = json.dumps(json_obj, ensure_ascii=False, indent=indent)

    # convert json like string to desired strcuture
    json_str = re.sub(
        r'("value": )\[\s+(.*?)\s+\]',
        lambda m: m.group(1) + '[' + re.sub(r'\s+', ' ', m.group(2).strip()) + ']',
        json_str,
        flags=re.DOTALL
    )
    return json_str

In [None]:
def extract_entity_type(question: str) -> str:
    pattern = r"what describes (.*?) in the text\??"
    match = re.search(pattern, question, re.IGNORECASE)
    return match.group(1).strip() if match else None

In [None]:
from datasets import load_dataset, Dataset, concatenate_datasets

#### entity_type_hi4_pilener

In [None]:
gliner_dataset = load_dataset("nis12ram/entity_type_hi4_pilener", split="train")

In [None]:
import json
from pprint import pprint
def map_func(datapoint):
  input_text: str = datapoint["input_text"]
  entities: list[dict[str, str|list[str]]] = datapoint["hi_entities"]
  datapoint["text"] = f'''<extra_id_0>System

<extra_id_1>User
{ner_user_prompt.format(text=input_text)}
<extra_id_1>Assistant
```json
{desired_json_structure({"entities":entities}, indent=1)}
```<extra_id_1>'''
  return datapoint
gliner_dataset = gliner_dataset.map(map_func)

In [None]:
print(gliner_dataset)
gliner_dataset = gliner_dataset.filter(lambda datapoint: len(datapoint["hi_entities"])>0)
print(gliner_dataset)

In [None]:
print(gliner_dataset[0]["text"])

In [None]:
gliner_dataset = gliner_dataset.filter(lambda datapoint: len(tokenizer(datapoint["text"])["input_ids"])<=2000)

In [None]:
gliner_dataset

## golden dataset

In [None]:
golden_dataset = load_dataset("nis12ram/HindiNER-golden-dataset", split = "train")

In [None]:
def map_func(datapoint):
  datapoint["text"] = f'''<extra_id_0>System

<extra_id_1>User
{ner_user_prompt.format(text=datapoint["input"])}
<extra_id_1>Assistant
```json
{desired_json_structure(datapoint["labels"], indent=1)}
```<extra_id_1>'''
  return datapoint
golden_dataset = golden_dataset.map(map_func)

In [None]:
golden_dataset

In [None]:
print(golden_dataset[0]["text"])

In [None]:
# @title oversampling the dataset
print(len(golden_dataset))
golden_dataset = concatenate_datasets([golden_dataset]*3)
print(len(golden_dataset))

#### concatenate

In [None]:
data = []
for ds in [
    gliner_dataset,
    golden_dataset]:
  data.extend(ds["text"])
train_dataset = Dataset.from_dict({"text":data})
train_dataset = train_dataset.shuffle(seed=12)
train_dataset

In [None]:
print(train_dataset[10002]["text"])

## main

In [None]:
# @title MAX CTX
res = tokenizer(next(train_dataset.iter(batch_size=len(train_dataset)))["text"])
max_ctx = max(len(lst) for lst in res["input_ids"])
print(f"MAX_CTX: {max_ctx}")

In [None]:
from trl import SFTTrainer, SFTConfig


trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    eval_dataset = None,
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 16,
        gradient_accumulation_steps = 1,
        warmup_ratio = 0.03,
        num_train_epochs = 1,
        learning_rate = 5e-5,
        logging_steps = 120,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "phase1-model",
        report_to = "none",
        save_steps=(len(train_dataset)*1)//16
    ),
)



In [None]:
from unsloth.chat_templates import train_on_responses_only

trainer = train_on_responses_only(
    trainer,
    instruction_part = "<extra_id_1>User\n",
    response_part = "<extra_id_1>Assistant\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[5]["input_ids"])

In [None]:
space = tokenizer(" ", add_special_tokens = False).input_ids[0]
tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5]["labels"]])

In [None]:
trainer_stats = trainer.train()

In [None]:
from huggingface_hub import login
login()

In [None]:
model.push_to_hub_merged("nis12ram/Nemotron-4-Mini-Hindi-4B-data-mixing-exp1", tokenizer, save_method = "merged_16bit")