# Relation extraction with LLMs

In this homework, we will explore the challenges and affordances of using LLMs for relation extraction, and how we can evaluate LLM RE systems.

In [None]:
import torch
import numpy as np

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
# use the 4B model

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="cuda", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

In [None]:
def call_llm(prompt, system_prompt="You are a helpful assistant.", generation_config=None):  
    if generation_config is None:
        generation_config = {
            "max_new_tokens": 500,
            "temperature": 0.01
        }
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # conduct text completion
    generated = model.generate(
        **model_inputs,
        **generation_config
    )

    # let's break this down:
    #                      | we take the element of the batch (our batch size is 1)
    #                      |  |-----------------------------| skip our original input
    output_ids = generated[0][len(model_inputs.input_ids[0]):].tolist()

    # decode into token space
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

## Load data

We will be using the relationship triples you extracted during the in-class activity on Tuesday. These have been preprocessed to match each triple to a paragraph.

In [None]:
import pandas as pd

In [None]:
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/11.nlp/movie_relations.json -O movie_relations.json

In [None]:
def read_data(path):
    df = pd.read_json(path)
    df = df.sample(50, random_state=42)
    texts = df.paragraph_text.to_list()
    labels = df.triples.to_list()
    return texts, labels

texts, triples = read_data("./movie_relations.json")

## Setting up the LLM

**Question 1:** Come up with **at least two different prompts or prompting methods** to perform relationship extraction using LLMs based on the relation categories we defined in the lab activity on Tuesday. Your output should be relationship triples. To enforce this, we create a `RelationTriple` wrapper class for your output.

Here's an example from the dataset:

Input:
```
Aboard the space station, Peiqiang discovers that MOSS, the station's computer commander, has decided to abandon Earth and repurpose the station as an interstellar ark to seed a new planet with Earth's biosphere. Breaking out of forced hibernation, he is joined by fellow Russian cosmonaut Maxim Makarov, whom MOSS awakens to stop Liu. While spacewalking, Makarov is killed by the spacecraft's automated security measures. Liu enters the control room, but his attempts to override the evacuation procedures are revoked. Qi's group arrives at the Sulawesi Supply Depot to find that, while most engines around the planet have been restored, the combined thrust is insufficient to divert Earth's trajectory as it approaches Jupiter's Roche limit. MOSS broadcasts a final message to the world, but Peiqiang refuses to follow the computer's instructions.
```

Output (you will want to return a list of `RelationTriple`s):
```
<Liu Peiqiang,business,Maxim Makarov>
<Liu Peiqiang,nemeses,MOSS>
<Maxim Makarov,nemeses,MOSS>
```


In [None]:
relations = [
    "family",
    "nemeses",
    "romantic",
    "friends",
    "business"
]

In [None]:
class RelationTriple():
    def __init__(self, head, tail, relation):
        self.head = head
        self.tail = tail
        self.relation = relation

    @classmethod
    def from_triple(cls, triple: str):
        parts = triple.strip("<>").split(",")
        parts = [part.strip() for part in parts]
        if len(parts) != 3:
            raise ValueError(f"triple {triple} is malformed")
        head, relation, tail = parts
        if relation not in relations:
            raise ValueError(f"triple {triple} has unsupported relation {relation}")
        return cls(head, tail, relation)

    @classmethod
    def validate(cls, triple: str):
        parts = triple.strip("<>").split(",")
        parts = [part.strip() for part in parts]
        if len(parts) != 3:
            return False
        head, relation, tail = parts
        if relation not in relations:
            return False
        return True

    def __str__(self):
        return f"<{self.head},{self.relation},{self.tail}>"

    def __repr__(self):
        return f"<{self.head},{self.relation},{self.tail}>"

In [None]:
def generate_relations_one(text: str) -> list[RelationTriple]:
    # TODO: fill me in!
    pass

In [None]:
def generate_relations_two(text: str) -> list[RelationTriple]:
    # TODO: fill me in!
    pass

In [None]:
def run_on_data(fn, texts) -> list[list[RelationTriple]]:
    return [
        fn(text) for text in tqdm(texts)
    ]

In [None]:
first_outputs = run_on_data(generate_relations_one, texts)
second_outputs = run_on_data(generate_relations_two, texts)

## Evaluating output

In [None]:
def get_gold_labels(labels: list[list[str]]):
    return [
        [RelationTriple.from_triple(triple) for triple in paragraph if RelationTriple.validate(triple)] for paragraph in labels
    ]

In [None]:
gold_labels = get_gold_labels(triples)

### Strict matching

**Question 2:** **Implement the following functions** in order to calculate the precision / recall / F1 of your model output on both prompts.

`get_confusion_matrix` should return a `ConfusionMatrix` containing the number of false/true positives/negatives calculated for the gold and predicted labels for one paragraph. It should use `correct_fn` to compute whether two triples match.

You shoudl calculate `precision`, `recall`, and `f1` over the gold and predicted labels for the entire list of paragraphs by adding up the confusion matrices for each one.

In [None]:
def strict_correct_fn(gold: RelationTriple, pred: RelationTriple) -> bool:
    return gold.head == pred.head and gold.relation == pred.relation and gold.tail == pred.tail

In [None]:
class ConfusionMatrix():
    def __init__(self, tp=0, fp=0, tn=0, fn=0):
        self.tp = tp
        self.fp = fp
        self.tn = tn
        self.fn = fn

    def __add__(self, other):
        return ConfusionMatrix(
            self.tp + other.tp,
            self.fp + other.fp,
            self.tn + other.tn,
            self.fn + other.fn,
        )

    def to_numpy(self):
        return np.array([self.tp, self.fp, self.fn, self.tn])

def get_confusion_matrix(gold: list[RelationTriple], pred: list[RelationTriple], correct_fn) -> ConfusionMatrix:
    pass

In [None]:
def precision(confusion_matrix: ConfusionMatrix) -> float:
    pass

In [None]:
def recall(confusion_matrix: ConfusionMatrix) -> float:
    pass

In [None]:
def f1(confusion_matrix: ConfusionMatrix) -> float:
    pass

In [None]:
def summed_confusion_matrix(gold: list[list[RelationTriple]], pred: list[list[RelationTriple]], correct_fn) -> ConfusionMatrix:
    return sum([
        get_confusion_matrix(g, p, correct_fn) for g, p in zip(gold, pred)
    ])

### LLM-as-judge

**Question 3:** Use the LLM to adjudicate the output by **implementing the `llm_correct_fn`**, then computing new precision, recall, and F1 scores.

In [None]:
def llm_correct_fn(gold: RelationTriple, pred: RelationTriple) -> bool:
    pass

### Evaluating the evaluation

**Question 4:** For each of the evaluation methods (strict matching and LLM-as-judge), sample 10 false positives and 10 false negatives. What proportion of these are incorrectly evaluated? **In a few sentences,** compare the evaluation methods and reflect on the challenges and potential methods for evaluating relationship extraction.