In [2]:
import arxiv
import json
import numpy as np  
import os
import pprint
import random
import re
import string
import torch
import torch.nn.functional as F
import time
import warnings

from accelerate import infer_auto_device_map, init_empty_weights, Accelerator
from bs4 import BeautifulSoup
from collections import Counter, defaultdict
from sklearn.metrics import accuracy_score, f1_score
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from torch.profiler import profile, record_function, ProfilerActivity
from tqdm import tqdm
from typing import List, Dict, Tuple

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [4]:
def set_random_seeds(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.manual_seed_all(seed_value)  
        torch.backends.cudnn.deterministic = True
set_random_seeds()

In [5]:
def set_env_vars(fname='access_keys.json'):
    with open(fname) as f:
        keys = json.load(f)
        for key in keys:
            if key not in os.environ.keys():
                os.environ[key.upper()] = keys[key]

set_env_vars()

In [6]:
model_id = "meta-llama/Meta-Llama-3-70B-Instruct"

accelerator = Accelerator()

tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
# device_map = infer_auto_device_map(model, max_memory=max_memory)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=os.getenv("HF_TOKEN")
)

model = accelerator.prepare(model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 30/30 [00:23<00:00,  1.28it/s]


In [7]:
data = []
with open("test-results-astro-ph.jsonl", "r") as f:
    for line in f:
        json_obj = json.loads(line)
        data.append(json_obj)

with open("schema.json", "r") as f:
    schema = json.load(f)

with open("constituency_tests.json", "r") as f:
    constituency_tests = json.load(f)

examples = []
with open("examples.jsonl", "r") as f:
    for line in f:
        examples.append(json.loads(line))


In [8]:
schema_readable = ""
for key, value in schema.items():
    schema_readable += f"{key.upper()}: {value}\n"
print(schema_readable)

MODEL: a representation of a (scientific) phenomenon using mathematical formalism and/or computational simulation
TASK: a specific problem, objective or goal to be accomplished
DATASET: a collection of data, measurements or observations
FIELD: an academic (sub)discipline
MODALITY: a class or type of data/observations with similar or the same structure
METHOD: an approach, technique or procedure to complete a task
OBJECT: an entity that can be studied
PROPERTY: a quantitative or qualitative descriptor, or an inherent attribute of an entity, data, modality or method
INSTRUMENT: a device or system used for making measurements



In [14]:
def get_sentences(text: str) -> List[str]:
    # TODO: consider using sentence splitter from spacy, etc.
    return text.split(". ")

# TODO: review this function
def extract_all_tagged_phrases(text: str) -> Dict[str, List[str]]:
    soup = BeautifulSoup(text, "html.parser")
    tagged_phrases = defaultdict(list)

    # Recursive function to extract text from nested tags
    def extract_text(tag):
        if tag.name:
            full_text = ' '.join(tag.stripped_strings)
            tagged_phrases[tag.name].append(full_text)
            # Recursively process all children tags
            for child in tag.find_all(True):
                extract_text(child)

    for tag in soup.find_all(True):
        extract_text(tag)

    for tag in tagged_phrases:
        tagged_phrases[tag] = list(dict.fromkeys(tagged_phrases[tag]))
    
    return dict(tagged_phrases)


def generate_instructions(schema: dict, kind: str = "json") -> str:
    instruction_parts = [
        "The following schema is provided to tag the title and abstract of a given scientific paper as shown in the examples:\n"    
    ]
    if kind == "json":
        instruction_parts.append(f"{json.dumps(schema, indent=2)}\n\n")
    elif kind == "readable":
        readable_schema = ""
        for tag, description in schema.items():
            readable_schema += f"{tag}: {description}\n"
        instruction_parts.append(f"{readable_schema}\n")
    else:
        raise ValueError(f"Invalid kind: {kind}")

    return "".join(instruction_parts)


def generate_demonstrations(examples: List[dict], kind: str = "json") -> str:
    demonstration_parts = []
    for example in examples:
        sentences = get_sentences(example["abstract"])
        tagged_sentences = get_sentences(example["tagged_abstract"])

        for sentence, tagged_sentence in random.sample(list(zip(sentences, tagged_sentences, strict=True)), k=3):
            tag_to_phrase = extract_all_tagged_phrases(tagged_sentence)
            if kind == "json":
                extractions = f"{json.dumps(tag_to_phrase, indent=2)}\n"
            elif kind == "readable":
                extractions = "".join(
                    f"{tag}: {', '.join(phrase)}\n"
                    for tag, phrase in tag_to_phrase.items()
                )
            else:
                raise ValueError(f"Invalid kind: {kind}")
            
            demonstration_parts.append(
                f"Sentence: {sentence}\n"
                f"Extractions:\n{extractions}\n"
            ) 

    return "".join(demonstration_parts)


def generate_input(sentence: str) -> str:
    return (
        f"Sentence: {sentence}\n"
        f"Extractions:\n"
    )
    

def generate_prompt(instructions: str, demonstrations: str) -> str:
    return (
        f"{instructions}"
        f"{demonstrations}"
    )
    

def generate_prediction(prompt: str, input: str, kind: str) -> dict:
    messages = [
                {"role": "system", "content": "You are an assistant who tags papers according to given schema and only returns the tagged phrases in the format as provided in the examples without repeating anything else."},
                {"role": "user", "content": prompt + input},
            ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        # add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    outputs = model.generate(
        input_ids,
        max_new_tokens=1200,
        eos_token_id=terminators,
        # num_beams=8,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    prediction_response = tokenizer.decode(response, skip_special_tokens=True)
    predicted_tag = extract_prediction(prediction_response, kind=kind)

    return predicted_tag


def extract_prediction(prediction: str, kind: str = "json") -> dict:
    if kind == "json":
        json_match = re.search(r'\{[^}]+\}', prediction)
        if json_match:
            # TODO: Replace single quotes with double quotes in prompt and remove code below.
            json_str = json_match.group(0)
            json_str = re.sub(r"(?<![\w'])'|'(?![\w'])", '"', json_str)
            json_str = re.sub(r'}\s*"', '}, "', json_str)
            json_str = re.sub(r']\s*"', '], "', json_str)
            try:
                pred = json.loads(json_str)
            except json.JSONDecodeError as e:
                # TODO: Use the warning module here.
                print(f"Failed to parse JSON: {json_str}")
                print(f"Error: {str(e)}")
                pred = {}
    elif kind == "readable":
        match = re.findall(rf'^({"|".join(list(schema.keys()))}): (.+)$', prediction, flags=re.MULTILINE)
        pred = {
            tag: values.split(", ")
            for tag, values in match
        }
    else:
        raise ValueError(f"Invalid kind: {kind}")

    return pred    


def classify_predictions(gold: dict, pred: dict, union=False) -> Dict[str, float]:
    """
    Returns true positives, false positives, and false negatives for one example
    If union is True, then disregards the type of the tag and only considers the union of all tags
    """
    n_tp = 0
    n_fp = 0
    n_fn = 0
    if union:
        gold_phrases = set(phrase for phrases in gold.values() for phrase in phrases)
        pred_phrases = set(phrase for phrases in pred.values() for phrase in phrases)
        n_tp = len(gold_phrases & pred_phrases)
        n_fp = len(pred_phrases - gold_phrases)
        n_fn = len(gold_phrases - pred_phrases)
        return n_tp, n_fp, n_fn

    for tag in schema.keys():
        gold_phrases = set(gold.get(tag, []))   
        pred_phrases = set(pred.get(tag, []))
            
        n_tp += len(gold_phrases & pred_phrases)
        n_fp += len(pred_phrases - gold_phrases)
        n_fn += len(gold_phrases - pred_phrases)

    return n_tp, n_fp, n_fn


train = examples[:3]
valid = examples[3:]

kind = "readable"
prompt = generate_prompt(
        instructions=generate_instructions(schema, kind),
        demonstrations=generate_demonstrations(train, kind),
)

In [15]:
n_tp = 0
n_fp = 0
n_fn = 0
union_tp = 0
union_fp = 0
union_fn = 0
running_time = 0
pred_times = []
all_inputs = []
gold_tags = []
predicted_tags = []
start = time.time()
for example in tqdm(valid): 
    abstract = example["abstract"]
    tagged_abstract = example["tagged_abstract"]
    for sentence, tagged_sentence in tqdm(zip(get_sentences(abstract), get_sentences(tagged_abstract), strict=True)):
        input = generate_input(sentence)
        s_time = time.time()
        pred = generate_prediction(prompt, input, kind)
        e_time = time.time()
        gold = extract_all_tagged_phrases(tagged_sentence)
        tp, fp, fn = classify_predictions(gold, pred)
        n_tp += tp
        n_fp += fp
        n_fn += fn
        utp, ufp, ufn = classify_predictions(gold, pred, union=True)
        union_tp += utp
        union_fp += ufp
        union_fn += ufn

        running_time += time.time() - s_time
        pred_times.append(e_time - s_time)
        
        all_inputs.append(input)
        predicted_tags.append(pred)
        gold_tags.append(gold)


  0%|          | 0/17 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Ple

In [16]:
precision = round(n_tp / (n_tp + n_fp) if (n_tp + n_fp) > 0 else 0, 4)
recall = round(n_tp / (n_tp + n_fn) if (n_tp + n_fn) > 0 else 0, 4)
f1 = round(2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0, 4)
union_precision = round(union_tp / (union_tp + union_fp) if (union_tp + union_fp) > 0 else 0, 4)
union_recall = round(union_tp / (union_tp + union_fn) if (union_tp + union_fn) > 0 else 0, 4)
union_f1 = round(2 * (union_precision * union_recall) / (union_precision + union_recall) if (union_precision + union_recall) > 0 else 0, 4)
avg_time = round(sum(pred_times) / len(pred_times), 4)

metrics = {
    "precision": precision,
    "recall": recall,
    "f1": f1,
    "union_precision": union_precision,
    "union_recall": union_recall,
    "union_f1": union_f1,
    "avg_time_per_sentence": avg_time,
    "total_time": round(running_time, 4)
}
pprint.pprint(metrics)

{'avg_time_per_sentence': 2.8936,
 'f1': 0.3226,
 'precision': 0.3962,
 'recall': 0.2721,
 'total_time': 332.7966,
 'union_f1': 0.449,
 'union_precision': 0.5559,
 'union_recall': 0.3766}


In [17]:
uuid = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(8))
few_shot = "random"
fname = f"prompt_{few_shot}_{kind}_{uuid}.jsonl"

with open(os.path.join("randomized_few_shot_outputs", fname), "w", encoding="utf-8") as f:
    for input, gold_tag, pred_tag in zip(all_inputs, gold_tags, predicted_tags):
        f.write(f"Prompt:\n{prompt + input}\n")
        f.write(f"True Tag:\n{gold_tag}\n")
        f.write(f"Predicted Tag:\n{pred_tag}\n")
        f.write("#"*50 + "\n")

mname = f"metrics_{few_shot}_{kind}_{uuid}.json"
with open(os.path.join("randomized_few_shot_outputs", mname), "w") as f:
    json.dump({
        "metrics": metrics,
        "prompt_file": fname
    }, f, indent=4)

: 

In [13]:
# TODO: add time stats to the loop
# TODO: Review code / clean it up.
# TODO: Test more thoroughly.
# TODO: try sorting the tags in the schema and examples
# TODO: try including all tags in the examples even if they dont have phrases
# TODO: dedup instruction and system prompts