In [None]:

import torch

import os
import sys
sys.path.append(os.getcwd()+"/../..")
from src import paths

from src.utils import (load_model_and_tokenizer, 
                       load_ms_data,  
                       check_gpu_memory, 
)

import argparse

from transformers import DataCollatorWithPadding, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoModel

from datasets import concatenate_datasets

from tqdm import tqdm

from typing import Callable

import json

In [2]:
check_gpu_memory()

GPU 0: NVIDIA GeForce RTX 2080 Ti
   Total Memory: 10.75 GB
   Free Memory: 10.20 GB
   Allocated Memory : 0.00 GB
   Reserved Memory : 0.00 GB


# LLAMA MedTuned 13B

In [1]:
MODEL_NAME = "Llama2-MedTuned-13b"
model, tokenizer = load_model_and_tokenizer(model_name = MODEL_NAME,
                                            task_type = "clm",
                                            quantization = "4bit")
model.config.use_cache = False
check_gpu_memory()

NameError: name 'load_model_and_tokenizer' is not defined

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def label_encoding(labels:list[str], model, tokenizer)->list[torch.Tensor]:
    """Label encoding of labels
    
    Args:
        labels (list(str)): list of labels
        
    Returns:
        list(torch.Tensor): list of label encodings
            
            
    """
    encodings = {label:[] for label in labels}

    for label in labels:
        input = tokenizer(label, return_tensors = "pt", add_special_tokens = False)
        input.to(device)
        with torch.no_grad():
            outputs = model(**input, output_hidden_states = True)
        last_hidden_state = outputs["hidden_states"][-1]
        last_hidden_state = torch.mean(last_hidden_state, dim = 1)
        encodings[label] = last_hidden_state.to("cpu").squeeze()
        del outputs
        del input

    return encodings

In [6]:
labels = ["primary progressive multiple sclerosis", "secondary progressive multiple sclerosis",
          "relapsing remitting multiple sclerosis","not enough info"]

In [7]:
encoded_labels = label_encoding(labels, model, tokenizer)

In [47]:
# Load data
df_line = load_ms_data("line")
df_all = load_ms_data("all")
df_first_last = load_ms_data("all_first_line_last")

In [7]:
# Trying to make output more consistent by stopping on MS, https://github.com/huggingface/transformers/issues/26959
class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [835, 2799, 4080, 29901]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids
ms_stop = EosListStoppingCriteria(tokenizer("multiple sclerosis", add_special_tokens = False)["input_ids"])

In [23]:
generation_config = GenerationConfig(bos_token_id = 1,
                                     eos_token_id = 2,
                                     pad_token_id = 32000,
                                     use_cache = False,
                                     temperature=1,
                                     top_p=1,
                                     do_sample=False,
                                     output_hidden_states = True,
                                     return_dict_in_generate = True
                                    )
def single_round_inference(reports:list[str], 
                           model:AutoModelForCausalLM, 
                           tokenizer:AutoTokenizer, 
                           format_fun:Callable[str,str], 
                           output_hidden_states:bool = True,
                          max_new_tokens:int = 20)->pd.DataFrame:
    
    """ Single round inference for the MS extraction task
    
    Args:
        reports (list[str]): list of medical reports
        model (AutoModelForCausalLM): model
        tokenizer (AutoTokenizer): tokenizer
        format_fun (Callable[str,str]): function to convert input text to desired prompt format
        
    Returns:
        pd.DataFrame: results of inference
            
    """

    tokens = [tokenizer(format_fun(t), add_special_tokens = False) for t in reports]
    
    collate_fn = DataCollatorWithPadding(tokenizer, padding=True) #padding=True, 'max_length'

    dataloader = torch.utils.data.DataLoader(dataset=tokens, collate_fn=collate_fn, batch_size=1, shuffle = False) 

    generation_config.max_new_tokens = max_new_tokens

    if output_hidden_states:    
        generation_config.output_hidden_states = True
    else:
        output_hidden_states = False
    model.eval()

    results = []
    whole_prompt = []
    last_hidden_states = []
    input_lengths = [len(t["input_ids"]) for t in tokens]

    for batch in tqdm(dataloader):
        batch.to(device)
        with torch.no_grad():
            outputs = model.generate(
                **batch,
                generation_config=generation_config,
                #stopping_criteria=[ms_stop]
            )
        if output_hidden_states:
            for idx in range(len(outputs.sequences)):
                # Find the index of eos_token_id in generated tokens if it exists
                eos_index = torch.where(outputs.sequences[idx] == tokenizer.eos_token_id)[0]
                # If eos_token_id does not exist in generated tokens, set to -1
                eos_index = eos_index[-1] if eos_index.numel() > 0 else -1
    
                # Extract the last hidden states for all the tokens in the output sequence
                # outputs["hidden_states"][:eos_index] is a tuple of tuples of hidden states (one for each layer) for all the generated tokens in the output sequence, it has length of generated sequence
                response_last_hidden_states_tuples = [hidden_state[-1][idx,:,:] for hidden_state in outputs["hidden_states"][:eos_index]]
                mean_last_hidden_states = torch.mean(torch.cat(response_last_hidden_states_tuples), dim=0)
                last_hidden_states.append(mean_last_hidden_states.to("cpu"))
        else:
            last_hidden_states.append([None] * len(outputs.sequences))


        return_tokens = outputs["sequences"].to("cpu")
        batch_result = tokenizer.batch_decode(return_tokens, skip_special_tokens=True)
        whole_prompt.extend(batch_result)
        batch_result = [result.split("[/INST]")[-1].lower().strip() for result in batch_result]

        results.extend(batch_result)
        del outputs

        
    return {"report": reports, 
            "prediction": results, 
            "last_hidden_states": last_hidden_states, 
            "input_lengths":input_lengths,
            "whole_prompt": whole_prompt}

## Zero Shot

### Vanilla

Using the original prompt template of meta Llama2 creators. \<s>[INST]<\<SYS>>{system_prompt}<\</SYS>>{instruction}{input}[/INST]
You should set add special tokens to false for the tokenizer otherwise you will have double bos in the beginning of the prompt, if you state it. Gives more control.

@misc{touvron2023llama,
      title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, 
      author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
      year={2023},
      eprint={2307.09288},
      archivePrefix={arXiv},
      primaryClarbage.

This prompt template builds the foundation to all further strategies, otherwise the model's answers are kinda garbage.

Hidden states is of format hidden_states (tuple(tuple(torch.FloatTensor)), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of torch.FloatTensor of shape (batch_size, generated_length, hidden_size). I will try working with the last hidden state of the first generated token as this is where the model will start it's generation/prediction from..CL}
}

In [24]:
# Llama-2 chat template

def zero_shot_base(report:str)->str:
    """Zero-shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}{input}[/INST]"
    system_prompt =  ("\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
                      "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
                       "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
                        "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
                        "know the answer to a question, please don’t share false information.\n"
                        )
    instruction = ("Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
                    "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
                    "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
                    "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
                    "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
                    "\nHere is the medical report:\n"
                    )
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)

    return input

In [25]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, zero_shot_base)

100%|██████████| 2/2 [00:13<00:00,  6.93s/it]


### Instruction

Based on the paper of the creators of Llama2-MedTuned

@misc{rohanian2023exploring,
      title={Exploring the Effectiveness of Instruction Tuning in Biomedical Language Processing}, 
      author={Omid Rohanian and Mohammadmahdi Nouriborji and David A. Clifton},
      year={2023},
      eprint={2401.00579},
      archivePrefix={arXiv},
      primaryClass={cs

Formulating the task as an instruction is closer to the fine-tuning of the model..CL}
}

In [33]:
def zero_shot_instruction(report:str)->str:
    """Zero-shot instruction for the MS extraction task
    
    Args:
        report (str): medical report
        
        Returns:
            str: reformatted medical report with instruction
            
            """
    # Llama-2 chat template
    instruction_base_prompt = "<s>[INST]\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n[/INST]"
    task_instruction = ("Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
                        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
                        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
                        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
                        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
                        "Here is the medical report: "
                    )
    input = instruction_base_prompt.format(instruction = task_instruction, input =  report)

    return input

In [34]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, zero_shot_instruction)

100%|██████████| 2/2 [00:12<00:00,  6.23s/it]


## Few Shot

Original Paper suggesting this:

@misc{brown2020language,
      title={Language Models are Few-Shot Learners}, 
      author={Tom B. Brown and Benjamin Mann and Nick Ryder and Melanie Subbiah and Jared Kaplan and Prafulla Dhariwal and Arvind Neelakantan and Pranav Shyam and Girish Sastry and Amanda Askell and Sandhini Agarwal and Ariel Herbert-Voss and Gretchen Krueger and Tom Henighan and Rewon Child and Aditya Ramesh and Daniel M. Ziegler and Jeffrey Wu and Clemens Winter and Christopher Hesse and Mark Chen and Eric Sigler and Mateusz Litwin and Scott Gray and Benjamin Chess and Jack Clark and Christopher Berner and Sam McCandlish and Alec Radford and Ilya Sutskever and Dario Amodei},
      year={2020},
      eprint={2005.14165},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

### Vanilla

In [13]:
# Llama-2 chat template

def few_shot_base(report:str)->str:
    """Few Shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}Report:\n{input}\nDiagnosis:\n[/INST]"

    rrms = 'Schubförmig-remittierende Multiple Sklerose, EM 01/2013, ED 10/2015\nINDENT EDSS 05/2020: 2.0 [...]'
    spms = '1. Sekundär progrediente schubförmige Multiple Sklerose [...]'
    ppms = '1. Primär progrediente Multiple Sklerose, EM 1992, ED 1996, aktuell EDSS 7.0 [...]'
    no_ms = '[...] INDENT MRI 07/2014: Progrediente supratentorielle MS-Plaques mit Befund-Progredienz im Bereich der Radiatio optica beidseits. [...]'

    examples = [ppms, spms, rrms, no_ms]

    labels = ["primary progressive multiple sclerosis", 
              "secondary progressive multiple sclerosis",
              "relapsing remitting multiple sclerosis",
              "no multiple sclerosis"]
    
    system_prompt = (
    "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
    "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
    "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
    "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
    "know the answer to a question, please don’t share false information.\n"
    )

    instruction = (
       "Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
        "To help you with your task, here are a few excerpts from reports that indiciate what output you should produce:\n\n"
        )
    
    for example, label in zip(examples, labels):
        instruction += f"Report:\n{example}\nDiagnosis:\n{label}\n\n"
    
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)
    input + "Diagnosis:\n"

    return input

In [14]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, few_shot_base)

100%|██████████| 2/2 [00:12<00:00,  6.04s/it]


### Instruction

In [16]:
def few_shot_instruct(report:str)->str:
    """Few Shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]### Instruction:\n{instruction}### Input:\n{input}\n### Output:\n[/INST]"

    rrms = 'Schubförmig-remittierende Multiple Sklerose, EM 01/2013, ED 10/2015\nINDENT EDSS 05/2020: 2.0 [...]'
    spms = '1. Sekundär progrediente schubförmige Multiple Sklerose [...]'
    ppms = '1. Primär progrediente Multiple Sklerose, EM 1992, ED 1996, aktuell EDSS 7.0 [...]'
    no_ms = '[...] INDENT MRI 07/2014: Progrediente supratentorielle MS-Plaques mit Befund-Progredienz im Bereich der Radiatio optica beidseits. [...]'

    examples = [ppms, spms, rrms, no_ms]

    labels = ["primary progressive multiple sclerosis", 
              "secondary progressive multiple sclerosis",
              "relapsing remitting multiple sclerosis",
              "not enough info"]

    instruction = (
        "Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
        "To help you with your task, here are a few excerpts from reports that indiciate what output you should produce:\n\n"
        )
    
    for example, label in zip(examples, labels):
        instruction += f"### Input:\n{example}\n### Output:\n{label}\n\n"
    
    input = base_prompt.format(instruction = instruction, input =  report)

    return input

In [17]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, few_shot_instruct)

100%|██████████| 2/2 [00:10<00:00,  5.26s/it]


## 2 Steps

In [41]:
def two_steps_one(report: str)->str:
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}{input}[/INST]"
    system_prompt =  ("\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
                      "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
                       "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
                        "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
                        "know the answer to a question, please don’t share false information.\n"
                        )
    instruction = ("Your task is to summarize all relevant information pertaining to the multiple sclerosis diagnosis "
                    "from the provided German medical report. The German word for multiple sclerosis is: \"Multiple Sklerose\", "
                    "watch for this keyword and extract all the text around it, especially words before and after. "
                    "If the report contains no information regarding multiple sclerosis, "
                    "please respond with \"not enough info.\" "
                    "\nHere is the medical report:\n\n"
                   )
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)
    return input

def two_steps_two(chat_history: str)->str:
    base_prompt = "<s>[INST]\n\n{instruction}[/INST]"
    instruction = ("Given your summary of the medical report, which of the following is the most likely label for this report: "
                  "\"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\", "
                   "\"schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\". Your answer should have only consist of one of the mentioned labels."
                   )
    if not chat_history.endswith(tokenizer.eos_token):
        chat_history += tokenizer.eos_token
    input = chat_history + base_prompt.format(instruction = instruction)

    return input
                    

In [38]:
def multi_round_inference(reports:list[str], 
                           model:AutoModelForCausalLM, 
                           tokenizer:AutoTokenizer, 
                           format_fun1:Callable[str,str],
                          format_fun2:Callable[str,str],
                           output_hidden_states:bool = True,
                          max_new_tokens:int = 20)->pd.DataFrame:
    
    """Multi Round inference for the MS extraction task
    
    Args:
        reports (list[str]): list of medical reports
        model (AutoModelForCausalLM): model
        tokenizer (AutoTokenizer): tokenizer
        format_fun1 (Callable[str,str]): function to convert input text to desired prompt format
        format_fun2 (Callable[str,str]): function to convert chat history to desired prompt format
        output_hidden_states (bool); whether hidden states should be calculated. Defaults to True
        max_new_tokens (int): The number of tokens to be generated.
        
    Returns:
        pd.DataFrame: results of inference
            
    """

    output_round1 = single_round_inference(reports, model, tokenizer, format_fun1, output_hidden_states = False, max_new_tokens = 2)
    chat_history = output_round1["whole_prompt"]

    return single_round_inference(chat_history, model, tokenizer, format_fun2, output_hidden_states = output_hidden_states, max_new_tokens = max_new_tokens)

In [6]:
results = multi_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, two_steps_one, two_steps_two, max_new_tokens = 2)

NameError: name 'df_line' is not defined

# LLAMA MedTuned 7B

In [1]:
MODEL_NAME = "Llama2-MedTuned-7b"
model, tokenizer = load_model_and_tokenizer(model_name = MODEL_NAME,
                                            task_type = "clm",
                                            quantization = "4bit")
model.config.use_cache = False
check_gpu_memory()

NameError: name 'load_model_and_tokenizer' is not defined

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def label_encoding(labels:list[str], model, tokenizer)->list[torch.Tensor]:
    """Label encoding of labels
    
    Args:
        labels (list(str)): list of labels
        
    Returns:
        list(torch.Tensor): list of label encodings
            
            
    """
    encodings = {label:[] for label in labels}

    for label in labels:
        input = tokenizer(label, return_tensors = "pt", add_special_tokens = False)
        input.to(device)
        with torch.no_grad():
            outputs = model(**input, output_hidden_states = True)
        last_hidden_state = outputs["hidden_states"][-1]
        last_hidden_state = torch.mean(last_hidden_state, dim = 1)
        encodings[label] = last_hidden_state.to("cpu").squeeze()
        del outputs
        del input

    return encodings

In [6]:
labels = ["primary progressive multiple sclerosis", "secondary progressive multiple sclerosis",
          "relapsing remitting multiple sclerosis","not enough info"]

In [7]:
encoded_labels = label_encoding(labels, model, tokenizer)

In [47]:
# Load data
df_line = load_ms_data("line")
df_all = load_ms_data("all")
df_first_last = load_ms_data("all_first_line_last")

In [7]:
# Trying to make output more consistent by stopping on MS, https://github.com/huggingface/transformers/issues/26959
class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [835, 2799, 4080, 29901]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids
ms_stop = EosListStoppingCriteria(tokenizer("multiple sclerosis", add_special_tokens = False)["input_ids"])

In [23]:
generation_config = GenerationConfig(bos_token_id = 1,
                                     eos_token_id = 2,
                                     pad_token_id = 32000,
                                     use_cache = False,
                                     temperature=1,
                                     top_p=1,
                                     do_sample=False,
                                     output_hidden_states = True,
                                     return_dict_in_generate = True
                                    )
def single_round_inference(reports:list[str], 
                           model:AutoModelForCausalLM, 
                           tokenizer:AutoTokenizer, 
                           format_fun:Callable[str,str], 
                           output_hidden_states:bool = True,
                          max_new_tokens:int = 20)->pd.DataFrame:
    
    """ Single round inference for the MS extraction task
    
    Args:
        reports (list[str]): list of medical reports
        model (AutoModelForCausalLM): model
        tokenizer (AutoTokenizer): tokenizer
        format_fun (Callable[str,str]): function to convert input text to desired prompt format
        
    Returns:
        pd.DataFrame: results of inference
            
    """

    tokens = [tokenizer(format_fun(t), add_special_tokens = False) for t in reports]
    
    collate_fn = DataCollatorWithPadding(tokenizer, padding=True) #padding=True, 'max_length'

    dataloader = torch.utils.data.DataLoader(dataset=tokens, collate_fn=collate_fn, batch_size=1, shuffle = False) 

    generation_config.max_new_tokens = max_new_tokens

    if output_hidden_states:    
        generation_config.output_hidden_states = True
    else:
        output_hidden_states = False
    model.eval()

    results = []
    whole_prompt = []
    last_hidden_states = []
    input_lengths = [len(t["input_ids"]) for t in tokens]

    for batch in tqdm(dataloader):
        batch.to(device)
        with torch.no_grad():
            outputs = model.generate(
                **batch,
                generation_config=generation_config,
                #stopping_criteria=[ms_stop]
            )
        if output_hidden_states:
            for idx in range(len(outputs.sequences)):
                # Find the index of eos_token_id in generated tokens if it exists
                eos_index = torch.where(outputs.sequences[idx] == tokenizer.eos_token_id)[0]
                # If eos_token_id does not exist in generated tokens, set to -1
                eos_index = eos_index[-1] if eos_index.numel() > 0 else -1
    
                # Extract the last hidden states for all the tokens in the output sequence
                # outputs["hidden_states"][:eos_index] is a tuple of tuples of hidden states (one for each layer) for all the generated tokens in the output sequence, it has length of generated sequence
                response_last_hidden_states_tuples = [hidden_state[-1][idx,:,:] for hidden_state in outputs["hidden_states"][:eos_index]]
                mean_last_hidden_states = torch.mean(torch.cat(response_last_hidden_states_tuples), dim=0)
                last_hidden_states.append(mean_last_hidden_states.to("cpu"))
        else:
            last_hidden_states.append([None] * len(outputs.sequences))


        return_tokens = outputs["sequences"].to("cpu")
        batch_result = tokenizer.batch_decode(return_tokens, skip_special_tokens=True)
        whole_prompt.extend(batch_result)
        batch_result = [result.split("[/INST]")[-1].lower().strip() for result in batch_result]

        results.extend(batch_result)
        del outputs

        
    return {"report": reports, 
            "prediction": results, 
            "last_hidden_states": last_hidden_states, 
            "input_lengths":input_lengths,
            "whole_prompt": whole_prompt}

## Zero Shot

### Vanilla

Using the original prompt template of meta Llama2 creators. \<s>[INST]<\<SYS>>{system_prompt}<\</SYS>>{instruction}{input}[/INST]
You should set add special tokens to false for the tokenizer otherwise you will have double bos in the beginning of the prompt, if you state it. Gives more control.

@misc{touvron2023llama,
      title={Llama 2: Open Foundation and Fine-Tuned Chat Models}, 
      author={Hugo Touvron and Louis Martin and Kevin Stone and Peter Albert and Amjad Almahairi and Yasmine Babaei and Nikolay Bashlykov and Soumya Batra and Prajjwal Bhargava and Shruti Bhosale and Dan Bikel and Lukas Blecher and Cristian Canton Ferrer and Moya Chen and Guillem Cucurull and David Esiobu and Jude Fernandes and Jeremy Fu and Wenyin Fu and Brian Fuller and Cynthia Gao and Vedanuj Goswami and Naman Goyal and Anthony Hartshorn and Saghar Hosseini and Rui Hou and Hakan Inan and Marcin Kardas and Viktor Kerkez and Madian Khabsa and Isabel Kloumann and Artem Korenev and Punit Singh Koura and Marie-Anne Lachaux and Thibaut Lavril and Jenya Lee and Diana Liskovich and Yinghai Lu and Yuning Mao and Xavier Martinet and Todor Mihaylov and Pushkar Mishra and Igor Molybog and Yixin Nie and Andrew Poulton and Jeremy Reizenstein and Rashi Rungta and Kalyan Saladi and Alan Schelten and Ruan Silva and Eric Michael Smith and Ranjan Subramanian and Xiaoqing Ellen Tan and Binh Tang and Ross Taylor and Adina Williams and Jian Xiang Kuan and Puxin Xu and Zheng Yan and Iliyan Zarov and Yuchen Zhang and Angela Fan and Melanie Kambadur and Sharan Narang and Aurelien Rodriguez and Robert Stojnic and Sergey Edunov and Thomas Scialom},
      year={2023},
      eprint={2307.09288},
      archivePrefix={arXiv},
      primaryClarbage.

This prompt template builds the foundation to all further strategies, otherwise the model's answers are kinda garbage.

Hidden states is of format hidden_states (tuple(tuple(torch.FloatTensor)), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) — Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of torch.FloatTensor of shape (batch_size, generated_length, hidden_size). I will try working with the last hidden state of the first generated token as this is where the model will start it's generation/prediction from..CL}
}

In [24]:
# Llama-2 chat template

def zero_shot_base(report:str)->str:
    """Zero-shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}{input}[/INST]"
    system_prompt =  ("\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
                      "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
                       "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
                        "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
                        "know the answer to a question, please don’t share false information.\n"
                        )
    instruction = ("Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
                    "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
                    "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
                    "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
                    "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
                    "\nHere is the medical report:\n"
                    )
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)

    return input

In [25]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, zero_shot_base)

100%|██████████| 2/2 [00:13<00:00,  6.93s/it]


### Instruction

Based on the paper of the creators of Llama2-MedTuned

@misc{rohanian2023exploring,
      title={Exploring the Effectiveness of Instruction Tuning in Biomedical Language Processing}, 
      author={Omid Rohanian and Mohammadmahdi Nouriborji and David A. Clifton},
      year={2023},
      eprint={2401.00579},
      archivePrefix={arXiv},
      primaryClass={cs

Formulating the task as an instruction is closer to the fine-tuning of the model..CL}
}

In [33]:
def zero_shot_instruction(report:str)->str:
    """Zero-shot instruction for the MS extraction task
    
    Args:
        report (str): medical report
        
        Returns:
            str: reformatted medical report with instruction
            
            """
    # Llama-2 chat template
    instruction_base_prompt = "<s>[INST]\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Output:\n[/INST]"
    task_instruction = ("Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
                        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
                        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
                        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
                        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
                        "Here is the medical report: "
                    )
    input = instruction_base_prompt.format(instruction = task_instruction, input =  report)

    return input

In [34]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, zero_shot_instruction)

100%|██████████| 2/2 [00:12<00:00,  6.23s/it]


## Few Shot

Original Paper suggesting this:

@misc{brown2020language,
      title={Language Models are Few-Shot Learners}, 
      author={Tom B. Brown and Benjamin Mann and Nick Ryder and Melanie Subbiah and Jared Kaplan and Prafulla Dhariwal and Arvind Neelakantan and Pranav Shyam and Girish Sastry and Amanda Askell and Sandhini Agarwal and Ariel Herbert-Voss and Gretchen Krueger and Tom Henighan and Rewon Child and Aditya Ramesh and Daniel M. Ziegler and Jeffrey Wu and Clemens Winter and Christopher Hesse and Mark Chen and Eric Sigler and Mateusz Litwin and Scott Gray and Benjamin Chess and Jack Clark and Christopher Berner and Sam McCandlish and Alec Radford and Ilya Sutskever and Dario Amodei},
      year={2020},
      eprint={2005.14165},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

### Vanilla

In [13]:
# Llama-2 chat template

def few_shot_base(report:str)->str:
    """Few Shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}Report:\n{input}\nDiagnosis:\n[/INST]"

    rrms = 'Schubförmig-remittierende Multiple Sklerose, EM 01/2013, ED 10/2015\nINDENT EDSS 05/2020: 2.0 [...]'
    spms = '1. Sekundär progrediente schubförmige Multiple Sklerose [...]'
    ppms = '1. Primär progrediente Multiple Sklerose, EM 1992, ED 1996, aktuell EDSS 7.0 [...]'
    no_ms = '[...] INDENT MRI 07/2014: Progrediente supratentorielle MS-Plaques mit Befund-Progredienz im Bereich der Radiatio optica beidseits. [...]'

    examples = [ppms, spms, rrms, no_ms]

    labels = ["primary progressive multiple sclerosis", 
              "secondary progressive multiple sclerosis",
              "relapsing remitting multiple sclerosis",
              "no multiple sclerosis"]
    
    system_prompt = (
    "\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
    "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
    "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
    "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
    "know the answer to a question, please don’t share false information.\n"
    )

    instruction = (
       "Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
        "To help you with your task, here are a few excerpts from reports that indiciate what output you should produce:\n\n"
        )
    
    for example, label in zip(examples, labels):
        instruction += f"Report:\n{example}\nDiagnosis:\n{label}\n\n"
    
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)
    input + "Diagnosis:\n"

    return input

In [14]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, few_shot_base)

100%|██████████| 2/2 [00:12<00:00,  6.04s/it]


### Instruction

In [16]:
def few_shot_instruct(report:str)->str:
    """Few Shot base for the MS extraction task

    Args:
        report (str): medical report

    Returns:
        str: reformatted medical report with base

    """
    base_prompt = "<s>[INST]### Instruction:\n{instruction}### Input:\n{input}\n### Output:\n[/INST]"

    rrms = 'Schubförmig-remittierende Multiple Sklerose, EM 01/2013, ED 10/2015\nINDENT EDSS 05/2020: 2.0 [...]'
    spms = '1. Sekundär progrediente schubförmige Multiple Sklerose [...]'
    ppms = '1. Primär progrediente Multiple Sklerose, EM 1992, ED 1996, aktuell EDSS 7.0 [...]'
    no_ms = '[...] INDENT MRI 07/2014: Progrediente supratentorielle MS-Plaques mit Befund-Progredienz im Bereich der Radiatio optica beidseits. [...]'

    examples = [ppms, spms, rrms, no_ms]

    labels = ["primary progressive multiple sclerosis", 
              "secondary progressive multiple sclerosis",
              "relapsing remitting multiple sclerosis",
              "not enough info"]

    instruction = (
        "Your task is to extract the type of multiple Sclerosis (MS) stated in a German medical report. There are 3 types: "
        "primär progrediente Multiple Sklerose (PPMS), sekundär progrediente Multiple Sklerose (SPMS) and schubförmige Multiple Sklerose (RRMS)."
        "The type is provided in the text you just have to extract it. If you cannot match a type exactly answer with \"not enough info\"."
        "Your answer should solely consist of either \"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\" "
        "\schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\"."
        "To help you with your task, here are a few excerpts from reports that indiciate what output you should produce:\n\n"
        )
    
    for example, label in zip(examples, labels):
        instruction += f"### Input:\n{example}\n### Output:\n{label}\n\n"
    
    input = base_prompt.format(instruction = instruction, input =  report)

    return input

In [17]:
results = single_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, few_shot_instruct)

100%|██████████| 2/2 [00:10<00:00,  5.26s/it]


## 2 Steps

In [41]:
def two_steps_one(report: str)->str:
    base_prompt = "<s>[INST]<<SYS>>{system_prompt}<</SYS>>\n\n{instruction}{input}[/INST]"
    system_prompt =  ("\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
                      "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
                       "Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make "
                        "any sense, or is not factually coherent, explain why instead of answering something not correct. If you don’t "
                        "know the answer to a question, please don’t share false information.\n"
                        )
    instruction = ("Your task is to summarize all relevant information pertaining to the multiple sclerosis diagnosis "
                    "from the provided German medical report. The German word for multiple sclerosis is: \"Multiple Sklerose\", "
                    "watch for this keyword and extract all the text around it, especially words before and after. "
                    "If the report contains no information regarding multiple sclerosis, "
                    "please respond with \"not enough info.\" "
                    "\nHere is the medical report:\n\n"
                   )
    input = base_prompt.format(system_prompt = system_prompt, instruction = instruction, input =  report)
    return input

def two_steps_two(chat_history: str)->str:
    base_prompt = "<s>[INST]\n\n{instruction}[/INST]"
    instruction = ("Given your summary of the medical report, which of the following is the most likely label for this report: "
                  "\"primär progrediente Multiple Sklerose (PPMS)\", \"sekundär progrediente Multiple Sklerose (SPMS)\", "
                   "\"schubförmige Multiple Sklerose (RRMS)\", or \"not enough info\". Your answer should have only consist of one of the mentioned labels."
                   )
    if not chat_history.endswith(tokenizer.eos_token):
        chat_history += tokenizer.eos_token
    input = chat_history + base_prompt.format(instruction = instruction)

    return input
                    

In [38]:
def multi_round_inference(reports:list[str], 
                           model:AutoModelForCausalLM, 
                           tokenizer:AutoTokenizer, 
                           format_fun1:Callable[str,str],
                          format_fun2:Callable[str,str],
                           output_hidden_states:bool = True,
                          max_new_tokens:int = 20)->pd.DataFrame:
    
    """Multi Round inference for the MS extraction task
    
    Args:
        reports (list[str]): list of medical reports
        model (AutoModelForCausalLM): model
        tokenizer (AutoTokenizer): tokenizer
        format_fun1 (Callable[str,str]): function to convert input text to desired prompt format
        format_fun2 (Callable[str,str]): function to convert chat history to desired prompt format
        output_hidden_states (bool); whether hidden states should be calculated. Defaults to True
        max_new_tokens (int): The number of tokens to be generated.
        
    Returns:
        pd.DataFrame: results of inference
            
    """

    output_round1 = single_round_inference(reports, model, tokenizer, format_fun1, output_hidden_states = False, max_new_tokens = 2)
    chat_history = output_round1["whole_prompt"]

    return single_round_inference(chat_history, model, tokenizer, format_fun2, output_hidden_states = output_hidden_states, max_new_tokens = max_new_tokens)

In [6]:
results = multi_round_inference(df_line["train"].filter(lambda e: e["labels"] != 3).select(range(2))["text"], model, tokenizer, two_steps_one, two_steps_two, max_new_tokens = 2)

NameError: name 'df_line' is not defined

# Leo Mistral 7B

In [3]:
model_name = "leo-mistral-hessianai-7b-chat"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# For these models running with CPU mem of 10GB not enough, with 30GB it works, maybe try 15GB should be enough to fit the largest shard which is about 
# 9.5 GB

In [2]:
# model_name = "leo-hessianai-7b"

In [5]:
# Low precision config
print("Memory before Model is loaded:\n")
check_gpu_memory()
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(paths.MODEL_PATH/model_name, 
                                             device_map="auto", 
                                             quantization_config = bnb_config, 
                                            # attn_implementation="flash_attention_2"
                                            )
print("Memory after Model is loaded:\n")
check_gpu_memory()

Memory before Model is loaded:

GPU 0: NVIDIA GeForce RTX 2080 Ti
   Total Memory: 10.75 GB
   Free Memory: 10.20 GB
   Allocated Memory : 0.00 GB
   Reserved Memory : 0.00 GB


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Memory after Model is loaded:

GPU 0: NVIDIA GeForce RTX 2080 Ti
   Total Memory: 10.75 GB
   Free Memory: 5.69 GB
   Allocated Memory : 4.35 GB
   Reserved Memory : 4.51 GB


In [6]:
# For mistral
tokenizer = AutoTokenizer.from_pretrained(
    paths.MODEL_PATH/model_name,
    padding_side="left",
    truncation_side = "left",
    add_eos_token=True)

tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
# Load data
data_files = {"train": "ms-diag_clean_train.csv", "validation": "ms-diag_clean_val.csv", "test": "ms-diag_clean_test.csv"}
df = load_dataset(os.path.join(paths.DATA_PATH_PREPROCESSED,'ms-diag'), data_files = data_files)
#df = df.map(preprocess, remove_columns=["rid", "date", "text"])

In [43]:
ppms_example = df["train"].filter(lambda e: e["labels"] == "primary_progressive_multiple_sclerosis")[0]
spms_example = df["train"].filter(lambda e: e["labels"] == "secondary_progressive_multiple_sclerosis")[0]
rrms_example = df["train"].filter(lambda e: e["labels"] == "relapsing_remitting_multiple_sclerosis")[0]

Filter:   0%|          | 0/123 [00:00<?, ? examples/s]

Filter:   0%|          | 0/123 [00:00<?, ? examples/s]

In [74]:
ppms_example["text"]

'V.a. primär progrediente Multiple Sklerose, EM 08/2016, ED 10/2018, EDSS 4.5 INDENT aktuell: klinisch: nicht aktiv, radiologisch: unklar, Progression: ja (nach Lublin 2013) INDENT Verlauf:  INDENT 08/2016: Schwäche und Trauma Fuss links mit Fuss, Trauma mit Bimalleolarluxationsfraktur OSG links, postoperativ progrediente Zunahme der Schwäche des linken Fusses INDENT 02/2019: Zunahme der Schwäche und Steifigkeit des linken Beines, seitdem weiter progredienter Verlauf INDENT 10/2020: leichte Schwäche und Steifigkeitsgefühl des rechten Oberarmes INDENT klinisch: INDENT diagnostisch: laborchemisch: INDENT LP vom 27.11.2018 (Spital Bülach): 1 Zelle/ul, Protein normal, keine Schrankenstörung, OKB positiv Bildgebend: INDENT MR BWS-LWS 10/2018 (Spital Bülach): Mehrere T2w hyperintense Signalalteraltionen im posterioren und lateralen Funiculus rechts und singulär im lateralen Funiculus links des thorakalen Myelons, es ergibt sich der Verdacht auf demyelinisierende Plaques. Keine Schrankenstöru

In [168]:
system_prompt = """Dies ist eine Unterhaltung zwischen einem intelligenten, hilfsbereitem KI-Assistenten und einem Nutzer.
Der Assistent gibt ausführliche, hilfreiche und ehrliche Antworten."""
few_shot_prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{prompt1}<|im_end|>
<|im_start|>assistant
{reply1}<|im_end|>
<|im_start|>user
{prompt2}<|im_end|>
<|im_start|>assistant
{reply2}<|im_end|>
<|im_start|>user
{prompt3}<|im_end|>
<|im_start|>assistant
{reply3}<|im_end|>
<|im_start|>user
{prompt4}<|im_end|>
<|im_start|>assistant
"""
zero_shot_prompt = f"""
<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{prompt1}<|im_end|>
<|im_start|>assistant
"""

one_shot_prompt = f"""<|im_start|>system
{system_prompt}<|im_end|>
<|im_start|>user
{prompt1}<|im_end|>
<|im_start|>assistant
{reply1}<|im_end|>
<|im_start|>user
{prompt2}<|im_end|>
<|im_start|>assistant
"""

prompt1 = "Was ist die MS Diagnose in diesem Text: " + ppms_example["text"][:80]
reply1 = "Primär progrediente Multiple Sklerose"
prompt2 = "Was ist die MS Diagnose in diesem Text: " + spms_example["text"][:80]
reply2 = "Sekundär progrediente Multiple Sklerose"
prompt3 = "Was ist die MS Diagnose in diesem Text: " + rrms_example["text"][:80]
reply3 = "Schubförmig remittierende Multiple Sklerose"
prompt4 = "Was ist die MS Diagnose in diesem Text: " + df["train"]["text"][5][:80]
def format_few_shot(system_prompt, prompt1, reply1, prompt2, reply2, prompt3, reply3, prompt4):
    input = few_shot_prompt.format(system_prompt,
                               prompt1,
                              reply1,
                              prompt2,
                              reply2,
                              prompt3,
                              reply3,
                              prompt4)
    return tokenizer(input, return_tensors = "pt")

def format_one_shot(system_prompt, prompt1, reply1, prompt2):
    input = one_shot_prompt.format(system_prompt,
                               prompt1,
                              reply1,
                              prompt2,
                              )
    print(len(input))
    print(input)
    return tokenizer(input, return_tensors = "pt")

def format_zero_shot(system_prompt, prompt1):
    input = zero_shot_prompt.format(system_prompt, prompt1)
    return tokenizer(input, return_tensors = "pt")

In [176]:
encoded_labels = tokenizer(["Primär progrediente Multiple Sklerose", "Sekundär progrediente Multiple Sklerose", "Schubförmig remittierende Multiple Sklerose"], add_special_tokens=False)["input_ids"]
encoded_bad_words = tokenizer(["user"], add_special_tokens = False)["input_ids"]

In [177]:
DisjunctiveConstraint(encoded_labels)

<transformers.generation.beam_constraints.DisjunctiveConstraint at 0x1466cec758e0>

In [158]:
[len(input) for input in encoded_labels]

[11, 13, 14]

In [184]:
prompt_encoded = format_few_shot(system_prompt, prompt1, reply1, prompt2, reply2, prompt3, reply3, prompt4)
# prompt_encoded = format_zero_shot(system_prompt, prompt1)
# prompt_encoded = format_one_shot(system_prompt, prompt1, reply1, prompt2)
return_tokens = model.generate(**prompt_encoded, max_new_tokens=50, temperature = 0, bad_words_ids = encoded_bad_words, num_beams = 2)
print(tokenizer.batch_decode(return_tokens, skip_special_tokens=True))

['system\nDies ist eine Unterhaltung zwischen einem intelligenten, hilfsbereitem KI-Assistenten und einem Nutzer.\nDer Assistent gibt ausführliche, hilfreiche und ehrliche Antworten. \n user\nWas ist die MS Diagnose in diesem Text: V.a. primär progrediente Multiple Sklerose, EM 08/2016, ED 10/2018, EDSS 4.5 IND \n assistant\nPrimär progrediente Multiple Sklerose \n user\nWas ist die MS Diagnose in diesem Text: Multiple Sklerose mit sekundär progredientem Verlauf seit ca. 2004 (EM 1983, ED  \n assistant\nSekundär progrediente Multiple Sklerose \n user\nWas ist die MS Diagnose in diesem Text: Schubförmige Multiple Sklerose (EM 09/2015, ED 11/2015), EDSS 0,0 Anamnestisch I \n assistant\nSchubförmig remittierende Multiple Sklerose \n user\nWas ist die MS Diagnose in diesem Text: Primär progrediente Multiple Sklerose, EM ca. 2010, ED 06/2016  INDENT EDSS 07/2 \n assistant\n assistant\nWähle A, B, C oder D als deine Lösung.\\n\\nEine kürzlich durchgeführte Studie mit 10.000 Personen ergab, d

In [10]:
tokens = [tokenizer(get_classification_llama(t)) for t in df["train"]["text"]]

# Default collate function 
collate_fn = DataCollatorWithPadding(tokenizer, padding=True) #padding=True, 'max_length'

dataloader = torch.utils.data.DataLoader(dataset=tokens, collate_fn=collate_fn, batch_size=2, shuffle = False) 

It seems as though reserved memory is extremely high when using beam search. If I have longer input sequences this will lead to out of memory issues. I will try to set number of tokens to a lower number and check if beam search works then. I truncate the text directly because if I truncate after the prompt insertion I will loose the end of the prompt.

In [None]:
outputs = []
print("Memory Consumption before loop\n")
check_gpu_memory()
for idx, batch in enumerate(dataloader):
    
    torch.cuda.empty_cache()
    gc.collect()

    print("Memory Consumption before Batch: ", idx)
    check_gpu_memory()
    
    input_ids = batch["input_ids"].to("cuda")
    attention_mask = batch["attention_mask"].to("cuda")
    with torch.inference_mode():
        generated_ids = model.generate(input_ids = input_ids, attention_mask = attention_mask, max_new_tokens=20, num_beams=1, do_sample=True, temperature = 0.9, num_return_sequences = 1, top_p = 0.6).to("cpu")
    outputs.append(tokenizer.batch_decode(generated_ids, skip_special_tokens=True))
    break
    outputs

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Memory Consumption before loop

GPU 0: NVIDIA GeForce RTX 2080 Ti
   Total Memory: 10.75 GB
   Free Memory: 5.92 GB
   Allocated Memory : 4.26 GB
   Reserved Memory : 4.27 GB
Memory Consumption before Batch:  0
GPU 0: NVIDIA GeForce RTX 2080 Ti
   Total Memory: 10.75 GB
   Free Memory: 5.92 GB
   Allocated Memory : 3.77 GB
   Reserved Memory : 4.27 GB


In [76]:
check_gpu_memory()
torch.cuda.empty_cache()
gc.collect()
check_gpu_memory()

GPU 0: Tesla V100-SXM2-32GB
   Total Memory: 31.74 GB
   Free Memory: 22.73 GB
   Allocated Memory : 3.70 GB
   Reserved Memory : 8.04 GB
GPU 0: Tesla V100-SXM2-32GB
   Total Memory: 31.74 GB
   Free Memory: 26.85 GB
   Allocated Memory : 3.70 GB
   Reserved Memory : 3.91 GB


In [77]:
from itertools import chain
outputs = list(chain.from_iterable(outputs))
pd.Series(outputs).to_csv(paths.RESULTS_PATH/'ms_diag-llama2-chat_zero_shot-shortened300_beam2.csv')

In [79]:
results = [out.split("\nBased on the information provided in the text, the most likely diagnosis for the patient is:")[1] for out in outputs]

In [86]:
set(df["train"]["labels"])

{'primary_progressive_multiple_sclerosis',
 'relapsing_remitting_multiple_sclerosis',
 'secondary_progressive_multiple_sclerosis'}

In [87]:
# Dictionary to map keywords to labels
keyword_label_mapping = {
    "RRMS": 'relapsing_remitting_multiple_sclerosis',
    "SPMS": 'secondary_progressive_multiple_sclerosis',
    "PPMS": 'primary_progressive_multiple_sclerosis',
}

# Function to assign labels based on text content
def assign_label(text):
    for keyword, label in keyword_label_mapping.items():
        if keyword in text:
            return label
    return "unknown"  # Default label if no keyword is found

# Assign labels to each text in the list
labels = [assign_label(text) for text in results]

In [99]:
correct = 0
for i in range(len(labels)):
    if labels[i] == df["train"]["labels"][i]:
        correct += 1
correct/len(labels)

0.6016260162601627