# F-COREF Exploration
[Paper](https://arxiv.org/pdf/2209.04280)

## Basic Setup

In [1]:
# Automatic reloading
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

# Get the current file's directory (e.g., the 'notebooks' directory)
current_dir = os.path.dirname(os.path.abspath(''))

# Navigate one level up to project directory
project_dir = os.path.abspath(os.path.join(current_dir, '..'))

# Add the directory to sys.path
sys.path.append(project_dir)
os.chdir(project_dir)
os.getcwd()

'c:\\Users\\Ryan Lee\\Desktop\\AISG Internship\\rag'

In [3]:
import transformers
print(transformers.__version__)

  from .autonotebook import tqdm as notebook_tqdm


4.44.2


## Patch Code

In [4]:
# pip install coref
# Need to patch for our version of transformers (4.44.2)

from fastcoref import LingMessCoref as OriginalLingMessCoref
from fastcoref import FCoref as OriginalFCoref
from transformers import AutoModel
import functools

class PatchedLingMessCoref(OriginalLingMessCoref):
    def __init__(self, *args, **kwargs):
        original_from_config = AutoModel.from_config

        def patched_from_config(config, *args, **kwargs):
            kwargs['attn_implementation'] = 'eager'
            return original_from_config(config, *args, **kwargs)

        try:
            AutoModel.from_config = functools.partial(patched_from_config, attn_implementation='eager')
            super().__init__(*args, **kwargs)
        finally:
            AutoModel.from_config = original_from_config

class PatchedFCoref(OriginalFCoref):
    def __init__(self, *args, **kwargs):
        original_from_config = AutoModel.from_config

        def patched_from_config(config, *args, **kwargs):
            kwargs['attn_implementation'] = 'eager'
            return original_from_config(config, *args, **kwargs)

        try:
            AutoModel.from_config = functools.partial(patched_from_config, attn_implementation='eager')
            super().__init__(*args, **kwargs)
        finally:
            AutoModel.from_config = original_from_config
                    
'''
model1 = PatchedLingMessCoref(
    nlp="en_core_web_lg",
    device="cpu"
)

model2 = PatchedFCoref(
    nlp="en_core_web_lg",
    device="cpu"
)
'''

# Run your stuff here


'\nmodel1 = PatchedLingMessCoref(\n    nlp="en_core_web_lg",\n    device="cpu"\n)\n\nmodel2 = PatchedFCoref(\n    nlp="en_core_web_lg",\n    device="cpu"\n)\n'

## Inference

In [5]:
# pip install coref

model = PatchedFCoref(device='cpu')

texts = [
    'We are AISG. We are so happy to see you using the coref package. This package is very fast!',
    'Alice goes down the rabbit hole. Where she would discover a new reality beyond her expectations.',
    'Mary saw Susan at the park. She was playing with a frisbee. They then conversed.',
    'Alice went to the library because she wanted to borrow a book. She found a novel by Kenrick and decided to check it out. As Alice walked home, she bumped into her friend Clara, who asked her what she had borrowed. Alice showed it to Clara, and they talked about the author for a while.'
]

preds = model.predict(
   texts=texts
)
preds

11/22/2024 17:01:57 - INFO - 	 missing_keys: []
11/22/2024 17:01:57 - INFO - 	 unexpected_keys: []
11/22/2024 17:01:57 - INFO - 	 mismatched_keys: []
11/22/2024 17:01:57 - INFO - 	 error_msgs: []
11/22/2024 17:01:57 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M
11/22/2024 17:01:57 - INFO - 	 Tokenize 4 inputs...
Map: 100%|██████████| 4/4 [00:00<00:00, 117.63 examples/s]
11/22/2024 17:01:58 - INFO - 	 ***** Running Inference on 4 texts *****
Inference: 100%|██████████| 4/4 [00:00<00:00,  5.38it/s]


[CorefResult(text="We are AISG. We are so happy to see you using the ...", clusters=[['We', 'We'], ['the coref package', 'This package']]),
 CorefResult(text="Alice goes down the rabbit hole. Where she would d...", clusters=[['Alice', 'she', 'her']]),
 CorefResult(text="Mary saw Susan at the park. She was playing with a...", clusters=[['Susan', 'She']]),
 CorefResult(text="Alice went to the library because she wanted to bo...", clusters=[['Alice', 'she', 'She', 'Alice', 'she', 'her', 'her', 'she', 'Alice'], ['a novel by Kenrick', 'it', 'it'], ['her friend Clara, who asked her what she had borrowed', 'Clara'], ['Kenrick', 'the author']])]

In [6]:
preds[0].get_clusters(as_strings=False)

[[(0, 2), (13, 15)], [(46, 63), (65, 77)]]

In [7]:
preds[0].get_clusters()

[['We', 'We'], ['the coref package', 'This package']]

## Remarks
- Need to adapt code to work with our version of transformers ([link](https://github.com/shon-otmazgin/fastcoref/issues/59))
- `LingMessCoref` is the larger s2e model - bigger input size (Longformer: 4096 tokens) but slower and larger memory footprint. In contrast, `FCoref` (student model via distillation) replaces Longformer with DistilRoBERTa which is roughly 8 times faster than Longformer but has smaller input size (512 tokens). 
    - Longformer uses sliding window attention which reduces attention mechanism time complexity to linear O(nw) where w is the window size. 
- Some strange clusters e.g. ['her friend Clara, who asked her what she had borrowed', 'Clara'] (the first element is overly long)
- How to handle overlapping spans in the processing step
- There is no given 'representative value' for a given cluster (maybe an LLM processes?)
- Some failures (cannot associate "we" with "AISG")

In [8]:
from src.components.coreference_models import FastCoreferenceModel

my_model = FastCoreferenceModel(device="cpu")
my_model.predict(text=texts[0])

11/22/2024 17:02:02 - INFO - 	 missing_keys: []
11/22/2024 17:02:02 - INFO - 	 unexpected_keys: []
11/22/2024 17:02:02 - INFO - 	 mismatched_keys: []
11/22/2024 17:02:02 - INFO - 	 error_msgs: []
11/22/2024 17:02:02 - INFO - 	 Model Parameters: 90.5M, Transformer: 82.1M, Coref head: 8.4M
11/22/2024 17:02:02 - INFO - 	 Tokenize 1 inputs...
Map: 100%|██████████| 1/1 [00:00<00:00, 111.10 examples/s]
11/22/2024 17:02:02 - INFO - 	 ***** Running Inference on 1 texts *****
Inference: 100%|██████████| 1/1 [00:00<00:00, 16.42it/s]


[Cluster(mentions=[Mention(char_idx=(0, 1), content='We'), Mention(char_idx=(13, 14), content='We')]),
 Cluster(mentions=[Mention(char_idx=(46, 62), content='the coref package'), Mention(char_idx=(65, 76), content='This package')])]