In [None]:
import torch
import json
import random
import re
import pandas as pd
import numpy as np
import os

import time
import google.generativeai as genai
from typing import List, Dict, Tuple
import requests
from datasets import Dataset
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
from pathlib import Path
from trl import DPOTrainer, DPOConfig
from datetime import datetime
os.environ["WANDB_DISABLED"] = "true"

os.environ["GOOGLE_API_KEY"] = "your-api-key"
genai.configure(api_key="your-api-key")


# Load FLAN-T5 Small model and tokenizer
model_name = "google/flan-t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

# Check model size
print(f"Model parameters: {model.num_parameters():,}")

Model parameters: 76,961,152


In [29]:
class FlanT5PairGenerator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.gemini_model = genai.GenerativeModel('gemini-1.5-flash')

    def generate_response_pair(self, question: str, max_length: int = 150) -> Tuple[str, str]:
        """Generate two different responses from FLAN-T5 using different parameters"""

        formatted_prompt = f"Answer this legal question: {question}"

        # Tokenize input
        inputs = self.tokenizer(
            formatted_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        )

        # Generate first response (more conservative)
        with torch.no_grad():
            outputs1 = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.3,  # Lower temperature
                do_sample=True,
                top_p=0.8,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id
            )

        # Generate second response (more creative)
        with torch.no_grad():
            outputs2 = self.model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=max_length,
                num_return_sequences=1,
                temperature=0.8,  # Higher temperature
                do_sample=True,
                top_p=0.9,
                top_k=50,
                eos_token_id=self.tokenizer.eos_token_id,
                pad_token_id=self.tokenizer.pad_token_id
            )

        # Decode responses
        response1 = self.tokenizer.decode(outputs1[0], skip_special_tokens=True)
        response2 = self.tokenizer.decode(outputs2[0], skip_special_tokens=True)

        return response1, response2

    def judge_with_gemini(self, question: str, response1: str, response2: str) -> Dict:
        """Use Gemini to judge which response is better and provide reasoning"""

        judge_prompt = f"""
        You are evaluating two AI responses to a legal question. Please determine which response is better and explain why.

        Question: {question}

        Response A: {response1}

        Response B: {response2}

        Evaluate based on:
        1. Legal accuracy and completeness
        2. Clarity and organization
        3. Use of proper legal terminology
        4. Comprehensiveness of the answer
        5. Professional tone

        Respond with:
        - "WINNER: A" or "WINNER: B"
        - Brief explanation (2-3 sentences) of why that response is better
        - If responses are very similar in quality, choose "WINNER: TIE"

        Format your response as:
        WINNER: [A/B/TIE]
        REASONING: [Your explanation]
        """

        try:
            response = self.gemini_model.generate_content(judge_prompt)
            response_text = response.text.strip()

            # Parse the response
            lines = response_text.split('\n')
            winner_line = next((line for line in lines if line.startswith('WINNER:')), '')
            reasoning_line = next((line for line in lines if line.startswith('REASONING:')), '')

            winner = winner_line.replace('WINNER:', '').strip()
            reasoning = reasoning_line.replace('REASONING:', '').strip()

            return {
                'winner': winner,
                'reasoning': reasoning,
                'full_response': response_text
            }

        except Exception as e:
            print(f"Error with Gemini judging: {e}")
            return {
                'winner': 'ERROR',
                'reasoning': f'Error occurred: {e}',
                'full_response': ''
            }

    def create_dpo_training_pair(self, question: str, response1: str, response2: str, judgment: Dict) -> Dict:
        """Create a DPO training example from the judged responses"""

        if judgment['winner'] == 'A':
            chosen = response1
            rejected = response2
        elif judgment['winner'] == 'B':
            chosen = response2
            rejected = response1
        else:  # TIE or ERROR
            return None

        return {
            'question': f"Answer this legal question: {question}",
            'chosen': chosen,
            'rejected': rejected,
            'judgment_reasoning': judgment['reasoning'],
            'original_question': question
        }

In [None]:
from IPython.display import clear_output
clear_output(wait=True)

In [30]:
def generate_training_questions(num_questions: int = 200) -> List[str]:
    """Generate diverse legal questions for training (separate from test set)"""

    question_templates = [
        "What are the key requirements for establishing {} under federal law?",
        "How do courts typically analyze {} in criminal cases?",
        "What defenses are available against {} charges?",
        "Explain the procedural requirements for {} in civil litigation.",
        "What constitutes {} under the Model Penal Code?",
        "How does the {} doctrine apply in corporate law contexts?",
        "What are the elements of a successful {} claim?",
        "When can {} be used as a defense in criminal proceedings?",
        "What standards do courts apply when evaluating {} motions?",
        "How has recent case law affected {} requirements?",
        "What are the constitutional limitations on {} powers?",
        "Explain the relationship between {} and due process rights.",
        "What factors determine {} in federal court?",
        "How do {} rules differ between state and federal courts?",
        "What are the ethical considerations surrounding {} in legal practice?"
    ]

    legal_concepts = [
        "insider trading", "racketeering", "conspiracy", "money laundering",
        "securities fraud", "antitrust violations", "mail fraud", "wire fraud",
        "civil forfeiture", "habeas corpus", "qualified immunity", "sovereign immunity",
        "preliminary injunctions", "summary judgment", "class certification", "discovery sanctions",
        "attorney-client privilege", "work product doctrine", "judicial review", "administrative exhaustion",
        "constitutional standing", "political question doctrine", "abstention doctrines", "preemption",
        "substantive due process", "procedural due process", "equal protection", "strict scrutiny",
        "commercial speech", "prior restraint", "content-based restrictions", "viewpoint discrimination",
        "search warrants", "probable cause", "reasonable suspicion", "Miranda warnings",
        "exclusionary rule", "fruit of poisonous tree", "inevitable discovery", "good faith exception"
    ]

    questions = []

    # Generate from templates
    for _ in range(num_questions // 2):
        template = random.choice(question_templates)
        concept = random.choice(legal_concepts)
        questions.append(template.format(concept))

    # Add specific questions
    specific_questions = [
        "What constitutes a Brady violation in criminal proceedings?",
        "How does the business judgment rule protect corporate directors?",
        "What are the requirements for piercing the corporate veil?",
        "When does the attorney-client privilege apply in corporate investigations?",
        "What constitutes deliberate indifference under Section 1983?",
        "How do courts analyze First Amendment retaliation claims?",
        "What are the elements of a successful RICO prosecution?",
        "When can law enforcement conduct warrantless searches?",
        "What constitutes effective assistance of counsel?",
        "How does qualified immunity protect government officials?",
        "What are the requirements for federal diversity jurisdiction?",
        "When must courts abstain from hearing state law claims?",
        "What constitutes a taking under the Fifth Amendment?",
        "How do courts analyze dormant Commerce Clause challenges?",
        "What are the requirements for preliminary injunctive relief?",
        "When can courts impose discovery sanctions?",
        "What constitutes prosecutorial misconduct?",
        "How does the exclusionary rule apply to digital evidence?",
        "What are the elements of mail and wire fraud?",
        "When can corporations assert constitutional rights?"
    ]

    questions.extend(specific_questions[:num_questions // 2])
    random.shuffle(questions)

    return questions[:num_questions]

In [31]:
def generate_dpo_dataset(model, tokenizer, num_examples: int = 100) -> List[Dict]:
    """Generate complete DPO dataset using FLAN-T5 pairs + Gemini judging"""

    generator = FlanT5PairGenerator(model, tokenizer)
    questions = generate_training_questions(num_examples)

    training_data = []
    successful_examples = 0

    for i, question in enumerate(questions):
        print(f"Processing {i+1}/{len(questions)}: {question[:60]}...")

        try:
            # Generate response pair from FLAN-T5
            response1, response2 = generator.generate_response_pair(question)

            # Skip if responses are identical
            if response1.strip() == response2.strip():
                print("  Skipping - identical responses")
                continue

            print(f"  Response A: {response1[:100]}...")
            print(f"  Response B: {response2[:100]}...")

            # Judge with Gemini
            judgment = generator.judge_with_gemini(question, response1, response2)
            print(f"  Winner: {judgment['winner']}")
            print(f"  Reasoning: {judgment['reasoning'][:100]}...")

            # Create training pair
            if judgment['winner'] in ['A', 'B']:
                training_pair = generator.create_dpo_training_pair(question, response1, response2, judgment)
                if training_pair:
                    training_data.append(training_pair)
                    successful_examples += 1

            time.sleep(1)  # Rate limiting for Gemini

        except Exception as e:
            print(f"  Error processing question: {e}")
            continue

    print(f"\nGenerated {successful_examples} successful training examples out of {len(questions)} questions")
    return training_data

In [35]:
# Usage with your existing model and tokenizer
if __name__ == "__main__":
    # Assuming your model and tokenizer are already loaded
    print("Generating DPO training dataset...")

    # Generate training dataset
    training_data = generate_dpo_dataset(model, tokenizer, num_examples=50)  # Start small

    # Save the data
    with open('flan_t5_dpo_training.json', 'w') as f:
        json.dump(training_data, f, indent=2)

    print(f"Saved {len(training_data)} training examples to 'flan_t5_dpo_training.json'")

    # Show sample
    if training_data:
        print("\nSample training example:")
        sample = training_data[0]
        print(f"Question: {sample['original_question']}")
        print(f"Chosen: {sample['chosen'][:200]}...")
        print(f"Rejected: {sample['rejected'][:200]}...")
        print(f"Reasoning: {sample['judgment_reasoning']}")

Generating DPO training dataset...
Processing 1/45: Explain the relationship between content-based restrictions ...
  Response A: The content-based restrictions are governed by the judicial system and are governed by the judicial ...
  Response B: The rights of those who are entitled to content are the same as those who have access to the content...
  Winner: TIE
  Reasoning: Both responses are inadequate and fail to accurately address the relationship between content-based ...
Processing 2/45: What constitutes a taking under the Fifth Amendment?...
  Response A: a reversal of the law...
  Response B: a death penalty...
  Winner: TIE
  Reasoning: Both responses are completely inaccurate and fail to address the legal concept of a "taking" under t...
Processing 3/45: When can attorney-client privilege be used as a defense in c...
  Response A: a court hearing...
  Response B: August 16, 2017...
  Winner: A
  Reasoning: Response A, while extremely brief, at least provides a relevant conte

In [48]:
class LegalT5DPOTrainer:
    def __init__(self, model_name: str = "google/flan-t5-small"):
        self.model_name = model_name
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)

        # Set pad token if not already set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print(f"Loaded model: {model_name}")
        print(f"Model parameters: {self.model.num_parameters():,}")

    def prepare_dpo_dataset(self, training_data_file: str):
        """Load and prepare DPO training dataset"""

        with open(training_data_file, 'r') as f:
            data = json.load(f)

        # Prepare data for DPO format - simplified for CPU training
        dataset_dict = {
            'prompt': [item['question'] for item in data],
            'chosen': [item['chosen'] for item in data],
            'rejected': [item['rejected'] for item in data]
        }

        dataset = Dataset.from_dict(dataset_dict)
        print(f"Prepared dataset with {len(dataset)} examples")
        return dataset

    def train_with_dpo(self, training_data_file: str, output_dir: str = "./flan-t5-legal-dpo"):
        """Train the model using DPO (simplified for compatibility)"""

        try:
            from trl import DPOTrainer, DPOConfig
        except ImportError:
            raise Exception("DPO requires 'trl' library: pip install trl")

        # Prepare dataset
        train_dataset = self.prepare_dpo_dataset(training_data_file)

        # CPU-friendly DPO Training configuration
        training_args = DPOConfig(
            output_dir=output_dir,
            num_train_epochs=2,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=20,
            logging_steps=5,
            save_steps=100,
            learning_rate=1e-5,
            fp16=False,
            bf16=False,
            remove_unused_columns=False,
            report_to=None,
            dataloader_pin_memory=False,
            dataloader_num_workers=0,
        )

        print("Initializing DPO trainer for CPU training...")

        # Try minimal DPOTrainer initialization
        try:
            dpo_trainer = DPOTrainer(
                model=self.model,
                args=training_args,
                train_dataset=train_dataset,
            )
        except Exception as e:
            print(f"DPOTrainer initialization failed: {e}")
            # Try with different parameter combinations
            try:
                dpo_trainer = DPOTrainer(
                    self.model,
                    training_args,
                    train_dataset,
                )
            except Exception as e2:
                print(f"Alternative DPOTrainer init failed: {e2}")
                raise Exception("DPO initialization failed - falling back to standard fine-tuning")

        print("Starting DPO training on CPU (this may take a while)...")
        dpo_trainer.train()

        # Save the fine-tuned model
        dpo_trainer.save_model(output_dir)
        self.tokenizer.save_pretrained(output_dir)

        print(f"DPO training complete! Model saved to {output_dir}")
        return output_dir

    def train_with_standard_finetuning(self, training_data_file: str, output_dir: str = "./flan-t5-legal-ft"):
        """Alternative: Standard fine-tuning approach (if DPO doesn't work)"""

        with open(training_data_file, 'r') as f:
            data = json.load(f)

        # Prepare data for standard fine-tuning (using only chosen responses)
        train_data = []
        for item in data:
            train_data.append({
                'input_text': item['question'],
                'target_text': item['chosen']
            })

        # Create dataset
        def preprocess_function(examples):
            inputs = [ex['input_text'] for ex in examples]
            targets = [ex['target_text'] for ex in examples]

            model_inputs = self.tokenizer(
                inputs,
                max_length=512,
                truncation=True,
                padding='max_length'
            )

            labels = self.tokenizer(
                targets,
                max_length=256,
                truncation=True,
                padding='max_length'
            )

            model_inputs["labels"] = labels["input_ids"]
            return model_inputs

        dataset = Dataset.from_list(train_data)
        tokenized_dataset = dataset.map(preprocess_function, batched=True)

        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            num_train_epochs=3,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=2,
            warmup_steps=100,
            logging_steps=10,
            save_steps=500,
            learning_rate=3e-4,
            fp16=torch.cuda.is_available(),
            report_to=None,
        )

        # Data collator
        data_collator = DataCollatorForSeq2Seq(
            tokenizer=self.tokenizer,
            model=self.model,
            padding=True
        )

        # Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_dataset,
            data_collator=data_collator,
        )

        print("Starting standard fine-tuning...")
        trainer.train()

        # Save model
        trainer.save_model(output_dir)
        self.tokenizer.save_pretrained(output_dir)

        print(f"Fine-tuning complete! Model saved to {output_dir}")
        return output_dir

In [51]:
def generate_text_with_model(model, tokenizer, prompt, max_length=200):
    """Generate text using a specific model"""
    formatted_prompt = f"Answer this legal question: {prompt}"

    inputs = tokenizer(
        formatted_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [52]:
def evaluate_before_after(original_model_name: str, finetuned_model_path: str, test_questions: list):
    """Compare original and fine-tuned model responses"""

    print("Loading models for comparison...")

    # Load original model
    original_tokenizer = T5Tokenizer.from_pretrained(original_model_name)
    original_model = T5ForConditionalGeneration.from_pretrained(original_model_name)

    # Load fine-tuned model
    finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
    finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)

    print("Generating responses...")

    results = []

    for i, question in enumerate(test_questions, 1):
        print(f"Processing question {i}/{len(test_questions)}")

        # Generate responses
        original_response = generate_text_with_model(original_model, original_tokenizer, question)
        finetuned_response = generate_text_with_model(finetuned_model, finetuned_tokenizer, question)

        results.append({
            'question': question,
            'original_response': original_response,
            'finetuned_response': finetuned_response
        })

    return results

In [53]:
def save_comparison_results(results: list, filename: str = "before_after_comparison.json"):
    """Save comparison results to file"""
    with open(filename, 'w') as f:
        json.dump(results, f, indent=2)

    # Also create a readable text version
    text_filename = filename.replace('.json', '.txt')
    with open(text_filename, 'w') as f:
        f.write("FLAN-T5 Legal Fine-tuning: Before vs After Comparison\n")
        f.write("=" * 60 + "\n\n")

        for i, result in enumerate(results, 1):
            f.write(f"QUESTION {i}:\n{result['question']}\n\n")
            f.write(f"ORIGINAL MODEL:\n{result['original_response']}\n\n")
            f.write(f"FINE-TUNED MODEL:\n{result['finetuned_response']}\n\n")
            f.write("-" * 60 + "\n\n")

    print(f"Results saved to {filename} and {text_filename}")


In [56]:
# Main execution
if __name__ == "__main__":
    # Your test questions
    test_questions = [
        "What is the legal standard for establishing proximate cause in tort law?",
        "Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.",
        "What constitutes a material breach of contract versus a minor breach?",
        "Define the elements required to prove negligence in a personal injury case.",
        "What is the doctrine of respondeat superior and when does it apply?",
        "Explain the difference between joint tenancy and tenancy in common.",
        "What are the requirements for a valid will under most state laws?",
        "Define the burden of proof in criminal cases versus civil cases.",
        "What is the statute of limitations and how does it vary by type of legal claim?",
        "Explain the concept of adverse possession and its legal requirements.",
        "What constitutes defamation and what defenses are available?",
        "Define the difference between an easement and a license in property law.",
        "What are the elements of a valid contract formation?",
        "Explain the doctrine of comparative negligence versus contributory negligence.",
        "What is the legal concept of standing to sue in federal court?"
    ]

    # Step 1: Initialize trainer
    trainer = LegalT5DPOTrainer("google/flan-t5-small")

    # Step 2: Train the model (make sure you have generated training data first)
    training_data_file = "flan_t5_dpo_training.json"  # From the previous script

    if os.path.exists(training_data_file):
        print("Training data found. Starting training...")

        try:
            # Try DPO training first (with better error handling)
            finetuned_model_path = trainer.train_with_dpo(training_data_file)
        except Exception as e:
            print(f"DPO training failed: {e}")
            print("Falling back to standard fine-tuning (which works just as well for your portfolio)...")
            finetuned_model_path = trainer.train_with_standard_finetuning(training_data_file)

        # Step 3: Load both models for comparison
        print("\n" + "="*80)
        print("LOADING MODELS FOR COMPARISON")
        print("="*80)

        # Load original model
        print("Loading original FLAN-T5 model...")
        original_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
        original_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")

        # Load fine-tuned model
        print("Loading fine-tuned model...")
        finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
        finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path)

        # Step 4: Generate and display side-by-side comparison
        print("\n" + "="*80)
        print("BEFORE vs AFTER TRAINING COMPARISON")
        print("="*80)

        for i, question in enumerate(test_questions, 1):
            print(f"\n{'='*20} QUESTION {i} {'='*20}")
            print(f"Q: {question}")

            # Generate response with original model
            original_response = generate_text_with_model(original_model, original_tokenizer, question, max_length=200)
            print(f"ORIGINAL MODEL: {original_response}")


            # Generate response with fine-tuned model
            finetuned_response = generate_text_with_model(finetuned_model, finetuned_tokenizer, question, max_length=200)
            print(f"FINE-TUNED MODEL: {finetuned_response}")

            print("\n" + "="*100)

        print("\nTraining and evaluation complete!")
        print("You can see the side-by-side comparison above to evaluate improvements.")

    else:
        print(f"Training data file '{training_data_file}' not found.")
        print("Please run the training data generation script first.")

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Loaded model: google/flan-t5-small
Model parameters: 76,961,152
Training data found. Starting training...
Prepared dataset with 11 examples
Initializing DPO trainer for CPU training...


Extracting prompt in train dataset:   0%|          | 0/11 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/11 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/11 [00:00<?, ? examples/s]

Starting DPO training on CPU (this may take a while)...


Step,Training Loss
5,1.0943


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


DPO training complete! Model saved to ./flan-t5-legal-dpo

LOADING MODELS FOR COMPARISON
Loading original FLAN-T5 model...
Loading fine-tuned model...

BEFORE vs AFTER TRAINING COMPARISON

Q: What is the legal standard for establishing proximate cause in tort law?
ORIGINAL MODEL: Prohibition of a claim
FINE-TUNED MODEL: proximate cause


Q: Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.
ORIGINAL MODEL: a warranty deed is a separate claim from a reclaim on the property.
FINE-TUNED MODEL: Preventive compensation is typically the same as a loss of income or the loss of money, including the loss of a house in a mortgage or a financial gain.


Q: What constitutes a material breach of contract versus a minor breach?
ORIGINAL MODEL: contract-only agreement
FINE-TUNED MODEL: contract


Q: Define the elements required to prove negligence in a personal injury case.
ORIGINAL MODEL: The elements in the case are the following:
FINE-TUNED MODEL: A pe