In [None]:
import torch
import json
import random
import re
from datasets import Dataset
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)
import requests
from pathlib import Path
import os
os.environ["WANDB_DISABLED"] = "true"

# Load FLAN-T5 Small model and tokenizer
model_name = "google/flan-t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)

# Check model size
print(f"Model parameters: {model.num_parameters():,}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Model parameters: 76,961,152


In [None]:
def generate_text(prompt, max_length=100):
    # T5 expects text-to-text format, add "Answer: " prefix for better results
    formatted_prompt = f"Answer this legal question: {prompt}"

    # Tokenize input
    inputs = tokenizer(
        formatted_prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )

    # Generate response
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            num_return_sequences=1,
            temperature=0.7,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id
        )

    # Decode the response
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

In [None]:
# Test with your legal questions
test_questions = [
    "What is the legal standard for establishing proximate cause in tort law?",
    "Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.",
    "What constitutes a material breach of contract versus a minor breach?",
    "Define the elements required to prove negligence in a personal injury case.",
    "What is the doctrine of respondeat superior and when does it apply?",
    "Explain the difference between joint tenancy and tenancy in common.",
    "What are the requirements for a valid will under most state laws?",
    "Define the burden of proof in criminal cases versus civil cases.",
    "What is the statute of limitations and how does it vary by type of legal claim?",
    "Explain the concept of adverse possession and its legal requirements.",
    "What constitutes defamation and what defenses are available?",
    "Define the difference between an easement and a license in property law.",
    "What are the elements of a valid contract formation?",
    "Explain the doctrine of comparative negligence versus contributory negligence.",
    "What is the legal concept of standing to sue in federal court?"
]

# Generate responses
print("FLAN-T5 Small Baseline Responses:")
print("=" * 50)

for i, question in enumerate(test_questions, 1):
    response = generate_text(question, max_length=150)
    print(f"Q{i}: {question}")
    print(f"A{i}: {response}")
    print("-" * 50)

FLAN-T5 Small Baseline Responses:
Q1: What is the legal standard for establishing proximate cause in tort law?
A1: statute of limitations
--------------------------------------------------
Q2: Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.
A2: List of warranties in real estate
--------------------------------------------------
Q3: What constitutes a material breach of contract versus a minor breach?
A3: legal obligation
--------------------------------------------------
Q4: Define the elements required to prove negligence in a personal injury case.
A4: A person who is injure or is in a negligent manner may be liable for the injury by negligence.
--------------------------------------------------
Q5: What is the doctrine of respondeat superior and when does it apply?
A5: the anthem of
--------------------------------------------------
Q6: Explain the difference between joint tenancy and tenancy in common.
A6: Joint tenancy is a general t

In [None]:
# ============================================================================
# PART 1: DOWNLOAD LEGAL DOCUMENTS
# ============================================================================

def download_legal_documents():
    """Download relevant modern legal documents for training"""

    import requests
    import re
    from bs4 import BeautifulSoup

    legal_docs = []

    # More relevant legal sources for modern legal concepts
    legal_sources = {
        "restatement_torts": "https://www.law.cornell.edu/wex/tort",
        "restatement_contracts": "https://www.law.cornell.edu/wex/contract",
        "property_law": "https://www.law.cornell.edu/wex/property",
        "negligence_law": "https://www.law.cornell.edu/wex/negligence",
        "real_estate_law": "https://www.law.cornell.edu/wex/real_property",
        "evidence_law": "https://www.law.cornell.edu/wex/evidence",
        "corporate_law": "https://www.law.cornell.edu/wex/corporation",
        "family_law": "https://www.law.cornell.edu/wex/family_law"
    }

    print("Downloading relevant legal documents...")

    # Try to download from Cornell Law (Wex legal encyclopedia)
    for topic, url in legal_sources.items():
        try:
            response = requests.get(url, timeout=10)
            if response.status_code == 200:
                soup = BeautifulSoup(response.content, 'html.parser')

                # Extract main content (skip navigation, ads, etc.)
                content_divs = soup.find_all(['div', 'p', 'article'], class_=lambda x: x and any(
                    term in x.lower() for term in ['content', 'main', 'body', 'text', 'article']
                ))

                if not content_divs:
                    # Fallback: get all paragraph text
                    content_divs = soup.find_all('p')

                text_content = ""
                for div in content_divs:
                    text_content += div.get_text() + " "

                # Clean the text
                text_content = re.sub(r'\n+', ' ', text_content)
                text_content = re.sub(r'\s+', ' ', text_content)
                text_content = text_content.strip()

                if len(text_content) > 500:  # Only add if substantial content
                    legal_docs.append(text_content[:3000])  # Limit to 3000 chars
                    print(f"✓ Downloaded {topic}: {len(text_content)} characters")

        except Exception as e:
            print(f"✗ Failed to download {topic}: {e}")

    # If downloads failed or insufficient content, use comprehensive modern legal examples
    if len(legal_docs) < 5:
        print("Using modern legal training examples...")
        legal_docs = [
            """
            Negligence is established when a plaintiff proves four essential elements: duty of care, breach of duty,
            causation, and damages. The duty of care requires the defendant to conform to a standard of reasonable
            care to avoid unreasonable risks of harm to others. This standard is typically that of a reasonable
            person under similar circumstances. A breach occurs when the defendant's conduct falls below this
            standard. Causation has two components: factual causation, often determined by the but-for test, and
            proximate causation, which limits liability to foreseeable consequences. The plaintiff must prove that
            the defendant's breach was both a factual and proximate cause of the harm. Finally, the plaintiff must
            demonstrate actual damages, which can include economic losses, physical injury, and in some cases,
            emotional distress. The reasonable person standard is objective and considers what a hypothetical
            reasonable person would do in the defendant's position.
            """,
            """
            Contract formation requires mutual assent, consideration, and legal capacity. Mutual assent consists
            of an offer and acceptance. An offer is a manifestation of willingness to enter into a bargain, made
            in such a way that another person would be justified in understanding that assent would conclude the
            bargain. The offer must be sufficiently definite and certain in its terms. Acceptance is a manifestation
            of assent to the terms of the offer. Under the mirror image rule, acceptance must be on the exact terms
            of the offer, though the UCC has modified this for sales of goods. Consideration is a bargained-for
            exchange of something of legal value. It can be a promise, performance, or forbearance. Past consideration
            is generally not sufficient. A material breach occurs when the failure to perform substantially defeats
            the purpose of the contract, while a minor breach does not excuse the other party's performance but may
            give rise to damages.
            """,
            """
            Real property ownership includes the bundle of rights associated with land ownership: the right to
            exclude others, the right to use and enjoy the property, and the right to transfer ownership. Fee simple
            absolute is the most complete form of ownership, providing the owner with the full bundle of rights
            during life and the power to transfer the property at death. A warranty deed provides the grantee with
            certain covenants or warranties about the title, including the covenant of seisin, the covenant of right
            to convey, the covenant against encumbrances, the covenant of quiet enjoyment, and the covenant of warranty.
            A quitclaim deed transfers only whatever interest the grantor has, if any, without any warranties about
            the quality of title. Recording statutes protect bona fide purchasers by providing constructive notice
            of property interests through the public records. Joint tenancy includes the right of survivorship,
            meaning that when one joint tenant dies, the surviving joint tenants automatically acquire the deceased
            tenant's interest.
            """,
            """
            Proximate cause in tort law serves to limit liability to consequences that bear a reasonable relationship
            to the defendant's negligent conduct. The test for proximate cause varies by jurisdiction, but commonly
            involves foreseeability analysis. Under the Palsgraf rule, liability extends only to those plaintiffs
            who are within the zone of foreseeable danger created by the defendant's negligent act. The Wagon Mound
            test focuses on whether the type of harm that occurred was reasonably foreseeable, even if the exact
            manner of occurrence was not. Intervening causes can break the chain of proximate causation if they are
            unforeseeable and sufficient to produce the harm independently. However, foreseeable intervening causes
            do not relieve the original tortfeasor of liability. The but-for test establishes factual causation by
            asking whether the harm would have occurred but for the defendant's conduct. In cases of multiple
            sufficient causes, courts may apply the substantial factor test instead.
            """,
            """
            The burden of proof in civil cases is typically preponderance of the evidence, meaning the plaintiff
            must show that it is more likely than not that the defendant is liable. This standard requires proof
            that tips the scales slightly in favor of the plaintiff, often described as 51% certainty. In criminal
            cases, the prosecution must prove guilt beyond a reasonable doubt, which is a much higher standard
            requiring proof to a moral certainty that leaves no reasonable doubt about the defendant's guilt. Some
            civil cases require clear and convincing evidence, an intermediate standard between preponderance and
            beyond reasonable doubt. This standard is used in cases involving fraud, termination of parental rights,
            and civil commitment proceedings. The allocation of the burden of proof can be outcome-determinative,
            as the party bearing the burden loses if the evidence is insufficient to meet the required standard.
            """,
            """
            Corporate governance involves the relationships between a corporation's shareholders, board of directors,
            and management. The board of directors has the authority and responsibility to manage the corporation's
            business and affairs. Directors owe fiduciary duties to the corporation and its shareholders, including
            the duty of care and the duty of loyalty. The duty of care requires directors to act with the care that
            an ordinarily prudent person would exercise in similar circumstances. The business judgment rule protects
            directors from liability for decisions made in good faith, with due care, and in the honest belief that
            the action is in the corporation's best interests. The duty of loyalty requires directors to act in the
            corporation's best interests and prohibits self-dealing transactions unless properly approved. Shareholders
            have limited liability, meaning they are not personally responsible for corporate debts and obligations
            beyond their investment in the corporation. Piercing the corporate veil is an exceptional remedy that
            holds shareholders personally liable when the corporate form is used to perpetrate fraud or injustice.
            """,
            """
            Evidence must be relevant to be admissible in court proceedings. Relevant evidence is evidence having
            any tendency to make the existence of any fact that is of consequence to the determination of the action
            more probable or less probable than it would be without the evidence. However, relevant evidence may be
            excluded if its probative value is substantially outweighed by the danger of unfair prejudice, confusion
            of the issues, misleading the jury, undue delay, waste of time, or needless presentation of cumulative
            evidence. Hearsay is an out-of-court statement offered in evidence to prove the truth of the matter
            asserted and is generally inadmissible unless it falls within a recognized exception. Common hearsay
            exceptions include present sense impressions, excited utterances, statements of then-existing mental or
            physical condition, business records, and public records. The confrontation clause requires that testimonial
            hearsay be excluded in criminal cases unless the declarant is unavailable and the defendant had a prior
            opportunity for cross-examination.
            """,
            """
            Family law encompasses marriage, divorce, child custody, and support obligations. Marriage creates a
            legal relationship with rights and responsibilities including property rights, inheritance rights, and
            decision-making authority. Divorce or dissolution terminates the marriage relationship and typically
            involves division of marital property and determination of spousal support. Community property states
            presume that property acquired during marriage belongs equally to both spouses, while common law property
            states follow principles of equitable distribution. Child custody determinations are made based on the
            best interests of the child standard, which considers factors such as the child's physical and emotional
            needs, the stability of each parent's home environment, the child's relationship with each parent, and
            any history of domestic violence. Joint custody arrangements are increasingly common and can involve
            joint legal custody, joint physical custody, or both. Child support obligations are typically determined
            by state guidelines that consider both parents' income and the amount of time the child spends with each parent.
            """,
            """
            Intellectual property protection includes patents, copyrights, trademarks, and trade secrets. Patents
            protect inventions that are novel, non-obvious, and useful, granting the patent holder the exclusive
            right to make, use, and sell the invention for a limited period. Copyright protects original works of
            authorship fixed in a tangible medium, including literary works, musical compositions, artistic works,
            and software. Copyright protection arises automatically upon creation and generally lasts for the life
            of the author plus 70 years. Trademarks protect words, phrases, symbols, or designs that identify and
            distinguish the source of goods or services. Trademark rights can last indefinitely as long as the mark
            continues to be used in commerce and is properly maintained. Trade secrets protect confidential business
            information that derives economic value from being secret. Unlike other forms of intellectual property,
            trade secret protection can last indefinitely but is lost if the information becomes publicly known or
            is independently discovered.
            """,
            """
            Administrative law governs the creation and operation of administrative agencies and their relationship
            to the legislative, executive, and judicial branches of government. Agencies exercise legislative power
            through rulemaking, executive power through enforcement actions, and judicial power through adjudication.
            The Administrative Procedure Act establishes procedures for agency rulemaking and adjudication and provides
            for judicial review of agency actions. Notice-and-comment rulemaking requires agencies to publish proposed
            rules in the Federal Register, provide an opportunity for public comment, and consider comments before
            issuing final rules. Courts review agency actions under different standards of review depending on the
            type of agency action. Questions of law are reviewed de novo, factual findings are reviewed under the
            substantial evidence standard, and policy decisions are reviewed for abuse of discretion. The doctrine
            of primary jurisdiction requires courts to defer to agency expertise on matters within the agency's
            specialized knowledge and statutory authority.
            """
        ]

    print(f"Total modern legal documents collected: {len(legal_docs)}")

    # Add brief descriptions to help with training
    enhanced_docs = []
    topics = ["negligence law", "contract law", "property law", "causation", "burden of proof",
              "corporate law", "evidence law", "family law", "intellectual property", "administrative law"]

    for i, doc in enumerate(legal_docs):
        if i < len(topics):
            enhanced_doc = f"Legal topic: {topics[i]}. {doc}"
            enhanced_docs.append(enhanced_doc)
        else:
            enhanced_docs.append(doc)

    return enhanced_docs

In [None]:
# ============================================================================
# PART 2: CREATE MLM-STYLE TRAINING DATA FOR T5
# ============================================================================

def create_span_corruption_data(texts, mask_probability=0.15):
    """
    Create T5-style span corruption training data
    T5 uses span corruption instead of traditional MLM
    """

    training_examples = []

    for text in texts:
        # Clean and split into sentences
        sentences = [s.strip() for s in text.split('.') if s.strip()]

        for sentence in sentences:
            if len(sentence.split()) < 5:  # Skip very short sentences
                continue

            words = sentence.split()

            # Create multiple corrupted versions of each sentence
            for _ in range(3):  # Generate 3 versions per sentence
                corrupted_sentence, target_spans = corrupt_sentence(words, mask_probability)

                if target_spans:  # Only add if we have something to predict
                    training_examples.append({
                        "input_text": f"Fill in the missing legal terms: {corrupted_sentence}",
                        "target_text": " ".join(target_spans)
                    })

    return training_examples

def corrupt_sentence(words, mask_probability=0.15):
    """
    Corrupt a sentence by masking spans (T5-style)
    """
    if len(words) < 3:
        return None, None

    # Identify spans to mask
    corrupted_words = words.copy()
    target_spans = []

    i = 0
    span_id = 0

    while i < len(words):
        if random.random() < mask_probability:
            # Start a span
            span_start = i
            span_length = random.randint(1, min(3, len(words) - i))  # 1-3 words

            # Extract the span
            span_words = words[span_start:span_start + span_length]
            target_spans.extend(span_words)

            # Replace with special token
            corrupted_words[span_start:span_start + span_length] = [f"<extra_id_{span_id}>"]

            span_id += 1
            i += span_length
        else:
            i += 1

    corrupted_sentence = " ".join(corrupted_words)
    return corrupted_sentence, target_spans

In [None]:
# ============================================================================
# PART 3: ALTERNATIVE: FILL-IN-THE-BLANK STYLE
# ============================================================================

def create_legal_fill_in_blanks(texts):
    """
    Create fill-in-the-blank style training data
    More interpretable than span corruption
    """

    training_examples = []

    # Legal terms to specifically target for masking
    legal_terms = [
        "contract", "consideration", "negligence", "damages", "liability",
        "plaintiff", "defendant", "jurisdiction", "statute", "precedent",
        "tort", "breach", "warranty", "deed", "property", "constitutional",
        "evidence", "burden", "proof", "reasonable", "standard", "care",
        "fiduciary", "corporation", "shareholders", "directors", "legal",
        "court", "judge", "jury", "trial", "appeal", "ruling", "decision"
    ]

    for text in texts:
        sentences = [s.strip() for s in text.split('.') if s.strip()]

        for sentence in sentences:
            words = sentence.split()

            # Find legal terms in the sentence
            legal_words_in_sentence = []
            for i, word in enumerate(words):
                clean_word = re.sub(r'[^\w]', '', word.lower())
                if clean_word in legal_terms:
                    legal_words_in_sentence.append((i, word, clean_word))

            # Create blanks for legal terms
            if legal_words_in_sentence:
                for word_idx, original_word, clean_word in legal_words_in_sentence:
                    # Create a version with this word blanked
                    blanked_words = words.copy()
                    blanked_words[word_idx] = "______"

                    blanked_sentence = " ".join(blanked_words)

                    training_examples.append({
                        "input_text": f"Fill in the blank with the correct legal term: {blanked_sentence}",
                        "target_text": original_word
                    })

    return training_examples


In [None]:
# ============================================================================
# PART 4: FINE-TUNING PIPELINE
# ============================================================================

def fine_tune_t5_on_legal_mlm():
    """Complete MLM-style fine-tuning pipeline for T5"""

    # Step 1: Get legal documents
    print("Loading legal documents...")
    legal_docs = download_legal_documents()

    # Step 2: Create training data (choose one approach)
    print("Creating MLM-style training data...")

    # Option A: Span corruption (more like traditional T5)
    training_data = create_span_corruption_data(legal_docs, mask_probability=0.15)

    # Option B: Fill-in-the-blank (more interpretable)
    #training_data = create_legal_fill_in_blanks(legal_docs)

    print(f"Created {len(training_data)} training examples")


    # Step 3: Create dataset
    print("Creating dataset...")
    dataset = Dataset.from_list(training_data)
    dataset = dataset.train_test_split(test_size=0.2, seed=42)

    # Step 4: Load model and tokenizer
    print("Loading T5 model...")
    model_name = "google/flan-t5-small"
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer = T5Tokenizer.from_pretrained(model_name)

    # Step 5: Define tokenization function and tokenize
    def tokenize_function(examples):
        model_inputs = tokenizer(
            examples["input_text"],
            max_length=512,
            truncation=True,
            padding=False
        )

        labels = tokenizer(
            examples["target_text"],
            max_length=128,
            truncation=True,
            padding=False
        )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset["train"].column_names
    )

    # Step 6: Training arguments
    training_args = TrainingArguments(
        output_dir="./t5-legal-mlm",
        num_train_epochs=3,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=100,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=10,
        eval_strategy="steps",
        eval_steps=50,
        save_steps=100,
        save_total_limit=2,
        load_best_model_at_end=True,
        report_to=None,
    )

    # Step 7: Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True
    )

    # Step 8: Create trainer
    print("Creating trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Step 9: Train
    print("Starting MLM-style training...")
    trainer.train()

    # Step 10: Save
    print("Saving model...")
    trainer.save_model("./t5-legal-mlm-final")
    tokenizer.save_pretrained("./t5-legal-mlm-final")

    return model, tokenizer

In [None]:
def test_mlm_improvement():
    """Test improvement on legal understanding"""

    # Load models
    original_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
    finetuned_model = T5ForConditionalGeneration.from_pretrained("./t5-legal-mlm-final")
    tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")

    def generate_response(model, question):
        formatted_prompt = f"Answer the following legal question: {question}"
        inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=512, truncation=True)

        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_length=200,
                temperature=0.7,
                do_sample=True
            )

        return tokenizer.decode(outputs[0], skip_special_tokens=True)

    print("TESTING MLM-TRAINED MODEL vs ORIGINAL")
    print("=" * 60)

    for i, question in enumerate(test_questions, 1):
        print(f"\nQuestion {i}: {question}")
        print("-" * 50)

        original_response = generate_response(original_model, question)
        mlm_response = generate_response(finetuned_model, question)

        print(f"ORIGINAL: {original_response}")
        print(f"MLM-TRAINED: {mlm_response}")
        print("=" * 60)

In [None]:
if __name__ == "__main__":
    # Run MLM-style fine-tuning
    model, tokenizer = fine_tune_t5_on_legal_mlm()

    # Test improvements
    test_mlm_improvement()

    print("Training complete!")

Loading legal documents...
Downloading relevant legal documents...
Using modern legal training examples...
Total modern legal documents collected: 10
Creating MLM-style training data...
Created 211 training examples
Creating dataset...
Loading T5 model...
Tokenizing dataset...


Map:   0%|          | 0/168 [00:00<?, ? examples/s]

Map:   0%|          | 0/43 [00:00<?, ? examples/s]

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  trainer = Trainer(


Creating trainer...
Starting MLM-style training...




Step,Training Loss,Validation Loss
50,5.2921,4.972179
100,4.8027,4.509517


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Saving model...
TESTING MLM-TRAINED MODEL vs ORIGINAL

Question 1: What is the legal standard for establishing proximate cause in tort law?
--------------------------------------------------
ORIGINAL: 
MLM-TRAINED: proximate cause

Question 2: Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.
--------------------------------------------------
ORIGINAL: a quitclaim deed is a deed that is a non-deed.
MLM-TRAINED: In the case of a disclaim deed, the deed is generally not a quitclaim deed to a real estate transaction

Question 3: What constitutes a material breach of contract versus a minor breach?
--------------------------------------------------
ORIGINAL: contract
MLM-TRAINED: contract breach

Question 4: Define the elements required to prove negligence in a personal injury case.
--------------------------------------------------
ORIGINAL: a).
MLM-TRAINED: a person who is negligent and negligent

Question 5: What is the doctrine of responde

In [None]:
# ============================================================================
# PART 6: LEGAL Q&A DATASETS
# ============================================================================

def load_legal_qa_datasets():
    """Load legal Q&A datasets from online sources"""

    print("Loading legal Q&A datasets...")
    all_qa = []

    # Option 1: Try LegalQAEval dataset (should work)
    print("Attempting to load LegalQAEval dataset...")
    try:
        from datasets import load_dataset
        dataset = load_dataset("isaacus/LegalQAEval", split="test")

        count = 0
        for example in dataset:
            if count >= 50:  # Limit to 50 examples
                break

            question = example.get('question', '')
            context = example.get('context', '')
            answers = example.get('answers', {})

            if question and answers and 'text' in answers and len(answers['text']) > 0:
                answer = answers['text'][0] if isinstance(answers['text'], list) else str(answers['text'])

                if len(question) > 10 and len(answer) > 20:
                    all_qa.append({
                        "input_text": f"Answer this legal question: {question}",
                        "target_text": answer[:400]  # Limit answer length
                    })
                    count += 1

        print(f"✓ Loaded {count} examples from LegalQAEval")

    except Exception as e:
        print(f"✗ Failed to load LegalQAEval: {e}")

    # Option 2: Try WikiQA (general Q&A that might include some legal)
    print("Attempting to load WikiQA dataset...")
    try:
        dataset = load_dataset("microsoft/wiki_qa", split="train")

        count = 0
        for example in dataset:
            if count >= 30:  # Limit to 30 examples
                break

            question = example.get('question', '')
            answer = example.get('answer', '')
            label = example.get('label', 0)

            # Only use positive examples with legal-sounding content
            if (label == 1 and question and answer and
                len(question) > 10 and len(answer) > 30 and
                any(word in question.lower() + answer.lower() for word in
                    ['law', 'legal', 'court', 'judge', 'rule', 'regulation', 'statute', 'contract'])):

                all_qa.append({
                    "input_text": f"Answer this legal question: {question}",
                    "target_text": answer[:400]
                })
                count += 1

        print(f"✓ Loaded {count} legal-related examples from WikiQA")

    except Exception as e:
        print(f"✗ Failed to load WikiQA: {e}")

    # Option 3: Enhanced synthetic legal Q&A (main source)
    print("Creating enhanced synthetic legal Q&A pairs...")
    synthetic_qa = create_enhanced_synthetic_legal_qa()
    all_qa.extend(synthetic_qa)

    print(f"Total Q&A examples: {len(all_qa)}")
    return all_qa

def create_enhanced_synthetic_legal_qa():
    """Create a larger set of synthetic legal Q&A pairs that don't overlap with test questions"""

    print("Creating enhanced synthetic legal Q&A pairs...")

    # Expanded synthetic pairs covering different legal areas
    synthetic_pairs = [
        # Contract Law (different from test questions)
        {
            "input_text": "Answer this legal question: What is the statute of frauds?",
            "target_text": "The statute of frauds requires certain types of contracts to be in writing to be enforceable, including contracts for the sale of land, contracts that cannot be performed within one year, contracts for the sale of goods over a certain value, and suretyship agreements."
        },
        {
            "input_text": "Answer this legal question: What is the parol evidence rule?",
            "target_text": "The parol evidence rule prohibits the introduction of extrinsic evidence to contradict, vary, or add to the terms of a written contract that appears to be whole and complete. It aims to preserve the integrity of written agreements."
        },
        {
            "input_text": "Answer this legal question: What is specific performance?",
            "target_text": "Specific performance is an equitable remedy that requires a party to perform exactly what they promised in a contract, typically used when monetary damages would be inadequate, such as in real estate transactions or unique goods."
        },
        {
            "input_text": "Answer this legal question: What is unconscionability in contract law?",
            "target_text": "Unconscionability refers to a contract or contract term that is so unfair or oppressive that it shocks the conscience of the court. It can be procedural (unfair bargaining process) or substantive (unfair terms)."
        },

        # Tort Law (different from test questions)
        {
            "input_text": "Answer this legal question: What is the difference between assault and battery?",
            "target_text": "Assault is the intentional creation of a reasonable apprehension of imminent harmful or offensive contact, while battery is the intentional harmful or offensive touching of another person. Assault requires fear; battery requires contact."
        },
        {
            "input_text": "Answer this legal question: What is strict liability?",
            "target_text": "Strict liability imposes legal responsibility for damages or injuries without requiring proof of negligence or intent to harm. It commonly applies to abnormally dangerous activities, defective products, and ownership of certain animals."
        },
        {
            "input_text": "Answer this legal question: What is intentional infliction of emotional distress?",
            "target_text": "Intentional infliction of emotional distress requires extreme and outrageous conduct that intentionally or recklessly causes severe emotional distress to another person. The conduct must exceed all bounds of decency."
        },

        # Criminal Law
        {
            "input_text": "Answer this legal question: What is the difference between felonies and misdemeanors?",
            "target_text": "Felonies are serious crimes typically punishable by imprisonment for more than one year or death, while misdemeanors are less serious offenses usually punishable by fines or imprisonment for less than one year."
        },
        {
            "input_text": "Answer this legal question: What is habeas corpus?",
            "target_text": "Habeas corpus is a legal action through which a person can seek relief from unlawful detention. It requires the person detaining someone to justify the detention before a court and is considered a fundamental safeguard of individual liberty."
        },
        {
            "input_text": "Answer this legal question: What is the Miranda warning?",
            "target_text": "The Miranda warning is a notification given to suspects in police custody before interrogation, informing them of their right to remain silent, that anything they say can be used against them, and their right to an attorney."
        },
        {
            "input_text": "Answer this legal question: What is double jeopardy?",
            "target_text": "Double jeopardy is a constitutional protection that prevents a person from being prosecuted twice for the same offense by the same sovereign after acquittal or conviction. It's guaranteed by the Fifth Amendment."
        },

        # Constitutional Law
        {
            "input_text": "Answer this legal question: What is due process?",
            "target_text": "Due process is the constitutional requirement that government actions depriving a person of life, liberty, or property must be fair and follow established legal procedures. It includes both procedural and substantive due process."
        },
        {
            "input_text": "Answer this legal question: What is equal protection under the law?",
            "target_text": "Equal protection requires that all persons in similar circumstances be treated alike by the law. It prohibits arbitrary discrimination and requires that government classifications be rationally related to legitimate government purposes."
        },

        # Evidence Law
        {
            "input_text": "Answer this legal question: What is the exclusionary rule?",
            "target_text": "The exclusionary rule prohibits the use of evidence obtained in violation of a defendant's constitutional rights, particularly evidence obtained through illegal searches and seizures in violation of the Fourth Amendment."
        },
        {
            "input_text": "Answer this legal question: What is attorney-client privilege?",
            "target_text": "Attorney-client privilege protects confidential communications between an attorney and client from disclosure, encouraging open and honest communication necessary for effective legal representation. The privilege belongs to the client."
        },
        {
            "input_text": "Answer this legal question: What is hearsay in evidence law?",
            "target_text": "Hearsay is an out-of-court statement offered to prove the truth of the matter asserted. It is generally inadmissible unless it falls within a recognized exception, such as excited utterances or business records."
        },

        # Property Law (different from test questions)
        {
            "input_text": "Answer this legal question: What is eminent domain?",
            "target_text": "Eminent domain is the power of the government to take private property for public use, provided that just compensation is paid to the property owner. It's based on the Fifth Amendment's Takings Clause."
        },
        {
            "input_text": "Answer this legal question: What is a life estate?",
            "target_text": "A life estate is a type of property ownership that lasts for the duration of a person's life. The holder has the right to use and occupy the property during their lifetime, but cannot transfer full ownership to others."
        },
        {
            "input_text": "Answer this legal question: What is a restrictive covenant?",
            "target_text": "A restrictive covenant is a legal agreement that limits how property can be used. It typically runs with the land and binds future owners, commonly used in residential developments to maintain property values and neighborhood character."
        },

        # Business Law
        {
            "input_text": "Answer this legal question: What is a limited liability company (LLC)?",
            "target_text": "An LLC is a business structure that combines elements of corporations and partnerships. It provides limited liability protection for owners (called members) while offering flexibility in management structure and tax treatment."
        },
        {
            "input_text": "Answer this legal question: What is the business judgment rule?",
            "target_text": "The business judgment rule protects corporate directors from liability for business decisions made in good faith, with due care, and in the honest belief that the action is in the corporation's best interests."
        },

        # Procedure
        {
            "input_text": "Answer this legal question: What is a motion for summary judgment?",
            "target_text": "A motion for summary judgment asks the court to rule in favor of one party without a trial, arguing that there are no genuine disputes of material fact and the moving party is entitled to judgment as a matter of law."
        },
        {
            "input_text": "Answer this legal question: What is voir dire?",
            "target_text": "Voir dire is the process of questioning prospective jurors to determine their qualifications and suitability to serve on a jury. It allows attorneys and judges to identify and remove biased or unsuitable jurors."
        },

        # Family Law (different from test questions)
        {
            "input_text": "Answer this legal question: What is community property?",
            "target_text": "Community property is a system where most property acquired during marriage is owned equally by both spouses, regardless of who earned it. It contrasts with common law property systems used in most states."
        },
        {
            "input_text": "Answer this legal question: What is child support?",
            "target_text": "Child support is financial assistance paid by a non-custodial parent to help cover the costs of raising a child. It typically covers basic needs like housing, food, clothing, medical care, and education expenses."
        },

        # Additional diverse questions
        {
            "input_text": "Answer this legal question: What is bankruptcy?",
            "target_text": "Bankruptcy is a legal process that allows individuals or businesses unable to pay their debts to seek relief from some or all of their obligations. It's governed by federal law and provides a fresh start for debtors."
        },
        {
            "input_text": "Answer this legal question: What is intellectual property?",
            "target_text": "Intellectual property refers to creations of the mind protected by law, including patents (inventions), copyrights (creative works), trademarks (brand identifiers), and trade secrets (confidential business information)."
        },
        {
            "input_text": "Answer this legal question: What is workers' compensation?",
            "target_text": "Workers' compensation is a system of insurance that provides medical benefits and wage replacement to employees injured in the course of employment, in exchange for mandatory relinquishment of the employee's right to sue the employer for negligence."
        }
    ]

    return synthetic_pairs

def fine_tune_qa_model():
    """Train model using real legal Q&A data (no test question overlap)"""

    print("=" * 60)
    print("TRAINING Q&A MODEL ON REAL LEGAL DATA")
    print("=" * 60)

    # Load real legal Q&A data
    training_data = load_legal_qa_datasets()
    print(f"Loaded {len(training_data)} Q&A training examples")

    if len(training_data) == 0:
        print("No training data available! Exiting...")
        return None, None

    # Create dataset
    from datasets import Dataset
    dataset = Dataset.from_list(training_data)
    dataset = dataset.train_test_split(test_size=0.2, seed=42)

    # Load model and tokenizer
    print("Loading T5 model...")
    model_name = "google/flan-t5-small"
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    tokenizer = T5Tokenizer.from_pretrained(model_name)

    # Tokenize
    def tokenize_function(examples):
        model_inputs = tokenizer(
            examples["input_text"],
            max_length=512,
            truncation=True,
            padding=False
        )

        labels = tokenizer(
            examples["target_text"],
            max_length=256,
            truncation=True,
            padding=False
        )

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs

    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset["train"].column_names
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./t5-legal-qa-real",
        num_train_epochs=5,  # More epochs for Q&A
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        warmup_steps=30,
        weight_decay=0.01,
        logging_dir="./logs",
        logging_steps=5,
        eval_strategy="steps",
        eval_steps=15,
        save_steps=30,
        save_total_limit=2,
        load_best_model_at_end=True,
        report_to=[],
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer=tokenizer,
        model=model,
        padding=True
    )

    # Create trainer
    print("Creating trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Train
    print("Starting Q&A training...")
    trainer.train()

    # Save
    print("Saving Q&A model...")
    trainer.save_model("./t5-legal-qa-real-final")
    tokenizer.save_pretrained("./t5-legal-qa-real-final")

    print("✓ Real Q&A model training complete!")
    return model, tokenizer

In [None]:
# ============================================================================
# PART 7: COMPREHENSIVE COMPARISON
# ============================================================================

def generate_response(model, tokenizer, question, max_length=200):
    """Generate response from model"""
    formatted_prompt = f"Answer the following legal question: {question}"
    inputs = tokenizer(formatted_prompt, return_tensors="pt", max_length=512, truncation=True)

    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def compare_all_three_models():
    """Compare original, span corruption, and Q&A models on all 15 questions"""

    print("=" * 80)
    print("COMPREHENSIVE 3-MODEL COMPARISON")
    print("=" * 80)

    # Load all models
    print("Loading models...")

    # Original model
    original_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
    original_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")

    # Span corruption model (from your previous training)
    span_model = T5ForConditionalGeneration.from_pretrained("./t5-legal-mlm-final")
    span_tokenizer = T5Tokenizer.from_pretrained("./t5-legal-mlm-final")

    # Q&A model (newly trained on real data)
    qa_model = T5ForConditionalGeneration.from_pretrained("./t5-legal-qa-real-final")
    qa_tokenizer = T5Tokenizer.from_pretrained("./t5-legal-qa-real-final")

    print(f"\nTesting all 3 models on {len(test_questions)} legal questions...")
    print("=" * 80)

    all_responses = {
        "questions": test_questions,
        "original": [],
        "span_corruption": [],
        "qa_finetuned": []
    }

    # Test each question
    for i, question in enumerate(test_questions, 1):
        print(f"\nQuestion {i}/{len(test_questions)}: {question}")
        print("-" * 60)

        # Generate responses
        original_response = generate_response(original_model, original_tokenizer, question)
        span_response = generate_response(span_model, span_tokenizer, question)
        qa_response = generate_response(qa_model, qa_tokenizer, question)

        # Store responses
        all_responses["original"].append(original_response)
        all_responses["span_corruption"].append(span_response)
        all_responses["qa_finetuned"].append(qa_response)

        # Display responses
        print(f"ORIGINAL: {original_response}")
        print(f"SPAN CORRUPTION: {span_response}")
        print(f"Q&A FINE-TUNED (REAL DATA): {qa_response}")
        print("=" * 80)

    # Save results
    import json
    with open("three_model_comparison_real_data.json", "w") as f:
        json.dump(all_responses, f, indent=2)

    print(f"\n✓ All responses saved to 'three_model_comparison_real_data.json'")
    print(f"✓ Model trained on legal Q&A data!")

    return all_responses

def create_claude_evaluation_prompt():
    """Create a prompt for Claude to evaluate all three models"""

    import json
    import os

    # Check if comparison file exists
    if not os.path.exists("three_model_comparison_real_data.json"):
        print("❌ Comparison file not found!")
        print("You need to run compare_all_three_models() first to generate the comparison data.")
        print("Run this command: all_responses = compare_all_three_models()")
        return None

    with open("three_model_comparison_real_data.json", "r") as f:
        data = json.load(f)

    prompt = """Please evaluate these three AI responses to legal questions. For each question, rank the responses from best (1) to worst (3):

Model A = Original FLAN-T5 (baseline)
Model B = Span Corruption Fine-tuned
Model C = Q&A Fine-tuned (trained on real legal data)

Rate based on legal accuracy, completeness, and clarity.

"""

    # Use first 5 questions for manageable evaluation
    for i in range(min(5, len(data["questions"]))):
        question = data["questions"][i]
        prompt += f"QUESTION {i+1}: {question}\n\n"
        prompt += f"Model A (Original): {data['original'][i]}\n\n"
        prompt += f"Model B (Span Corruption): {data['span_corruption'][i]}\n\n"
        prompt += f"Model C (Q&A Fine-tuned): {data['qa_finetuned'][i]}\n\n"
        prompt += f"Your ranking for Question {i+1} (1=best, 3=worst): A=_, B=_, C=_\n"
        prompt += "="*70 + "\n\n"

    with open("claude_evaluation_prompt.txt", "w") as f:
        f.write(prompt)

    print("✓ Claude evaluation prompt saved to 'claude_evaluation_prompt.txt'")
    print("Copy this prompt to Claude to get rankings for DPO training!")

    return prompt

In [None]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

if __name__ == "__main__":
    print("LEGAL MODEL TRAINING AND COMPARISON PIPELINE")
    print("=" * 80)

    # First, run your existing span corruption training if not done
    if not os.path.exists("./t5-legal-mlm-final"):
        print("Step 1: Training span corruption model...")
        model, tokenizer = fine_tune_t5_on_legal_mlm()
    else:
        print("Step 1: ✓ Span corruption model already exists")

    # Train the new Q&A model
    print("\nStep 2: Training Q&A model...")
    qa_model, qa_tokenizer = fine_tune_qa_model()

    # Compare all three models
    print("\nStep 3: Comparing all three models...")
    all_responses = compare_all_three_models()

    # Create Claude evaluation prompt
    #print("\nStep 4: Creating Claude evaluation prompt...")
    #create_claude_evaluation_prompt()

    print("\n" + "=" * 80)
    print("PIPELINE COMPLETE! You now have:")
    print("- Original FLAN-T5 baseline")
    print("- Span corruption fine-tuned model")
    print("- Q&A fine-tuned model")
    print("- Comparison results for all 15 questions")
    print("=" * 80)

LEGAL MODEL TRAINING AND COMPARISON PIPELINE
Step 1: ✓ Span corruption model already exists

Step 2: Training Q&A model...
TRAINING Q&A MODEL ON REAL LEGAL DATA
Loading legal Q&A datasets...
Attempting to load LegalQAEval dataset...
✓ Loaded 0 examples from LegalQAEval
Attempting to load WikiQA dataset...
✓ Loaded 30 legal-related examples from WikiQA
Creating enhanced synthetic legal Q&A pairs...
Creating enhanced synthetic legal Q&A pairs...
Total Q&A examples: 58
Loaded 58 Q&A training examples
Loading T5 model...
Tokenizing dataset...


Map:   0%|          | 0/46 [00:00<?, ? examples/s]

Map:   0%|          | 0/12 [00:00<?, ? examples/s]

Creating trainer...
Starting Q&A training...


  trainer = Trainer(


Step,Training Loss,Validation Loss
15,3.3911,3.051852
30,3.2152,2.992949
45,3.1272,2.986279
60,2.8665,2.983747


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


Saving Q&A model...
✓ Real Q&A model training complete!

Step 3: Comparing all three models...
COMPREHENSIVE 3-MODEL COMPARISON
Loading models...

Testing all 3 models on 15 legal questions...

Question 1/15: What is the legal standard for establishing proximate cause in tort law?
------------------------------------------------------------
ORIGINAL: proximate grounds
SPAN CORRUPTION: a court case
Q&A FINE-TUNED (REAL DATA): Prohibition of proximate cause in tort law requires proximate cause to be found in tort law.

Question 2/15: Explain the difference between a warranty deed and a quitclaim deed in real estate transactions.
------------------------------------------------------------
ORIGINAL: Disclaims deeds are a form of legal notice with a resemblance to a disclaim statement.
SPAN CORRUPTION: If the deed is a quitclaim deed, it is a quitclaim deed.
Q&A FINE-TUNED (REAL DATA): A warranty deed is a deed which may either be a guarantee of money or a quitclaim deed in real estate tra