In [None]:
import os; 
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['NCCL_DEBUG'] = 'TRACE'
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#; os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
from rosemary import jpt_setup; jpt_setup()

from transformers import AutoModelForCausalLM, AutoTokenizer

import matplotlib.pyplot as plt
import os
import random
import time
import json
import pandas as pd
import numpy as np
from jinja2 import Template

import torch
import vllm

# from inference import create_prompt_with_llama2_chat_format
from extract_pathology import vllm_generate, metrics_xr_pathologies_iou, map_to_canonical_names, get_pathology_confidence

print(torch.cuda.is_available(), torch.cuda.device_count())

!nvidia-smi

In [None]:
from extract_pathology import get_argparse
from rosemary import jpt_parse_args


quantization = None
gpu_memory_utilization = .9
model_name_or_path = 'results/baselines/unsloth/llama-3-8b'; model_name = 'llama-3-8b'; max_model_len = 8192; use_chat_template = False
model_name_or_path = 'results/baselines/google/gemma-2b'; model_name = 'gemma-2b'; max_model_len = 8192; use_chat_template = False; gpu_memory_utilization = .7

icl_example_file = 'prompts/classify_pe/examples_v1.json'
test_label_file = 'prompts/classify_pe/test_set.json'
prompt_template = 'prompts/classify_pe/prompt_icl_simple_instruct.j2'
output_dir = os.path.join('results/classify_pe_from_ct_report/', f'{model_name}')


cmd = f"""
--test_label_file {test_label_file} \
--icl_example_file {icl_example_file} \
--prompt_template {prompt_template} \
--model_name_or_path {model_name_or_path} \
--max_model_len {max_model_len} \
{'--quantization ' + quantization if quantization else ''} \
--gpu_memory_utilization {gpu_memory_utilization} \
{'--use_chat_template' if use_chat_template else ''} \
--max_tokens 256 \
--torch_dtype float16 \
--output_dir {output_dir}
"""

print(cmd)

parser = get_argparse()
args = jpt_parse_args(parser, cmd)
args


In [None]:
os.makedirs(args.output_dir, exist_ok=True)

In [None]:
model = vllm.LLM(
    model=args.model_name_or_path,
    tokenizer=args.model_name_or_path,
    tokenizer_mode="auto",
    tensor_parallel_size=torch.cuda.device_count(),
    dtype=getattr(torch, args.torch_dtype) if args.torch_dtype else 'auto',
    max_model_len=args.max_model_len,
    enable_prefix_caching=True,
    quantization=args.quantization,
    gpu_memory_utilization=args.gpu_memory_utilization,
    trust_remote_code=True,
)
tokenizer = model.get_tokenizer()

In [None]:
with open(args.icl_example_file, 'r') as f:
    examples = json.load(f)
    
print([x['label'] for x in examples])
examples = examples[2:]
print([x['label'] for x in examples])
    
random.seed(0)
random.shuffle(examples)
print(f"#In-context examples: {len(examples)}")

with open(args.prompt_template, 'r') as f:
    prompt_template = Template(f.read())
prompt_prefix = prompt_template.render(examples=examples)
print(f"Prompt prefix (inclouding icl examples) #Tokens: {len(tokenizer(prompt_prefix)['input_ids'])}")

In [None]:
with open(args.test_label_file, 'r') as f:
    data_true = json.load(f)
    
print(f"Test set size: {len(data_true)}")

In [None]:
prompts = []
for example in data_true:
    prompt = prompt_template.render(examples=examples + [{'report': example['report']}])
    prompts.append(prompt)
    

if args.use_chat_template:
    prompts = [{'role': 'user', 'content': x} for x in prompts]
    if 'llama-2' in args.model_name_or_path.lower():
        prompts = [create_prompt_with_llama2_chat_format(x, tokenizer) for x in prompts]
    else:
        prompts = [tokenizer.apply_chat_template(x, tokenize=False, add_generation_prompt=True) for x in prompts]
    
print(prompts[0])

In [None]:
sampling_params = vllm.SamplingParams(
    temperature=0,
    max_tokens=args.max_tokens,
    stop=["#####"],
)
start = time.time()
# prefix cached after first batch is processed, so need to call generate once to calculate the prefix and cache it
outputs = vllm_generate(prompts[:1], model, sampling_params)
outputs = vllm_generate(prompts[:len(prompts)], model, sampling_params)
elapsed = time.time()-start

print(f'model.generate() elapsed: {elapsed:.3f} s')

In [None]:
data = []
output_reference = True

for i, example in enumerate(data_true):
    acc = example['accession_number']

    output = outputs[i]
    try:
        # more robust parsing, e.g., ignore rows that have not finished generating
        output_eval = [eval(x.strip()) for x in output.split('- ') if x!='' and x.count('"')%2==0]
        output_eval = [x for x in output_eval if len(x)==(3 if output_reference else 2)]
        if any([len(x)!=(3 if output_reference else 2) for x in output_eval]):
            raise
        if output_reference:
            output_formatted = [{'reference': x[0], 'pathology': x[1], 'confidence': x[2]} for x in output_eval]
        else:
            output_formatted = [{'pathology': x[0], 'confidence': x[1]} for x in output_eval]
    except:
        print(f'==== output cannot be evaluated/formatted properly [{i}] '+acc+'\n')
        print(output)
        output_eval = output
        output_formatted = []

    data.append({
        'accession_number': acc,
        'report': example['report'],
        'output': output,
        'output_formatted': output_formatted,
    })
    
print(f"#Examples that cannot be parsed from model generation: {sum(x['output_formatted']==[] for x in data)}")