<a href="https://colab.research.google.com/github/ghoshankur102/Judge_It_Well/blob/main/LegalAI_250153.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
# ^ Hides the massive wall of text output to keep your notebook clean

import torch
major_version, minor_version = torch.cuda.get_device_capability()

# ------------------------------------------------------------------------
# OPTIMIZATION EXPLANATION:
# 1. unsloth[colab-new]: Colab frequently updates its Python/PyTorch versions.
#    The 'colab-new' branch is specifically patched to work with the latest
#    Google Colab environment (Torch 2.3+).
# ------------------------------------------------------------------------
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# ------------------------------------------------------------------------
# 2. --no-deps: This is the SECRET SAUCE.
#    If you just pip install xformers, it will try to uninstall the
#    Colab-native PyTorch and install an older version.
#    This causes the runtime to crash immediately.
#    "--no-deps" forces it to use the pre-installed, optimized PyTorch.
# ------------------------------------------------------------------------
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

# 3. Fix Locale Issue: Colab sometimes defaults to ASCII, which breaks
#    Unsloth's loading bars. This forces UTF-8 encoding.
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [13]:
!pip install --upgrade sympy # Add this line to upgrade sympy

from unsloth import FastLanguageModel
import torch

# ------------------------------------------------------------------------
# CONFIGURATION:
# ------------------------------------------------------------------------
# 1. Max Sequence Length:
#    Llama-3 supports up to 8192 tokens. However, on a free Colab GPU,
#    loading 8192 tokens with LoRA will cause an OutOfMemory (OOM) error.
#    We limit it to 2048 (approx 1500 words) to ensure stability.
# ------------------------------------------------------------------------
max_seq_length = 2048

# 2. Dtype (Data Type):
#    We set this to None so Unsloth auto-detects your GPU.
#    On T4 (Colab Free), it uses Float16. On Ampere (A100), it uses Bfloat16.
# ------------------------------------------------------------------------
dtype = None

# 3. 4-Bit Quantization:
#    MANDATORY for free Colab. It shrinks the model from 16GB -> 5.5GB.
#    Without this, the model won't even load.
# ------------------------------------------------------------------------
load_in_4bit = True

print("⏳ Loading Llama-3 Model... this might take 1-2 minutes...")

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

print(f"✅ Model Loaded Successfully. Context Window set to: {max_seq_length} tokens.")

⏳ Loading Llama-3 Model... this might take 1-2 minutes...
==((====))==  Unsloth 2026.1.4: Fast Llama patching. Transformers: 4.57.6.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

✅ Model Loaded Successfully. Context Window set to: 2048 tokens.


In [12]:
import json
import glob # To find all JSON files

# 1. Merge into a single JSON Array (most common)
all_data = []
# Exclude 'merged.json' from the list of files to process
for filename in [f for f in glob.glob('*.json') if f != 'merged.json']:
    print(f"Processing file: {filename}")
    try:
        with open(filename, 'r') as f:
            data = json.load(f)
            # If data is a list, extend; if an object, append it as an element
            if isinstance(data, list):
                all_data.extend(data)
            else:
                all_data.append(data)
    except json.JSONDecodeError as e:
        print(f"❌ Error: JSONDecodeError in file '{filename}': {e}")
        print("Please check the file for syntax errors, especially unclosed strings.")
    except Exception as e:
        print(f"❌ Error processing file '{filename}': {e}")

with open('constitution_qa.json', 'w') as outfile:
    json.dump(all_data, outfile, indent=4)

Processing file: constitution_qa.json


In [11]:
import json
from datasets import Dataset

# 1. Load the raw file
# Make sure 'merged.json' is uploaded in the Files section on the left
file_path = "constitution_qa.json" # Changed from "constitution_qa.json" to "merged.json"

raw_data = [] # Initialize to empty list

try:
    with open(file_path, "r") as f:
        raw_data = json.load(f)
    print(f"File found. Parsing {len(raw_data)} records...")

except FileNotFoundError:
    print(f"❌ Error: '{file_path}' not found. Please upload it to the Colab Files folder (folder icon on the left).")
except Exception as e:
    print(f"❌ Error reading or parsing '{file_path}': {e}")


formatted_data = []
# 2. Convert Dictionary-of-Dictionaries to List-of-Dictionaries
# Your file has keys like "0", "1", "2". We need to strip these and just get the values.
if raw_data: # Only proceed if raw_data is not empty
    # Iterate directly over the list as raw_data is a list of dicts
    for entry in raw_data:
        # strict checking to ensure no empty rows crash the training
        if isinstance(entry, dict) and entry.get("question") and entry.get("answer"):
            formatted_data.append({
                "question": entry["question"],
                "answer": entry["answer"],
                "source": "Legal_Corpus" # Adding a tag helps the model know the context
            })
    if not formatted_data:
        print("⚠️ Warning: No valid 'question' and 'answer' pairs found in the data.")
else:
    print("⚠️ Warning: raw_data is empty, skipping formatting.")


dataset = None # Initialize dataset to None
if formatted_data: # Create dataset only if formatted_data is not empty
    # 3. Create the Hugging Face Dataset object
    dataset = Dataset.from_list(formatted_data)
    print(f"✅ Success! Converted to training dataset with {len(dataset)} rows.")
    print("Sample row:", dataset[0])
else:
    print("❌ Error: Cannot create dataset, formatted_data is empty.")
    print("Please ensure your JSON file contains valid 'question' and 'answer' entries.")


File found. Parsing 4082 records...
✅ Success! Converted to training dataset with 4082 rows.
Sample row: {'question': 'What is India according to the Union and its Territory?', 'answer': 'India, that is Bharat, shall be a Union of States.', 'source': 'Legal_Corpus'}


In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [15]:
from unsloth import FastLanguageModel # Added this line to ensure FastLanguageModel is defined

# 1. Define the Llama-3 Chat Template
# This specific format is required for Llama-3 Instruct models
legal_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an expert Indian Legal Assistant. Answer strictly based on the provided context.<|eot_id|><|start_header_id|>user<|end_header_id|>

Question: {}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{}"""

# Only proceed with tokenizer and dataset mapping if dataset is not None and tokenizer is defined
if dataset is not None:
    # Ensure tokenizer is in scope. If not, this will still raise a NameError.
    # However, given previous cell executed, it should be available.
    try:
        EOS_TOKEN = tokenizer.eos_token # Must add this so the model knows when to stop talking
    except NameError:
        print("❌ Error: 'tokenizer' not found. Please ensure the model loading cell (cell 7xJQ2OAgKN0e) has been run successfully.")
        EOS_TOKEN = "<|end_of_text|>" # Fallback, though generation might not work as expected

    def formatting_prompts_func(examples):
        questions = examples["question"]
        answers   = examples["answer"]
        texts     = []

        for q, a in zip(questions, answers):
            # Format: System -> User Question -> Assistant Answer -> EOS
            text = legal_prompt.format(q, a) + EOS_TOKEN
            texts.append(text)

        return { "text" : texts }

    # 2. Apply the format to the dataset
    # batched=True processes multiple rows at once (much faster)
    dataset = dataset.map(formatting_prompts_func, batched = True)

    print("✅ Data formatted successfully.")
    print("Sample Input to Model:\n",dataset[0]["text"])
else:
    print("Skipping data formatting and model input preparation as dataset is empty.")
# ------------------------------------------------------------------------
# LORA CONFIGURATION
# ------------------------------------------------------------------------
# r = 16: The "rank". Higher numbers (32, 64) learn more complex patterns
#         but use more VRAM. 16 is standard for a T4.
# target_modules: We target ALL linear layers. This results in smarter models
#                 than just targeting "q_proj" and "v_proj".
# ------------------------------------------------------------------------

# Check if model and tokenizer are defined (they should be from cell 7xJQ2OAgKN0e)
if 'model' not in locals() or 'tokenizer' not in locals():
    print("❌ Error: 'model' and/or 'tokenizer' are not defined. Please ensure cell 7xJQ2OAgKN0e (Llama-3 Model Loading) has been run successfully to initialize them.")
    model = None # Set model to None to prevent a NameError if the next line tries to use it
else:
    model = FastLanguageModel.get_peft_model(
        model,
        r = 16,
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                          "gate_proj", "up_proj", "down_proj"],
        lora_alpha = 16,
        lora_dropout = 0, # Dropout = 0 is faster
        bias = "none",    # Bias = "none" is faster
        use_gradient_checkpointing = "unsloth", # The secret to not running out of VRAM
        random_state = 3407,
        use_rslora = False,
        loftq_config = None,
    )
    print("✅ LoRA Adapters attached. Model is ready for training.")

Map:   0%|          | 0/4082 [00:00<?, ? examples/s]

✅ Data formatted successfully.
Sample Input to Model:
 <|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are an expert Indian Legal Assistant. Answer strictly based on the provided context.<|eot_id|><|start_header_id|>user<|end_header_id|>

Question: What is India according to the Union and its Territory?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

India, that is Bharat, shall be a Union of States.<|end_of_text|>


Unsloth 2026.1.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


✅ LoRA Adapters attached. Model is ready for training.


In [16]:
from trl import SFTTrainer
from transformers import TrainingArguments

# ------------------------------------------------------------------------
# TRAINING ARGUMENTS (Optimized for Colab T4)
# ------------------------------------------------------------------------
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2, # Keep low to prevent OOM
        gradient_accumulation_steps = 4, # Simulates batch_size = 8
        warmup_steps = 5,
        max_steps = 60, # 60 steps for a quick test. Set to 0 for full epoch.
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit", # Saves massive VRAM
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

# Print memory stats before starting
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU: {gpu_stats.name} | Max Memory: {max_memory} GB")
print(f"Reserved Memory: {start_gpu_memory} GB")

print("🚀 Starting Training...")
trainer_stats = trainer.train()
print("✅ Training Complete!")

Unsloth: Tokenizing ["text"] (num_proc=6):   0%|          | 0/4082 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.


GPU: Tesla T4 | Max Memory: 14.741 GB
Reserved Memory: 6.967 GB
🚀 Starting Training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 4,082 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice:

 1


wandb: You chose 'Create a W&B account'
wandb: Create an account here: https://wandb.ai/authorize?signup=true&ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


wandb: No netrc file found, creating one.
wandb: Appending key for api.wandb.ai to your netrc file: /root/.netrc
wandb: Currently logged in as: ghoshankur102 (ghoshankur102-indian-institute-of-technology-kanpur) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


wandb: Detected [huggingface_hub.inference, openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/


Step,Training Loss
1,4.2364
2,4.4398
3,4.479
4,4.6948
5,4.0398
6,3.6414
7,2.9107
8,2.7957
9,2.0001
10,2.3065


0,1
train/epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇████
train/grad_norm,▇█▆▅▆▄▄▃▃▂▂▃▂▁▂▂▂▃▂▂▁▂▁▂▂▂▂▁▂▁▃▂▂▂▂▂▂▂▁▁
train/learning_rate,▁▂▄▅▇█▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▁▁
train/loss,▇▇██▇▄▃▂▃▃▂▁▁▁▁▂▂▁▂▁▁▂▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁

0,1
total_flos,2144699053867008.0
train/epoch,0.11759
train/global_step,60.0
train/grad_norm,1.0844
train/learning_rate,0.0
train/loss,1.4361
train_loss,1.97206
train_runtime,273.9292
train_samples_per_second,1.752
train_steps_per_second,0.219


✅ Training Complete!


In [25]:
# 1. Enable native 2x faster inference
FastLanguageModel.for_inference(model)

# 2. Define your test question
test_question = "?"

# 3. Format the input using the same template as training
input_text = legal_prompt.format(test_question, "is dual citizenship legal?") # Empty answer for generation

# 4. Tokenize and move to GPU
inputs = tokenizer([input_text], return_tensors = "pt").to("cuda")

# 5. Generate response
# max_new_tokens = 128 (limits the length of the answer)
outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)

# 6. Decode the numbers back to text
response = tokenizer.batch_decode(outputs)
print(response[0].split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip())

is dual citizenship legal?<|reserved_special_token_75|><|reserved_special_token_94|>assistant<|reserved_special_token_153|>

No, it is not legal. The Constitution of India does not recognize dual citizenship.<|end_of_text|>


In [18]:
# Save to a local folder named "Legal_Llama_LoRA"
model.save_pretrained("Legal_Llama_LoRA")
tokenizer.save_pretrained("Legal_Llama_LoRA")

print("✅ Model saved to folder 'Legal_Llama_LoRA'")

✅ Model saved to folder 'Legal_Llama_LoRA'
