In [None]:

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorWithPadding

from datasets import DatasetDict, load_dataset

import torch
from torch.utils.data import DataLoader
import gc

import os
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent.parent))

from src import paths

from itertools import chain

import pandas as pd

import tqdm

from typing import Tuple
from vllm import LLM, SamplingParams


MODEL_PATH = paths.MODEL_PATH/'llama2-chat'
QUANTIZATION = "4bit"

SPLIT = "train"

BASE_PROMPT = "<s>[INST]\n<<SYS>>\n{system_prompt}\n<</SYS>>\n\n{user_prompt}[/INST]\n\n{answer_init}"
SYSTEM_PROMP = "Is the MS diagnosis in the text of type \"Sekundär progrediente Multiple Sklerose (SPMS)\", \"primäre progrediente Multiple Sklerose (PPMS)\" or \"schubförmig remittierende Multiple Sklerose (RRMS)\"?"
ANSWER_INIT = "Based on the information provided in the text, the most likely diagnosis for the patient is: "
TRUNCATION_SIZE = 300

BATCH_SIZE = 4
DO_SAMPLE = False
NUM_BEAMS = 1
MAX_NEW_TOKENS = 20
TEMPERATURE = 1
TOP_P = 1
TOP_K = 4
PENALTY_ALPHA = 0.0

def check_gpu_memory():
    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()
        for gpu_id in range(num_gpus):
            free_mem, total_mem = torch.cuda.mem_get_info(gpu_id)
            gpu_properties = torch.cuda.get_device_properties(gpu_id)
            print(f"GPU {gpu_id}: {gpu_properties.name}")
            print(f"   Total Memory: {total_mem / (1024 ** 3):.2f} GB")
            print(f"   Free Memory: {free_mem / (1024 ** 3):.2f} GB")
            print(f"   Allocated Memory : {torch.cuda.memory_allocated(gpu_id) / (1024 ** 3):.2f} GB")
            print(f"   Reserved Memory : {torch.cuda.memory_reserved(gpu_id) / (1024 ** 3):.2f} GB")
    else:
        print("No GPU available.")



In [None]:
# Sample prompts.
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model=, quantization="AWQ")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

In [None]:

# Load Model and tokenizer

def load_model_and_tokenizer(model_path:os.PathLike, quantization:str = QUANTIZATION)->Tuple[AutoModelForCausalLM, AutoTokenizer]:
    """Loads the model and tokenizer from the given path and returns the compiled model and tokenizer.
    
    Args:
        model_path (os.PathLike): Path to the model
        quantization (str, optional): Quantization. Must be one of 4bit or bfloat16. Defaults to QUANTIZATION.

        Returns:
            tuple(AutoModelForCausalLM, AutoTokenizer): Returns the compiled model and tokenizer
            
    """
    # ### Model
    if quantization == "bfloat16":
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", torch_dtype=torch.bfloat16)
    elif quantization == "4bit":
        bnb_config = BitsAndBytesConfig(load_in_4bit=True,
                                        bnb_4bit_use_double_quant=True,
                                        bnb_4bit_quant_type="nf4",
                                        bnb_4bit_compute_dtype=torch.bfloat16,
                                        )
        model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", quantization_config=bnb_config)
    else:
        raise ValueError("Quantization must be one of 4bit or bfloat16")
    
    ### Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")

    # Check if the pad token is already in the tokenizer vocabulary
    if '<pad>' not in tokenizer.get_vocab():
        # Add the pad token
        tokenizer.add_special_tokens({"pad_token":"<pad>"})
    

    #Resize the embeddings
    model.resize_token_embeddings(len(tokenizer))

    #Configure the pad token in the model
    model.config.pad_token_id = tokenizer.pad_token_id

    # Check if they are equal
    assert model.config.pad_token_id == tokenizer.pad_token_id, "The model's pad token ID does not match the tokenizer's pad token ID!"

    # Print the pad token ids
    print('Tokenizer pad token ID:', tokenizer.pad_token_id)
    print('Model pad token ID:', model.config.pad_token_id)
    print('Model config pad token ID:', model.config.pad_token_id)
    print("Vocabulary Size with Pad Token: ", len(tokenizer))

    return torch.compile(model), tokenizer # Compile Model for faster inference. # To-Do https://pytorch.org/blog/pytorch-compile-to-speed-up-inference/


In [None]:

def load_data()->DatasetDict:
    """Loads the data for MS-Diag task and returns the dataset dictionary
    
    Returns:
        DatasetDict: Returns the dataset dictionary
    """

    data_files = {"train": "ms-diag_clean_train.csv", "validation": "ms-diag_clean_val.csv", "test": "ms-diag_clean_test.csv"}

    df = load_dataset(os.path.join(paths.DATA_PATH_PREPROCESSED,'ms-diag'), data_files = data_files)
    
    return df

def prepare_data(df:DatasetDict, tokenizer:AutoTokenizer, split:str=SPLIT, truncation_size:int = TRUNCATION_SIZE)->list[str]:
    """Returns a list of input texts for the classification task
    
    Args:
        df (DatasetDict): Dataset dictionary
        tokenizer (AutoTokenizer): Tokenizer
        split (str, optional): Split. Defaults to SPLIT.
        truncation_size (int, optional): Truncation size. Defaults to TRUNCATION_SIZE.
        
    Returns:
        list(str): Returns a list of input texts for the classification task
    """

    def format_prompt(text:str)->str:
        """Truncates the text to the given truncation size and formats the prompt.
        
        Args:
            text (str): Text
        
        Returns:
            str: Returns the formatted prompt
        """
        if len(text) > truncation_size:
            text = text[:truncation_size]
        else:
            text = text
        input = BASE_PROMPT.format(system_prompt = SYSTEM_PROMP,
                                user_prompt = text,
                                answer_init = ANSWER_INIT)

        return input

    
    # Tokenize the text
    if split == "all":
        text = df["train"]["text"] + df["validation"]["text"] + df["test"]["text"]
    else:
        text = df[split]["text"]

    tokens = [tokenizer(format_prompt(t)) for t in text]

    return tokens

def get_DataLoader(tokens:list[str], tokenizer:AutoTokenizer, batch_size:int = BATCH_SIZE, padding:bool = True)->DataLoader:
    """Returns a DataLoader for the given dataset dictionary
    
    Args:
        tokens (List(str)): List of input texts
        tokenizer (AutoTokenizer): Tokenizer
        batch_size (int, optional): Batch size. Defaults to BATCH_SIZE.
        padding (bool, optional): Padding. Defaults to True.
        
    Returns:
        DataLoader: Returns a DataLoader for the given dataset dictionary
    """

    # Default collate function 
    collate_fn = DataCollatorWithPadding(tokenizer, padding=padding)

    dataloader = torch.utils.data.DataLoader(dataset=tokens, collate_fn=collate_fn, batch_size=batch_size, shuffle = False) 

    return dataloader


In [None]:

def main():

    # Load Data, Model and Tokenizer
    df = load_data()

    print("GPU Memory before Model is loaded:\n")
    check_gpu_memory()
    model, tokenizer = load_model_and_tokenizer(MODEL_PATH, quantization=QUANTIZATION)
    print("GPU Memory after Model is loaded:\n")
    check_gpu_memory()

    # Prepare Data
    tokens = prepare_data(df, tokenizer, split=SPLIT, truncation_size=TRUNCATION_SIZE)

    # Get DataLoader
    dataloader = get_DataLoader(tokens, tokenizer, batch_size=BATCH_SIZE, padding=True)

    # Inference
    outputs = []

    for idx, batch in enumerate(tqdm.tqdm(dataloader)):
            
        torch.cuda.empty_cache()
        gc.collect()
        
        input_ids = batch["input_ids"].to("cuda")
        attention_mask = batch["attention_mask"].to("cuda")
        with torch.inference_mode():
            generated_ids = model.generate(input_ids = input_ids, 
                                           attention_mask = attention_mask,
                                            max_new_tokens=MAX_NEW_TOKENS, 
                                            num_beams=NUM_BEAMS, 
                                            do_sample=DO_SAMPLE, 
                                            temperature = TEMPERATURE, 
                                            num_return_sequences = 1, 
                                            top_p = TOP_P,
                                            top_k = TOP_K,
                                            penalty_alpha = PENALTY_ALPHA).to("cpu")
    
        outputs.append(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
        print("Memory after batch {}:\n".format(idx))
        check_gpu_memory()

    # Save results
    outputs = list(chain.from_iterable(outputs))
    results = [out.split(ANSWER_INIT)[1] for out in outputs]
    
    # Add Arguments as last row to the results
    results.append(str(args))

    file_name = f"ms_diag-llama2-chat_zero-shot_{JOB_ID}.csv"
    pd.Series(results).to_csv(paths.RESULTS_PATH/file_name)

    return

if __name__ == "__main__":
    main()
