In [None]:
import os
import pandas as pd
import torch
from huggingface_hub import login

from torchtune.models import convert_weights
from torchtune.models.llama3 import llama3_8b, llama3_tokenizer
from torchtune.training.checkpointing._checkpointer import safe_torch_load
from torchtune import utils
from transformers import AutoTokenizer
from torchtune.generation import generate

In [ ]:
def prompt_mimic(note):
    return f"""
    You are given a clinical report. Your task is to identify the diseases that the patient have and list them. The result list should be formatted exactly as [disease1, disease2, ...] and be on a single line.

    Only use diseases from this list:
    ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
    'Lung Opacity', 'Pleural Effusion', 'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices']

    Guidelines:
    Be careful about negation. E.g., do not include Pneumonia if the text says "no Pneumonia"
    Only include diseases that are certain. Entities that are 'likely', 'possibly' should not be returned. 
    After output the list, provide the evidence in the text for each disease.    
    
    **Report to Analyze:**
    {note}

    Expected Output:
    """

In [ ]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

backbone = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(backbone)
model = llama3_8b()
model_state_dict = safe_torch_load("meta_model_0.pt")
model_state_dict = convert_weights.meta_to_tune(model_state_dict)
model.load_state_dict(model_state_dict)
model.to(device)

In [ ]:
def process_row(note):
    prompt_str = prompt_mimic(note)

    tokens = tokenizer.encode(prompt_str)
    prompt = torch.tensor(tokens, dtype=torch.int, device=device)

    generated_disease = ""
    ntry = 0
    while (not generated_disease.startswith("[") or not generated_disease.endswith("]")) and ntry < 2:
        try:
            with torch.no_grad():
                outputs = generate(
                    model=model,
                    prompt=prompt,
                    max_generated_tokens=150,
                    temperature=0,
                    stop_tokens=tokenizer.stop_tokens,
                    pad_id=tokenizer.pad_id,
                )

            raw_output = tokenizer.decode(outputs[0])
            disease_start = raw_output.find("Expected Output:")
            if disease_start != -1:
                generated_disease = raw_output[disease_start + len("Expected Output:"):].strip().splitlines()[0]
            ntry += 1
        except Exception as e:
            print(e)

    result = {
        "note": note,
        "generated_disease": generated_disease,
        "raw_output": raw_output,
    }

    print("Generated diseases: " + generated_disease)
    return result

df = pd.read_csv("radreportx.csv")
print("CSV file read successfully.")

results = []

for index, row in df.iterrows():
    try:
        findings = row['note']
        print(f"Processing row {index}...")
        result = process_row(findings)
        result['study_id'] = row['study_id']
        result['subject_id'] = row['subject_id']
        result['mimic_label'] = row['mimic_label']
        result['negbio_label'] = row['negbio_label']
        result['human_label1'] = row['human_label1']
        result['human_label2'] = row['human_label2']
        results.append(result)
    except Exception as e:
        print(f"Error processing row {index}: {e}")
        print(row['input'])

In [ ]:
findings = f"""
FINDINGS: A cluster of heterogeneous opacities in the right lower lung has 
 has continued to grow since ___. 
   Otherwise, the lungs are clear. Moderate cardiomegaly, including severe left
 atrial enlargement is chronic; there is no pulmonary vascular congestion or
 edema. The thoracic aorta is heavily calcified.  There may be a new small,
 right pleural effusions or pneumothorax.
 IMPRESSION: Slowly progressive chronic right pneumonia, could be exogenous
 lipoid pneumonia, but tuberculosis is in the differential.  CT scanning
 recommended.  Nurse ___ and I discussed the findings and their
 clinical significance by telephone at the time of dictation.
"""

In [ ]:
prompt_str = prompt_mimic(findings)
tokens = tokenizer.encode(prompt_str)
prompt = torch.tensor(tokens, dtype=torch.int, device=device)

generated_disease = ''

In [ ]:
print(prompt_str)

In [ ]:
with torch.no_grad():
    outputs = generate(
        model=model,
        prompt=prompt,
        max_generated_tokens=100,
        temperature=0
        )

In [ ]:
raw_output = tokenizer.decode(outputs[0][0])
print(raw_output)

In [ ]:
# Find the extract disease 
disease_start = raw_output.find("Expected Output:")
if disease_start != -1:
    generated_disease = raw_output[disease_start + len("Expected Output:"):].strip().splitlines()[0]
print(generated_disease)

In [ ]:
# Provide evidence 
evidence = raw_output.find("Evidence:")
generated_evidence = raw_output[evidence + len("Evidence:"):]
print(generated_evidence)