# In-context learning for Citation Prediction

In [16]:
import dspy
import pandas as pd
import ast
import numpy as np
import os
from numpy.linalg import norm
from tqdm import tqdm
from pathlib import Path
# from operator import add
from PyPDF2 import PdfReader
from openai import OpenAI
from dspy.evaluate import Evaluate

## Get the test data

In [17]:
query_candidate_data = pd.read_csv('darwin/test.qrel.cid', sep=' ', header=None, names=['query', 'candidate', 'bool'])

In [18]:

with open('darwin/qpaper_to_emb', 'r') as f:
    query_papers = [line.strip() for line in f]

with open('darwin/cpaper_to_emb', 'r') as f:
    candidate_papers = [line.strip() for line in f]

print(f'len(query_papers): {len(query_papers)}')
print(f'len(candidate_papers): {len(candidate_papers)}')


len(query_papers): 115
len(candidate_papers): 637


In [19]:
valid_rows = pd.DataFrame()
query_dir = 'darwin/query_papers'
candidate_dir = 'darwin/candidate_papers'
# Iterate over the rows of the data
for _, row in query_candidate_data.iterrows():
    query_file = os.path.join(query_dir, str(row['query']) + '.pdf')
    candidate_file = os.path.join(candidate_dir, str(row['candidate']) + '.pdf')

    # Check if both files exist
    if os.path.isfile(query_file) and os.path.isfile(candidate_file):
        # If both files exist, append the row to valid_rows
        valid_rows = valid_rows._append(row)

# Reset the index of valid_rows
valid_rows.reset_index(drop=True, inplace=True)
print(valid_rows.head())
print(f'Number of query candidate pairs with valid files: {len(valid_rows)}')

     query candidate  bool
0  3498240   1824499     1
1  3498240  53645322     0
2  3498240   1915951     0
3  3498240   3048298     0
4  3498240   3627503     0
Number of query candidate pairs with valid files: 651


In [20]:
data = [{"query_file": query_file, "candidate_file": candidate_file, "cites": bool(bool_)} for query_file, candidate_file, bool_ in zip(valid_rows['query'], valid_rows['candidate'], valid_rows['bool'])]
data = [dspy.Example(**x).with_inputs('query_file', 'candidate_file') for x in data]

def split_data(data, split_ratio, seed=42):
    np.random.seed(seed)
    indices = np.random.permutation(len(data))
    split_index = int(split_ratio * len(data))
    train_indices = indices[:split_index]
    test_indices = indices[split_index:]
    trainset = [data[i] for i in train_indices]
    testset = [data[i] for i in test_indices]
    return trainset, testset

# trainset, testset = split_data(data, 0)
trainset = data


## Chunker

In [21]:
llm = dspy.OpenAI(model="gpt-3.5-turbo")
dspy.settings.configure(lm=llm, rm=None)

client = OpenAI(
    # this is also the default, it can be omitted
    api_key=os.environ['OPENAI_API_KEY'],
)

In [22]:
class Chunker:
    def __init__(self, context_window=3000, max_windows=5):
        self.context_window = context_window
        self.max_windows = max_windows
        self.window_overlap = 0.02

    def __call__(self, paper):
        snippet_idx = 0

        while snippet_idx < self.max_windows and paper:
            endpos = int(self.context_window * (1.0 + self.window_overlap))
            snippet, paper = paper[:endpos], paper[endpos:]

            next_newline_pos = snippet.rfind('\n')
            if paper and next_newline_pos != -1 and next_newline_pos >= self.context_window // 2:
                paper = snippet[next_newline_pos+1:] + paper
                snippet = snippet[:next_newline_pos]

            yield snippet_idx, snippet.strip()
            snippet_idx += 1

## DSPy Module

In [39]:
def get_embeddings(texts, model="text-embedding-3-small", save_file=None):
    if save_file and Path(save_file).exists():
        with open(save_file, 'r') as f:
            # print(f"Loading embeddings from {save_file}")
            embeddings = [ast.literal_eval(line.strip()) for line in f]
        return embeddings
        
    try:
        response = client.embeddings.create(input=texts, model=model)
        embeddings = [embedding.embedding for embedding in response.data]
        if save_file: # Save the embeddings to a file
            with open(save_file, 'w') as f:
                # print(f"Saving embeddings to {save_file}")
                for embedding in embeddings:
                    f.write(str(embedding) + '\n')
        return embeddings
    except Exception as e:
        print("Error during API call:", e)
        return []
    
def get_most_similar_chunk(query_embedding, candidate_embeddings, candidate_chunks):
    similarities = np.dot(candidate_embeddings, query_embedding) / (norm(candidate_embeddings, axis=1) * norm(query_embedding))
    most_similar_idx = np.argmax(similarities)
    return candidate_chunks[most_similar_idx]
    
    
class PredictCitation(dspy.Signature):
    __doc__ = """Predict if the two chunks are related by a citation. Consider all possible ways in which a citation could occur, such as direct quotes, paraphrasing, or referring to the same ideas or data. Don't be afraid to predict that the chunks are related by a citation. If you're not sure, it's better to predict that they are related."""   
    query_chunk: str = dspy.InputField(desc='Query chunk to compare to the candidate chunk.')
    candidate_chunk: str = dspy.InputField(desc='Candidate chunk to compare to the query chunk.')
    answer: bool = dspy.OutputField(desc="either True or False", prefix="Answer:")


class PredictCitationAndResolve(dspy.Module):
    def __init__(self, context_window=3000, max_windows=5, resolve_function=any,
                 candidate_folder='darwin/candidate_papers', query_folder='darwin/query_papers',
                 reset_embedding=False):
        super().__init__()
        
        self.chunk = Chunker(context_window=context_window, max_windows=max_windows)
        # self.predict = dspy.TypedPredictor(PredictCitation)
        # self.predict = dspy.TypedChainOfThought(PredictCitation)
        self.predict = dspy.ChainOfThought(PredictCitation)
        self.resolve_function = resolve_function
        self.query_folder = query_folder
        self.candidate_folder = candidate_folder
        os.makedirs('embeddings', exist_ok=True)
        if reset_embedding:
            for emb_file in os.listdir('embeddings'):
                os.remove(f'embeddings/{emb_file}')

    def forward(self, query_file, candidate_file):
        predictions = []
        
        # Get the text from the pdfs
        query_pdf = PdfReader(f'{self.query_folder}/{query_file}.pdf')
        query_text = ""
        for page in query_pdf.pages:
            page_text = page.extract_text()
            if page_text:
                query_text += page_text + " "  # Adding space to separate text between pages
        query_text = query_text.replace("\n", " ")
        
        candidate_pdf = PdfReader(f'{self.candidate_folder}/{candidate_file}.pdf')
        candidate_text = ""
        for page in candidate_pdf.pages:
            page_text = page.extract_text()
            if page_text:
                candidate_text += page_text + " "
        candidate_text = candidate_text.replace("\n", " ")
        
        # for each chunk in the paper
        query_chunks = [snippet for _, snippet in self.chunk(query_text)]
        candidate_chunks = [snippet for _, snippet in self.chunk(candidate_text)]
        
        # Create embeddings for the chunks
        candidate_embeddings = get_embeddings(candidate_chunks, save_file=f'embeddings/candidate_{candidate_file}.emb')
        query_embeddings = get_embeddings(query_chunks, save_file=f'embeddings/query_{query_file}.emb')
        
        for snippet, query_embedding in zip(query_chunks, query_embeddings):
            # Get the candidate chunk that is most similar to the snippet
            candidate_chunk = get_most_similar_chunk(query_embedding, candidate_embeddings, candidate_chunks)
            prediction = self.predict(query_chunk=snippet, candidate_chunk=candidate_chunk)
            # print(prediction)
            predictions.append(prediction.answer=='True')

        return dspy.Prediction(predictions=predictions, resolved=self.resolve_function(predictions))

In [40]:
pipeline_chunking = PredictCitationAndResolve(max_windows=15, context_window=1000, reset_embedding=False)

## Example

In [41]:
chunker = Chunker(context_window=1000, max_windows=15)
query_pdf = PdfReader(f'darwin/query_papers/1323414.pdf')
query_text = ""
for page in query_pdf.pages:
    page_text = page.extract_text()
    if page_text:
        query_text += page_text + " "  # Adding space to separate text between pages
query_text = query_text.replace("\n", " ")
query_chunks = [snippet for _, snippet in chunker(query_text)]
print(query_chunks)



In [42]:
print(len(query_chunks[0]))
print(len(query_chunks))

1020
15


In [43]:
# get an example
example = trainset[-2]
example_x = example.inputs()
example_y = example.labels()
print(example_x)
print(example_y)

prediction = pipeline_chunking(**example_x)
print(prediction)
print(example_y.cites)

Example({'query_file': 1323414, 'candidate_file': '3324808'}) (input_keys=None)
Example({'cites': False}) (input_keys=None)
Prediction(
    predictions=[False, False, False, True, False, False, False, False, True, False, True, True, False, False, False],
    resolved=True
)
False


In [44]:
llm.inspect_history(n=5)





Predict if the two chunks are related by a citation. Consider all possible ways in which a citation could occur, such as direct quotes, paraphrasing, or referring to the same ideas or data. Don't be afraid to predict that the chunks are related by a citation. If you're not sure, it's better to predict that they are related.

---

Follow the following format.

Query Chunk: Query chunk to compare to the candidate chunk.

Candidate Chunk: Candidate chunk to compare to the query chunk.

Reasoning: Let's think step by step in order to ${produce the answer}. We ...

Answer: either True or False

---

Query Chunk: the BrainProducts company. The same distribution of elec- trodes was used and the signals were registered with a sampling frequency of 500 Hz. Salazar-Varas et al. Journal of NeuroEngineering and Rehabilitation (2015) 12:101 Page 3 of 15 Inertial measurement units During the tests, kinematic information is also recorded in order to know when the subject has reacted to the obsta-

## Evaluate

In [45]:
def metric(example, result):
    '''Match metric'''
    return 1 if example.cites == result.resolved else 0

In [46]:
evaluate = Evaluate(devset=trainset, metric=metric, num_threads=8, display_progress=True, display_table=0, max_errors=100, return_outputs=True)
outputs = evaluate(pipeline_chunking)



Error for example in dev set: 		 negative seek value -1



Average Metric: 50.0 / 102  (49.0):  16%|█▌        | 102/651 [06:41<22:25,  2.45s/it]

Error during API call: Error code: 400 - {'error': {'message': "'$.input' is invalid. Please check the API reference: https://platform.openai.com/docs/api-reference.", 'type': 'invalid_request_error', 'param': None, 'code': None}}
Error for example in dev set: 		 shapes (0,) and (1536,) not aligned: 0 (dim 0) != 1536 (dim 0)


[A
  return v1_cached_gpt3_turbo_request_v2(**kwargs)



Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


Average Metric: 110.0 / 243  (45.3):  37%|███▋      | 242/651 [15:55<23:41,  3.48s/it][A

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


  return v1_cached_gpt3_turbo_request_v2(**kwargs)
  return v1_cached_gpt3_turbo_request_v2(**kwargs)

Average Metric: 136.0 / 313  (43.5):  48%|████▊     | 313/651 [20:26<17:48,  3.16s/it]

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


[A


Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


Average Metric: 164.0 / 393  (41.7):  60%|██████    | 392/651 [25:24<12:41,  2.94s/it][A

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)
Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)



Average Metric: 164.0 / 395  (41.5):  61%|██████    | 395/651 [25:27<06:57,  1.63s/it]

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


[A
Average Metric: 164.0 / 397  (41.3):  61%|██████    | 396/651 [25:27<05:56,  1.40s/it]

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


[A

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)




Error during API call: Error code: 400 - {'error': {'message': "'$.input' is invalid. Please check the API reference: https://platform.openai.com/docs/api-reference.", 'type': 'invalid_request_error', 'param': None, 'code': None}}
Error for example in dev set: 		 shapes (0,) and (1536,) not aligned: 0 (dim 0) != 1536 (dim 0)


  return v1_cached_gpt3_turbo_request_v2(**kwargs)


Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)




Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)




Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)



Average Metric: 216.0 / 519  (41.6):  80%|███████▉  | 519/651 [32:19<07:45,  3.53s/it]

Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


  return v1_cached_gpt3_turbo_request_v2(**kwargs)
  return v1_cached_gpt3_turbo_request_v2(**kwargs)
  return v1_cached_gpt3_turbo_request_v2(**kwargs)
  return v1_cached_gpt3_turbo_request_v2(**kwargs)


Error for example in dev set: 		 PyCryptodome is required for AES algorithm




Error for example in dev set: 		 negative seek value -1




Error for example in dev set: 		 leading zeros in decimal integer literals are not permitted; use an 0o prefix for octal integers (<unknown>, line 1)




Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)




Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)




Error for example in dev set: 		 unmatched ']' (<unknown>, line 1)




Error for example in dev set: 		 '[' was never closed (<unknown>, line 1)


Average Metric: 270.0 / 651  (41.5): 100%|██████████| 651/651 [40:21<00:00,  3.72s/it]

Average Metric: 270.0 / 651  (41.5%)





In [48]:
all_predictions = []
for x in outputs[1]:
    if type(x[1])==dspy.Prediction:
        all_predictions.append(x[1].resolved)
    else:
        all_predictions.append(np.nan)
    

all_labels = [x[0].cites for x in outputs[1]]
print(len(all_predictions))

with open('darwin/eval/predictions_COT_large_prompt_1000.txt', 'w') as f:
    for pred in all_predictions:
        f.write(str(pred) + '\n')

651


In [49]:
# Compute the accuracy of the final predictions
correct_predictions = [prediction == label for prediction, label in zip(all_predictions, all_labels)]
accuracy = sum(correct_predictions) / len(correct_predictions)
print(f'Accuracy: {accuracy:.2f}')

# Compute the recall of the final predictions
true_positives = sum([prediction and label for prediction, label in zip(all_predictions, all_labels)])
false_negatives = sum([not prediction and label for prediction, label in zip(all_predictions, all_labels)])
recall = true_positives / (true_positives + false_negatives)
print(f'Recall: {recall: .2f}')

# Compute the precision of the final predictions
true_positives = sum([prediction and label for prediction, label in zip(all_predictions, all_labels)])
false_positives = sum([prediction and not label for prediction, label in zip(all_predictions, all_labels)])
precision = true_positives / (true_positives + false_positives)
print(f'Precision: {precision:.2f}')

# F1 score
f1 = 2 * (precision * recall) / (precision + recall)
print(f'F1 Score: {f1:.2f}')

Accuracy: 0.41
Recall:  0.70
Precision: 0.18
F1 Score: 0.29


In [33]:
all_predictions

[False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 True,
 False,
 False,
 nan,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 nan,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False,
 False

Weird paper

In [34]:
PdfReader('darwin/query_papers/53079158.pdf').pages[-1].extract_text()

'237TomasMikolov,QuocVLe,andIlyaSutskever.2013b.\nExploitingsimilaritiesamonglanguagesformachine\ntranslation. CoRR.\nTomas Mikolov, Ilya Sutskever, Kai Chen, Gregory S.\nCorrado, and Jeffrey Dean. 2013c. Distributed rep-\nresentations of words and phrases and their compo-\nsitionality. In NIPS.\nDavid Milne and Ian H. Witten. 2008. An effective,\nlow-cost measure of semantic relatedness obtained\nfromwikipedialinks. In AAAI.\nMike Mintz, Steven Bills, Rion Snow, and Daniel Ju-\nrafsky.2009. Distantsupervisionforrelationextrac-\ntionwithoutlabeleddata. In ACL/IJCNLP .\nAditya Mogadala and Achim Rettinger. 2016. Bilin-\ngual word embeddings from parallel and non-\nparallel corpora for cross-language text classiﬁca-\ntion. In HLT-NAACL .\nThien Huu Nguyen, Nicolas Fauceglia, Mariano Ro-\ndriguez Muro, Oktie Hassanzadeh, Alﬁo Massimil-\nianoGliozzo,andMohammadSadoghi.2016. Joint\nlearning of local and global features for entity link-\ningvianeuralnetworks. In COLING.\nSebastian Ruder, Iva

In [35]:
llm.inspect_history(n=5)





Predict if the two chunks are related by a citation. Consider all possible ways in which a citation could occur, such as direct quotes, paraphrasing, or referring to the same ideas or data. Don't be afraid to predict that the chunks are related by a citation. If you're not sure, it's better to predict that they are related.

---

Follow the following format.

Query Chunk: Query chunk to compare to the candidate chunk.
Candidate Chunk: Candidate chunk to compare to the query chunk.
Answer: either True or False

---

Query Chunk: s/article/19/16/2088/242445 by guest on 05 April 2024 S.Oba et al. Factorscores x=(x1,...,xK)fortheexpressionvector y are obtained by minimization of the residual error: err=/vextenddouble/vextenddouble/vextenddoubleyobs−Wobsx/vextenddouble/vextenddouble/vextenddouble2 . Thisisawell-knownregressionproblem,andtheleastsquare solution is given by x=(WobsTWobs)−1WobsTyobs. Using x, the missing part is estimated as ymiss=Wmissx.( 2 ) InthePCregressionabove, Wshou