<a href="https://colab.research.google.com/github/chandini2595/Stanford_Hackathon_AI_ArgumentCounter/blob/main/Solution/Lawgorithms_StanfordHackathon_HighlightTokens.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 🧠 Install dependencies

!pip install -q sentence-transformers gradio transformers


In [None]:
import gradio as gr
import json
import torch
import numpy as np
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModel

# Sentence-level encoder for matching
model = SentenceTransformer("all-MiniLM-L6-v2")

# Token-level model for token embeddings
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
base_model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

In [None]:
brief_pairs = []
moving_brief_index = {}

def combine_argument_text(arg):
    return arg["heading"] + ". " + arg["content"]

def get_full_texts(pair):
    full_moving = " ".join([combine_argument_text(arg) for arg in pair["moving_brief"]["brief_arguments"]])
    full_response = " ".join([combine_argument_text(arg) for arg in pair["response_brief"]["brief_arguments"]])
    return full_moving, full_response


In [None]:
def handle_file_upload(file_obj):
    global brief_pairs, moving_brief_index
    brief_pairs = json.load(open(file_obj.name))

    moving_brief_index = {}
    for pair in brief_pairs:
        brief_id = pair["moving_brief"]["brief_id"]
        if brief_id not in moving_brief_index:
            moving_brief_index[brief_id] = {
                "pair": pair,
                "headings": [arg["heading"] for arg in pair["moving_brief"]["brief_arguments"]]
            }

    return gr.update(choices=list(moving_brief_index.keys()), value=None), gr.update(choices=[], value=None), "", "", ""


In [None]:
def get_headings_for_brief(brief_id):
    if brief_id not in moving_brief_index:
        return gr.update(choices=[], value=None)
    return gr.update(choices=moving_brief_index[brief_id]["headings"], value=None)

In [None]:
def get_token_embeddings(text):
    tokens = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    with torch.no_grad():
        output = base_model(**tokens)
    return output.last_hidden_state.squeeze(0), tokens["input_ids"].squeeze(0)

def get_top_contributing_words(moving_text, response_text, top_n=5):
    try:
        mov_emb, mov_ids = get_token_embeddings(moving_text)
        res_emb, res_ids = get_token_embeddings(response_text)
        sim_matrix = util.pytorch_cos_sim(mov_emb, res_emb).cpu().numpy()
        top_scores = np.max(sim_matrix, axis=1)
        top_indices = top_scores.argsort()[::-1][:top_n]
        tokens = tokenizer.convert_ids_to_tokens(mov_ids)
        return [tokens[i] for i in top_indices if tokens[i] not in tokenizer.all_special_tokens]
    except Exception as e:
        print("🔴 Error in get_top_contributing_words():", e)
        return ["❌ Token match error"]

def highlight_top_tokens_in_text(text, top_tokens):
    for tok in top_tokens:
        if tok.startswith("##"): tok = tok[2:]
        if len(tok) > 2 and tok.lower() in text.lower():
            text = text.replace(tok, f"<mark style='background-color: #ffff99'>{tok}</mark>")
    return text


In [None]:
def render_html_table(data):
    if not data or len(data) < 2:
        return "<p>No results.</p>"

    html = "<div style='overflow-x:auto'><table style='width:100%; border-collapse: collapse;'>"
    html += "<thead><tr>"
    for col in data[0]:
        html += f"<th style='border: 1px solid #ccc; padding: 8px; background-color: #f0f0f0; text-align: left'>{col}</th>"
    html += "</tr></thead><tbody>"

    for row in data[1:]:
        html += "<tr>"
        for cell in row:
            html += f"<td style='border: 1px solid #ddd; padding: 8px; vertical-align: top; word-break: break-word'>{cell}</td>"
        html += "</tr>"

    html += "</tbody></table></div>"
    return html


In [None]:
def match_counter_arguments(brief_id, selected_heading, top_k=5):
    if brief_id not in moving_brief_index:
        return "<p>Error: Brief not found.</p>", "", ""

    pair = moving_brief_index[brief_id]["pair"]
    moving_args = pair["moving_brief"]["brief_arguments"]
    response_args = pair["response_brief"]["brief_arguments"]

    moving_index = next((i for i, arg in enumerate(moving_args) if arg["heading"] == selected_heading), None)
    if moving_index is None:
        return "<p>Error: Heading not found.</p>", "", ""

    moving_text = combine_argument_text(moving_args[moving_index]).strip()
    response_texts = [combine_argument_text(arg).strip() for arg in response_args]

    try:
        moving_emb = model.encode([moving_text], convert_to_tensor=True)
        response_emb = model.encode(response_texts, convert_to_tensor=True)
        sim_scores = util.pytorch_cos_sim(moving_emb, response_emb)[0].cpu().numpy()
    except Exception as e:
        return f"<p>Error computing similarity: {e}</p>", "", ""

    top_indices = sim_scores.argsort()[::-1][:top_k]

    result_table = [["#", "Response Brief ID", "Heading", "Match %", "Excerpt", "Top Tokens"]]

    try:
        full_moving_text, full_response_text = get_full_texts(pair)
        # full_moving_text = full_moving_text[:2000]
        # full_response_text = full_response_text[:2000]
        top_tokens = get_top_contributing_words(full_moving_text, full_response_text)
        highlighted_moving = highlight_top_tokens_in_text(full_moving_text, top_tokens)
        highlighted_response = highlight_top_tokens_in_text(full_response_text, top_tokens)
    except:
        highlighted_moving = "<i>Error in highlight</i>"
        highlighted_response = "<i>Error in highlight</i>"
        top_tokens = ["❌"]

    for i, idx in enumerate(top_indices):
        resp_arg = response_args[idx]
        excerpt = resp_arg["content"].replace("\n", " ").strip()[:300] + "..."
        result_table.append([
            str(i + 1),
            pair["response_brief"]["brief_id"],
            resp_arg["heading"],
            f"{sim_scores[idx]*100:.2f}%",
            excerpt,
            ", ".join(top_tokens)
        ])

    return render_html_table(result_table), highlighted_moving, highlighted_response


In [None]:
with gr.Blocks() as demo:
    gr.Markdown("## ⚖️ Upload & Match Legal Brief Arguments with Explainability")
    gr.Markdown("Upload `brief_pairs.json`, select a moving brief and heading. View matched counter-arguments with <mark>highlighted top tokens</mark>.")

    file_upload = gr.File(label="📁 Upload `brief_pairs.json`")

    with gr.Row():
        dropdown_brief = gr.Dropdown(choices=[], label="📂 Select Moving Brief ID")
        dropdown_heading = gr.Dropdown(choices=[], label="📘 Select Argument Heading")

    output_table_html = gr.HTML(label="📊 Top Counter-Argument Matches")
    highlighted_moving_md = gr.HTML(label="🧾 Moving Brief (Highlighted)")
    highlighted_response_md = gr.HTML(label="📕 Response Brief (Highlighted)")

    file_upload.change(fn=handle_file_upload, inputs=file_upload,
                       outputs=[dropdown_brief, dropdown_heading, output_table_html, highlighted_moving_md, highlighted_response_md])

    dropdown_brief.change(fn=get_headings_for_brief, inputs=dropdown_brief, outputs=dropdown_heading)

    dropdown_heading.change(fn=match_counter_arguments,
                            inputs=[dropdown_brief, dropdown_heading],
                            outputs=[output_table_html, highlighted_moving_md, highlighted_response_md])

demo.launch()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://2db0518797b901b325.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


