In [1]:
import os

paper_source_directory = '/home/louis/research/pdf_processor/processed_data/superconductivity_processed/physrevb.71.134526'
#paper_source_directory = '/home/louis/research/pdf_processor/processed_data/superconductivity_processed/physrevb.88.144511'
paper_source_directory = '/home/louis/research/pdf_processor/processed_data/superconductivity_processed/physrevb.86.214518'
file_name = 'text.txt'

with open(os.path.join(paper_source_directory, file_name)) as f:
    paper_text = f.read()


In [42]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import transformers
import torch

SYS_PROMPT = """You are a graduate student research assistant in physics. 
You are given the extracted parts of a long document and a question. Read the document and don't make up an answer."""

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

# use quantization to lower GPU usage                                                
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=bnb_config
)
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [47]:
def format_prompt(prompt, paper_text):
  PROMPT = f"Question: {prompt}\nContext: " + paper_text
  return PROMPT
    \

def generate(formatted_prompt):
  formatted_prompt = formatted_prompt[:16000] # to avoid GPU OOM                      
  messages = [{"role":"system","content":SYS_PROMPT}, {"role":"user","content":formatted_prompt}]
  # tell the model to generate                                                       
  input_ids = tokenizer.apply_chat_template(
      messages,
      add_generation_prompt=True,
      return_tensors="pt"
  ).to(model.device)

  print(input_ids)
  print(input_ids.shape)
  outputs = model.generate(
      input_ids,
      max_new_tokens=1024,
      eos_token_id=terminators,
      do_sample=True,
      temperature=0.6,
      top_p=0.9,
  )
  response = outputs[0][input_ids.shape[-1]:]
  return tokenizer.decode(response, skip_special_tokens=True)

In [48]:
generate(format_prompt("What is the material studied in this paper? Format the answer as MATERIAL: {Chemical Formula}. If there are multiple materials, separate them with commas.", paper_text))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


tensor([[128000, 128006,   9125,  ...,  78191, 128007,    271]],
       device='cuda:0')
torch.Size([1, 4334])


'MATERIAL: Sr0.9La0.1CuO2, Sr0.9Gd0.1CuO2, Gd.'

In [40]:
generate(format_prompt("What is the critical temperature at zero-field of the material studied in this paper? Just give a number and do not provide any explanation. The critical temperature is sometimes expressed as Tc, T_c, $T_c$, or $T_{c}$. Format the answer as CRITICAL TEMPERATURE: {Number} K. If there are multiple critical temperatures, separate them with commas.", paper_text))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


'CRITICAL TEMPERATURE: 43 K'

In [41]:
generate(format_prompt("What is upper critical field of the material studied in this paper? Format the answer as MAGNETIC FIELD: {Number} T", paper_text))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


'MAGNETIC FIELD: 160 T'