<a href="https://colab.research.google.com/github/darshandugar2004/CorporateMailHandler/blob/main/Gemma_Finetuing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Artificial Data generation

In [None]:
# Install required libraries for fine-tuning
!pip install -q -U transformers datasets accelerate peft trl
!pip install -U bitsandbytes

# Artifiical Data Generation

In [None]:
# Block 1: Generate an Enhanced and Diverse Training Dataset

import json

# ==============================================================================
# DATASET ENHANCEMENT NOTES:
# 1. Quality Check: Manually review each "completion" for perfect grammar and tone.
# 2. Further Expansion: To generate more high-quality data, consider using a
#    powerful LLM with a prompt like:
#    "Generate 10 diverse training examples for a corporate assistant LLM with the
#     intent 'Sustainability Initiative'. Include successes, failures, and requests
#     for clarification. Format as a JSON object with 'details' and 'completion' keys."
# ==============================================================================

scenarios = {
    "Merger Announcement": [
        "Details: Verification API: SUCCESS. Ticket #T5821 raised for SM approval.",
        "Details: Verification API: FAILED - Inconsistent company registration number. Ticket #T522 raised for manual review.",
        "Details: Verification API: PARTIAL SUCCESS. Director names verified, financial statements pending upload. Ticket #T5823 for follow-up.",
        "Details: Verification API: ERROR - Service timed out. Ticket #T5824 raised with IT. Will re-attempt verification.",
        "Details: Senior Manager approval received for Ticket #T5821. Notifying the legal team to proceed.",
        # --- NEW DIVERSITY EXAMPLE (URGENCY) ---
        "Details: URGENT - Verification API SUCCESS. Ticket #T5825 raised for IMMEDIATE SM approval due to market sensitivity.",
    ],
    "Sustainability Initiative": [
        "Details: RAG system extracted the following from Q3 report: 'Carbon emissions reduced by 15% year-over-year, exceeding the 10% target.'",
        "Details: User asked for water usage metrics. RAG system found no specific data for 'water usage' in the latest report.",
        "Details: User inquired about waste reduction. RAG found: 'Our Zero Waste to Landfill initiative successfully diverted 95% of waste in the past year.'",
        "Details: User asked for data on renewable energy. RAG found: 'As of Q3 2025, 45% of our global energy consumption is from renewable sources.'",
        "Details: User asked for volunteering hours and packaging sources. RAG found volunteering data: '10,000+ employee volunteer hours contributed.' No data found for 'packaging sources.'",
        # --- NEW DIVERSITY EXAMPLE (AMBIGUITY) ---
        "Details: User sent a vague email: 'send sustainability info'. RAG system cannot proceed without more specific keywords.",
    ]
}

completions = {
    "Merger Announcement": [
        "Dear Team,\n\nThis is to confirm that the preliminary details for the upcoming merger have been successfully verified. A ticket (#T5821) has been raised for the Senior Manager's final approval.\n\nFurther updates will be shared post-approval.\n\nBest regards,\nCorporate Assistant",
        "Dear Team,\n\nAn issue was encountered while verifying the details for the upcoming merger. The company registration number was found to be inconsistent. A ticket (#T5822) has been created for the corporate governance team to conduct a manual review.\n\nWe will provide an update once the review is complete.\n\nBest regards,\nCorporate Assistant",
        "Dear Team,\n\nA partial verification was completed for the upcoming merger. While director names have been confirmed, the required financial statements are pending upload. A follow-up ticket (#T5823) has been created to track this item.\n\nRegards,\nCorporate Assistant",
        "Dear Team,\n\nWe encountered a temporary system issue while attempting to verify the merger details, as the verification service timed out. The IT department has been notified via ticket #T5824. The verification process will be re-attempted shortly.\n\nThank you for your patience.\n\nBest regards,\nCorporate Assistant",
        "Dear Team,\n\nThis is an update regarding the proposed merger. The Senior Manager has formally approved the initial verification (Ticket #T5821). The legal team has been notified and will now proceed with the subsequent steps.\n\nFurther communication will follow from the legal department.\n\nBest regards,\nCorporate Assistant",
        # --- NEW DIVERSITY COMPLETION (URGENCY) ---
        "URGENT ACTION REQUIRED\n\nDear Team,\n\nThis is to confirm that the preliminary details for the time-sensitive merger have been successfully verified. An urgent ticket (#T5825) has been raised for IMMEDIATE Senior Manager approval.\n\nYour prompt attention is required.\n\nBest regards,\nCorporate Assistant",
    ],
    "Sustainability Initiative": [
        "Dear Stakeholder,\n\nIn response to your query regarding our sustainability efforts, our Q3 report confirms a significant achievement: carbon emissions were reduced by 15% year-over-year, which surpasses our stated goal of 10%.\n\nThank you for your interest in our initiatives.\n\nSincerely,\nCorporate Assistant",
        "Dear Stakeholder,\n\nThank you for your inquiry. While our latest sustainability report details our progress in areas like carbon emissions, it does not contain specific metrics on water usage.\n\nYour query has been forwarded to the sustainability department for more detailed information.\n\nSincerely,\nCorporate Assistant",
        "Dear Stakeholder,\n\nRegarding your inquiry on waste reduction, our latest report confirms that our 'Zero Waste to Landfill' initiative has been highly effective, successfully diverting 95% of our manufacturing waste from landfills in the past year.\n\nWe appreciate your interest in our environmental programs.\n\nSincerely,\nCorporate Assistant",
        "Dear Stakeholder,\n\nIn response to your query about renewable energy, our data from Q3 2025 shows that 45% of our global energy consumption is now sourced from renewable providers, primarily solar and wind.\n\nThank you for your engagement on this important topic.\n\nSincerely,\nCorporate Assistant",
        "Dear Stakeholder,\n\nThank you for your inquiry. In response to your questions:\n- Our employees have contributed over 10,000 volunteer hours to community projects this year.\n- Our current sustainability report does not contain specific data on packaging material sources.\n\nYour query on packaging has been forwarded to the supply chain department for further details.\n\nSincerely,\nCorporate Assistant",
        # --- NEW DIVERSITY COMPLETION (AMBIGUITY) ---
        "Dear Stakeholder,\n\nThank you for your interest in our sustainability initiatives. To provide you with the most relevant information, could you please specify which area you are interested in (e.g., carbon emissions, waste reduction, renewable energy)?\n\nWe look forward to providing you with the data you need.\n\nSincerely,\nCorporate Assistant",
    ]
}


# Create the formatted dataset (No changes to this logic)
dataset = []
for intent, details_list in scenarios.items():
    # Adding a check to prevent errors if completions are missing
    if intent in completions:
        for i, details in enumerate(details_list):
            prompt = f"As a corporate assistant, write a formal email based on the following intent and details. Intent: {intent}. {details}"
            completion = completions[intent][i]
            formatted_text = f"<s>[INST] {prompt} [/INST] {completion}</s>"
            dataset.append({"text": formatted_text})

# Save the dataset to a JSONL file
output_file = "style_dataset.jsonl"
with open(output_file, "w") as f:
    for entry in dataset:
        f.write(json.dumps(entry) + "\n")

print(f"Enhanced dataset with {len(dataset)} examples created and saved to '{output_file}'")

In [None]:
print(f"Dataset with {len(dataset)} examples created and saved to '{output_file}'")
print("\n--- Sample Entry ---")
dataset[0]

## Finetune LLAMA for email generation

In [None]:
from datasets import load_dataset
dataset_file = "style_dataset.jsonl"
# --- 4. Load Dataset ---
dataset = load_dataset("json", data_files=dataset_file, split="train")

In [None]:
!pip install trl
!pip install -U bitsandbytes

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from huggingface_hub import login
import os
from google.colab import userdata


HF_TOKEN = userdata.get('HF_TOKEN')
login(token=HF_TOKEN)

# --- 3. Configuration ---
# CHANGED: Model ID is now Gemma-2B
model_id = "google/gemma-2b-it"
dataset_file = "style_dataset.jsonl"
# CHANGED: New output directory for the new model
output_dir = "./fine_tuned_gemma_2b_adapters"

# --- 4. Load Dataset ---
dataset = load_dataset("json", data_files=dataset_file, split="train")

In [None]:
# --- 4. Define Quantization Configuration ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    # bnb_4bit_quant_type="nf4",
    # bnb_4bit_compute_dtype=torch.bfloat16
)

# --- 5. Load Model and Tokenizer ---
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    # quantization_config=bnb_config,
    device_map="auto",
    # use_auth_token=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)


In [None]:
# --- Prepare a Test Prompt ---
prompt = "As a corporate assistant, write a formal email based on the following intent and details. Intent: Merger Announcement. Details: Verification API: FAILED - Inconsistent company registration number. Ticket #T9999 raised for manual review."
# Use the Gemma-specific prompt format
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

# --- Generate the Response ---
inputs = tokenizer(formatted_prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=200)
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
completion = response_text.split("<start_of_turn>model\n")[-1]

print("--- RESPONSE FROM BASE GEMMA-2B MODEL (BEFORE TRAINING) ---")
print(completion)

In [None]:
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# --- 6. Configure LoRA ---
# The target modules are the same for Gemma
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
# No need for `prepare_model_for_kbit_training` in recent PEFT versions
model = get_peft_model(model, lora_config)

In [None]:
# --- 7. Set Up Training ---
# NOTE: Increased batch size as Gemma-2B is smaller and uses less memory
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,  # Increased from 2
    gradient_accumulation_steps=2, # Decreased to keep effective batch size (4*2=8)
    learning_rate=2e-4,
    logging_steps=10,
    num_train_epochs=5,
    save_strategy="epoch",
    report_to="none"
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=training_args,
    peft_config=lora_config,
    # tokenizer=tokenizer,
    # max_seq_length=1024
)

# --- 8. Start Training ---
print("Starting Gemma-2B fine-tuning...")
trainer.train()
print("Fine-tuning complete!")

# --- 9. Save the Final Adapters ---
trainer.save_model(output_dir)
print(f"Gemma-2B model adapters saved to {output_dir}")

In [None]:
# ==============================================================================
# 🚀 PART 3: RUNNING A TEST INFERENCE
# ==============================================================================
print("\n--- Running test inference ---")

# --- Step 3.1: Prepare a test prompt ---
# This prompt must follow the EXACT same format as your training data.
test_intent = "Merger Announcement"
test_details = "Details: Verification API: FAILED - Inconsistent company registration number. Ticket #T9999 raised for manual review."

prompt = f"As a corporate assistant, write a formal email based on the following intent and details. Intent: {test_intent}. {test_details}"
formatted_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

# --- Step 3.2: Generate the response ---
# Set the model to evaluation mode
model.eval()

# Tokenize the input
inputs = tokenizer(formatted_prompt, return_tensors="pt", padding=True).to("cuda")

# Generate the output with optimized parameters
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=200,
        temperature=0.2, # Low temperature for professional, predictable output
        do_sample=True
    )

# --- Step 3.3: Decode and print the result ---
response_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Clean up the output to only show the model's completion
completion = response_text.split("<start_of_turn>model\n")[-1]

print(f"\nPROMPT:\n{prompt}")
print("-" * 20)
print(f"MODEL RESPONSE:\n{completion}")