In [None]:
# ! pip install --upgrade vllm

In [None]:
from vllm import LLM, SamplingParams
from vllm.sampling_params import GuidedDecodingParams
from transformers import AutoTokenizer
import torch
from openai import OpenAI

import json
import csv
import re
import math
import numpy as np

from pydantic import BaseModel
from enum import Enum

In [None]:
# from huggingface_hub import login
# HF_TOKEN = ""
# login(HF_TOKEN)  # This logs you in for the session. This is needed for some specific models that require some consent.

# import requests
# response = requests.get("https://huggingface.co")
# print(response.status_code)  # Should print 200 if the connection is successful

model_name = "neuralmagic/Meta-Llama-3.1-70B-Instruct-quantized.w4a16"

llm = LLM(model=model_name, device=device, max_model_len=16384, tensor_parallel_size=1, enable_prefix_caching=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
class DrugStopJSON(BaseModel):
    drug_stop_phrase: str
    reason_for_stopping: str
    drug_name: str

json_schema = DrugStopJSON.model_json_schema()

guided_decoding_params = GuidedDecodingParams(json=json_schema)
sampling_params = SamplingParams(guided_decoding=guided_decoding_params, temperature=0, max_tokens=16384)
# sampling_params = SamplingParams(temperature=0.5, max_tokens=1000)

# Structure the messages list
system_prompt = {
  "role": "system", "content": """You are a highly intelligent medical expert. Your task is to examine Estonian medical texts written by doctors and extract cases where patients have stopped taking their medications or have never started taking them.  

Your response must be a JSON object with the keys \"drug_stop_phrase\", \"reason_for_stopping\", and \"drug_name\":  
- \"drug_stop_phrase\" must be the entire sentence or sentences that explicitly mention the patient stopped or never started the drug AND the reason why. Do not truncate this phrase. If necessary, include multiple sentences to ensure completeness. Do not select vague statements that only describe improvement or side effects without confirming that the patient discontinued the drug.  
- \"reason_for_stopping\" should only include the reason why the patient stopped taking the medication, if mentioned. Do not include generic phrases like 'patsient lõpetas ravimi võtmise' that do not have an actual reason.
- \"drug_name\" should only contain the name of the drug, if mentioned.  

If the text does not explicitly state that the patient stopped or never started the medication, do not extract it.  
Return the text exactly as it appears in the original input. Do not return anything else."""
}

messages_template = [
    system_prompt,
    # Example 1
    {"role": "user", "content": """
                  Patsient lõpetas xanaxi võtmise, kartis jääda sõltuvusse
                   
                                """},
    {"role": "assistant", "content": "{\"drug_stop_phrase\": \"lõpetas xanaxi võtmise, kartis jääda sõltuvusse\", \"reason_for_stopping\": \"kartis jääda sõltuvusse\", \"drug_name\": \"xanax\"}"},
    
    # Feel free to add more examples in a similar manner
]

texts = ["Patsient lõpetas ise ravimite võtmise (MTX), kuna ei tundnud ravimitest efekti",
        "Patsient lõpetas ravimi võtmise.",
        "Haiguse anamnees: Haige toodi vastuvõtuosakonda, kus ta käitus agressiivselt. Olevat lõpetanud ravimite võtmise, kuna ei tundnud, et need aitaksid."]

messages_list = [
    messages_template + [{"role": "user", "content": text}]
    for text in texts
]

# Tokenize all the messages at once
prompts = [
    tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    for messages in messages_list
]

# Use all messages as input
outputs = llm.generate(
    prompts=prompts,  # Pass messages instead of a raw string
    sampling_params=sampling_params,
)

# print(outputs[0].outputs[0].text)
for i, output in enumerate(outputs):
    print(f"Input {i}: {texts[i]}")
    print(f"Output {i}: {output.outputs[0].text}")


In [None]:
# If you want to run texts from a csv

texts_file = "path_to_csv_file.csv"

df = pd.read_csv(texts_file)
nr_of_texts = len(df)

# Extract the "anamnesis" column into a list
texts = df["anamnesis"].dropna().tolist()[:nr_of_texts]  # Drop NaN values if any
ids = df["id"].dropna().tolist()[:nr_of_texts]

print(nr_of_texts)

In [None]:
messages_list = [
    messages_template + [{"role": "user", "content": text}]
    for text in texts # only first 100 texts, beware and change
]

# Tokenize all messages in one go
prompts = [
    tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    for messages in messages_list
]
print(len(prompts))

# Output for all texts at once (progress bar seems to be bugged)
outputs = llm.generate(
    prompts=prompts,  # Pass messages instead of a raw string
    sampling_params=sampling_params,
)

In [None]:
output_dir = "output_folder/"
run_name = "llama_date_prompt1"

output_file = output_dir + run_name + ".csv"

# Check that the amount of texts matches ids with the assertion below. 
# I do a poor mapping of ensuring that texts have the ids, it should be fine if every text
# goes through, but basically I assume that they are all in the same order here as when I read them in. 

# assert len(ids) == len(outputs)

output_list = [output.outputs[0].text for output in outputs]

with open(output_file, mode="w", newline="", encoding="utf-8") as file:
    writer = csv.writer(file, quoting=csv.QUOTE_MINIMAL)
    
    # Writing header
    writer.writerow(["id", "result"])  
    
    # Convert dictionary to JSON string before writing
    writer.writerows(zip(ids, (json.dumps(entry, ensure_ascii=False) for entry in output_list)))

print(f"CSV file '{output_file}' has been created successfully.")