In [1]:
from transformers import (
    T5Tokenizer,
    AutoTokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    TrainingArguments,
)
import torch
from datasets import load_dataset
from wasabi import msg

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the config
import yaml
with open('config/config_testing.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
dataset = load_dataset(
        config['dataset_vars']['type'], 
        data_dir=config['dataset_vars']['dir'],
        column_names=config['dataset_vars']['column_names']
        )

dataset_train = dataset['train'].select(range(1,501)) # remove first row that contains column names
dataset_eval = dataset['validation'].select(range(1,501)) # remove first row that contains column names

In [4]:
test_text=dataset_eval[0]['input']
expected_output=dataset_eval[0]['relations']

msg.info("Input:")
print(test_text)
msg.info("Expected output:")
print(expected_output)

[38;5;4mℹ Input:[0m
Tricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant. A newborn with massive tricuspid regurgitation, atrial flutter, congestive heart failure, and a high serum lithium level is described. This is the first patient to initially manifest tricuspid regurgitation and atrial flutter, and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy. Sixty-three percent of these infants had tricuspid valve involvement. Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy. It also causes neurologic depression, cyanosis, and cardiac arrhythmia when consumed prior to delivery.
[38;5;4mℹ Expected output:[0m
lithium carbonate @CHEMICAL@ neurologic depression @DISEASE@ @CID@ lithium carbonate @CHEMICAL@ cyanosis @DISEASE@ @CID@ lithium carbonate @CHEMICAL@ cardiac arrhythmia @DISEASE@ @CID@


In [5]:
# Load model and tokenizer
model_name = config['model_name']
device_map = {"": 0}

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(
    model_name,
    device_map=device_map
) # we specificly use T5 for COnfitional generations because it has a language modeling head

In [6]:
# Load model before training
# inference
input_ids = tokenizer(test_text, return_tensors="pt").input_ids.to('cuda') 
outputs = model.generate(input_ids, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

isosorbide dehydrogenase, and lithium carbonate toxicity..,.,,,,.,,.,.,.,.,.,,.,,,.,.,,,.,.,.,.,.,.,.,.,.,..,..s, cinq


In [7]:
# Load model after training
model = T5ForConditionalGeneration.from_pretrained(
    "fine_tune_results/checkpoint-1200",
    device_map=device_map,
    local_files_only=True
)

In [8]:
# inference
input_ids = tokenizer(test_text, return_tensors="pt").input_ids.to('cuda') 
outputs = model.generate(input_ids, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

lithium carbonate @CHEMICAL@ tricuspid regurgitation @DISEASE@ @CID@ lithium carbonate @CHEMICAL@ atrial flutter @DISEASE@ @CID@ lithium carbonate @CHEMICAL@ congenital heart disease @DISEASE@ @CID@
