In [None]:
import ast
import os
import re
from pprint import pprint
from tqdm import tqdm
from transformers import pipeline
import torch
from utils import load_dataset

HF_TOKEN = os.getenv("HF_ACCESS_TOKEN")

astro_reviews = load_dataset("data/sentence_segmented/Astro_Reviews.json")

DEVICE = 'cuda' if torch.cuda.is_available(
) else 'mps' if torch.mps.is_available() else 'cpu'
DEVICE = list(range(torch.cuda.device_count())) if DEVICE == 'cuda' else DEVICE
print(f"Using device: {DEVICE}")

data/sentence_segmented/Astro_Reviews.json: 996/996 have all required keys
352


In [None]:
# Use a pipeline as a high-level helper
pipe = pipeline("text-generation",
                model="meta-llama/Llama-3.2-3B-Instruct", 
                max_new_tokens=100,
                temperature=0.01,
                device=DEVICE,
                token=HF_TOKEN)

"""
pipe(messages) -> list of one item
[{'generated_text': [dicts]}]
each inner dict has {'content': ..., 'role': ...}
The 'response' is then
response[0]['generated_text'][-1]['content']
"""

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

"\npipe(messages) -> list of one item\n[{'generated_text': [dicts]}]\neach inner dict has {'content': ..., 'role': ...}\nThe 'response' is then\nresponse[0]['generated_text'][-1]['content']\n"

In [None]:
prompt = """
Extract and output ONLY the inline citations from the text below as a list of tuples
- Each citation becomes a (string, int) tuple where the string is the first author's name and the int is the year
- If there are no citations in the text, output []
- Do not count citations 'in preparation' or lacking a year
- Do not include any introductory text, explanations, or anything before or after the array

Examples of inline citations:
'''
Sentence: "Like Caffau et al. (2008a) , we have similar findings."
Output: [('Caffau et al.', 2008)]

Sentence: "Methods for mixing below the convection zone are well understood ( Brun, Turck-Chièze Zahn 1999 , Charbonnel Talon 2005 )."
Output: [('Brun', 1999), ('Charbonnel', 2005)]

Sentence: "Momentum balance gives an expression ( Fabian 1999 ; Di Matteo, Wilman Crawford 2002 ; King 2003 , 2005 )"
Output: [('Fabian', 1999), ('Di Matteo', 2002), ('King', 2003), ('King', 2005)]

Sentence: "In the early Universe, when the metal content was extremely low, enrichment by a single supernova could dominate preexisting metal contents (e.g., Audouse Silk 1995 ; Ryan, Norris Beers 1996 )."
Output: [('Audouse', 1995), ('Ryan', 1996)]

Sentence: "This is consistent with previous results (Pereira et al., in preparation)."
Output: []
'''

Now extract the inline citations from the following text:
'''
{text}
'''

Output format: 
[('first author', year), ('first author', year), ...]
"""
text_1 = 'neglect the H collisions altogether based on the available atomic physics data for other elements, while others use the classical Drawin (1968) formula, possibly with a scaling factor S H that typically varies from 0 to 1. Holweger (2001) found log ε O = 8.71 ± 0.05 using the Holweger Müller (1974) model with granulation corrections'
text_2 = 'AGN feedback features in many theoretical, numerical, and semianalytic simulations of galaxy growth and evolution (e.g., Kauffmann Haehnelt 2000 ; Granato et al. 2004 ; Di Matteo, Springel Hernquist 2005 ; Springel, Di Matteo Hernquist 2005 ; Bower et al. 2006 ; Croton et al. 2006 ; Hopkins et al. 2006 ; Ciotti, Ostriker Proga 2010 ; Scannapieco et al. 2012 ).'


def format_prompt_for_pipe(prompt, text):
    return [{"role": "user", "content": prompt.format(text=text)}]

def get_pipe_response(pipe, prompt, text):
    msg = format_prompt_for_pipe(prompt, text)
    res = pipe(msg)
    return res[0]['generated_text'][-1]['content']

In [67]:
import csv
LIST_PATTERN = re.compile(r"\[.*?\]")

class ParseResponseError(Exception):
    def __init__(self, match_group, exception):
        super().__init__(f"Error parsing response: {match_group}")
        self.match_group = match_group
        self.exception = exception

def parse_response(response):
    match = re.search(LIST_PATTERN, response)
    if not match:
        return []
    try:
        lst = ast.literal_eval(match.group())
        return lst
    except Exception as e:
        raise ParseResponseError(match.group(), e)

def citations_from_sentence(sentence):
    try:
        res = parse_response(get_pipe_response(pipe, prompt, sentence))
        if res == []:
            with open('no_citation_sentences.csv', 'a') as f:
                csv.writer(f).writerow([sentence])
        else:
            print('in nonempty branch')
            with open('citation_sentences.csv', 'a') as f:
                csv.writer(f).writerow([res, sentence])
    except ParseResponseError as e:
        print(e)
        with open('error_citation_sentences.csv', 'a') as f:
            csv.writer(f).writerow([e.match_group, sentence, e.exception])

In [None]:
for sentence in astro_reviews[0]['body_sentences'][:100]:
    citations_from_sentence(sentence)


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


in nonempty branch


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


in nonempty branch
Error parsing response: [('']


In [51]:
print(len(astro_reviews[42]['body_sentences']))

1044


In [56]:
results = []
with open('results.csv', 'a') as f:
    for sentence in tqdm(astro_reviews[42]['body_sentences']):
        try:
            citations = citations_from_sentence(sentence)
            f.write(f'{citations},{sentence}\n')
        except Exception as e:
            print(e)
            continue
# for sentence in astro_reviews[42]['body_sentences'][:10]:
#     print('Sentence:', sentence)
#     citations = citations_from_sentence(sentence)
#     print('Citations:', citations)
#     results.append((sentence, citations))
#     print('=====')

  0%|          | 0/1044 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  0%|          | 1/1044 [00:01<23:16,  1.34s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  0%|          | 2/1044 [00:01<15:39,  1.11it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  0%|          | 3/1044 [00:02<11:39,  1.49it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  0%|          | 4/1044 [00:02<09:52,  1.75it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  0%|          | 5/1044 [00:03<08:44,  1.98it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  1%|          | 6/1044 [00:04<16:04,  1.08it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  1%|          | 7/1044 [00:05<14:57,  1.16it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
  1%|          | 8/1044 [00:06<12:17,  1.40it/s]Setting 