[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dbamman/anlp25/blob/main/11.nlp/HW11_LLM_Coref.ipynb)

# HW11: Coreference with LLMs

In this homework, you will experiment with using LLMs for zero-shot or few-shot coreference resolution.

In [None]:
import torch

from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_brat.conll -O 1342_pride_and_prejudice_brat.conll
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/data/1342_pride_and_prejudice_sample.txt -O 1342_pride_and_prejudice_sample.txt
!wget https://raw.githubusercontent.com/dbamman/anlp25/refs/heads/main/11.nlp/coref_utils.py

## The coreference resolution task

We formulate the coreference resolution task as follows: given an input text, output a sequence of coref chains $(C_1, \ldots, C_i)$, each of which contains a sequence of coref mentions $C_i = (m_{i1}, \ldots, m_{ij})$ ordered by start index. Each coref mention is a tuple of the start and end indices $m_{ij} = (\text{start\_index}, \text{end\_index})$ denoting the span of the mention in the text.

To formalize this in the code, we set up the following classes:

```python
@dataclass
class CorefMention:
    start_idx: int
    end_idx: int


@dataclass
class CorefChain:
    mentions: list[CorefMention] = field(default_factory=list)


CorefOutput = list[CorefChain]
```

To avoid cluttering up the notebook, we provide a utility file with these types and other useful functions.

In [None]:
from coref_utils import CorefMention, CorefChain, CorefOutput

## Evaluation data

We will use data from [LitBank](https://github.com/dbamman/litbank), which contains coreference annotations for novels in the public domain. In particular, we will be evaluate the coreference output on approximately 2,000 words of _Bleak House_ by Charles Dickens.

Because most systems can't handle coreference on such long texts (and it would also strain the memory usage of the LLM), we do inference on chunks of sentences instead. The `load_conll_data` function takes care of this for us.

In [None]:
from coref_utils import load_conll_data

In [None]:
text, gold_coref_chains = load_conll_data("./1342_pride_and_prejudice_sample.txt", "./1342_pride_and_prejudice_brat.conll")

Let's examine the first text chunk as an example.

In [None]:
print(text[0])

In [None]:
for chain in gold_coref_chains[0][:10]:
    print("===")
    for mention in chain.mentions:
        print(f'"{text[0][mention.start_idx:mention.end_idx]}"')


## Baseline with `stanza`

Here, we will evaluate a baseline with the `stanza` coreference resolution system.

The $B^3$ precision and recall metrics are defined at the entity mention level. We follow previous works in evaluating mention detection and coreference separately:
- We calculate span F1 to measure the performance of mention detection
- We calculate the $B^3$ metrics on only the mentions that are shared between the gold and system outputs.

In [None]:
from coref_utils import evaluate

# Usage:
# evaluate(gold, pred)
# Returns an EvaluationOutput object that contains B3 precision and recall, as well as span precision/recall/f1

Now let's implement and run the Stanza baseline.

In [None]:
## STANZA BASELINE
import stanza

def stanza_baseline(text_chunks: list[str]) -> CorefOutput:
    """
    Run the stanza baseline on the text chunks.
    """
    pipe = stanza.Pipeline("en", processors="tokenize,lemma,pos,depparse,coref")
    results = []
    for text_chunk in tqdm(text_chunks):
        chains = []
        doc = pipe(text_chunk)
        for coref_chain in doc.coref:
            mentions = []
            for mention in coref_chain.mentions:
                span = doc.sentences[mention.sentence].words[mention.start_word:mention.end_word]
                mentions.append(CorefMention(start_idx=span[0].start_char, end_idx=span[-1].end_char))
            chains.append(CorefChain(mentions=mentions))
        results.append(chains)
    return results

baseline = stanza_baseline(text)

In [None]:
evaluate(gold_coref_chains, baseline)

## LLM prompting

How can we prompt an LLM to do this task? How can we post-process the output into the structure that we desire?

In [None]:
# use the 4B model

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-4B", device_map="cuda", dtype="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")

We use the same code to call the LLM as in previous assignments. Feel free to modify this if you wish.

In [None]:
def call_llm(prompt, system_prompt="You are a helpful assistant.", generation_config=None):  
    if generation_config is None:
        generation_config = {
            "max_new_tokens": 10,
            "temperature": 0.01
        }
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )

    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # conduct text completion
    generated = model.generate(
        **model_inputs,
        **generation_config
    )

    # let's break this down:
    #                      | we take the element of the batch (our batch size is 1)
    #                      |  |-----------------------------| skip our original input
    output_ids = generated[0][len(model_inputs.input_ids[0]):].tolist()

    # decode into token space
    return tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")

### Question 1

What is the simplest solution you can think of to do this task with an LLM? **In a few sentences**, explain your approach. Then, **implement this method** and evaluate its performance.

In [None]:
def llm_baseline(text_chunks) -> CorefOutput:
    """For you to implement! Return a list of CorefChain objects"""
    return [[CorefChain([])] for _ in text_chunks]

output = llm_baseline(text)
evaluate(gold_coref_chains, output)

### Question 2

What are some mistakes that you notice the system making? What are some improvements you can make? Name at least **two** improvements you want to test. Then, **implement these** and report how they affect the performance of the system. Feel free to experiment with more!

In [None]:
def llm_improved(text_chunks) -> CorefOutput:
    """For you to implement! Return a list of CorefChain objects"""
    return [[CorefChain([])] for _ in text_chunks]

output = llm_improved(text)
evaluate(gold_coref_chains, output)