In [23]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [4]:
!pip install --upgrade pip
!pip install transformers==4.33.2 sentencepiece accelerate

Collecting pip
  Downloading pip-23.3.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-23.3.1
Collecting transformers==4.33.2
  Downloading transformers-4.33.2-py3-none-any.whl.metadata (119 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.9/119.9 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.24.1-py3-none-any.whl.metadata (18 kB)
Collecting huggingface-hub<

In [5]:
import gc
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, PreTrainedModel, PreTrainedTokenizer
from typing import List, Optional

In [8]:
MODEL_PATH = "syzymon/long_llama_code_7b_instruct"
TOKENIZER_PATH = MODEL_PATH
# to reduce GPU memory usage we will use reduced precision
TORCH_DTYPE = torch.bfloat16

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(device)

cuda


In [9]:
# To fit most of the demo parts on a single Google Colab GPU we
# provide a basic unoptimized quantization code
# change to False to disable the quantization
QUANTIZED = True

In [10]:
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)

# unoptimized quantization code for running with free Colab GPU
def load_and_qunatize_model(num_bit: int, model_path):
    print(f"!!!!!WARNING!!!!! The mode will be quantized to {num_bit} bits!\n"
          "This may affect the model performance!")

    !pip3 install huggingface_hub
    !pip3 install bitsandbytes
    !git clone https://github.com/CStanKonrad/long_llama.git
    !cp -r long_llama/src long_llama_code/
    from long_llama_code.modeling_longllama import LongLlamaForCausalLM
    from long_llama_code.configuration_longllama import LongLlamaConfig
    from transformers import AutoConfig
    from accelerate.utils import BnbQuantizationConfig
    from accelerate.utils import load_and_quantize_model
    from accelerate import init_empty_weights
    from huggingface_hub import snapshot_download, hf_hub_download


    cfg = LongLlamaConfig.from_pretrained(model_path)
    cfg.mem_attention_grouping = (1, 1024)
    with init_empty_weights():
        empty_model = LongLlamaForCausalLM(cfg)

    gc.collect()
    if num_bit == 8:
        weights_loc = hf_hub_download(repo_id=MODEL_PATH, filename="quantized/pytorch_model_8bit.bin")
        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold = 6)
    elif num_bit == 4:
        # May give out of RAM on Colab
        weights_loc = snapshot_download(MODEL_PATH) #MODEL_PATH
        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
    else:
        raise ValueError(f"{num_bit} quantization not supported.")

    gc.collect()
    model = load_and_quantize_model(empty_model, weights_location=weights_loc, bnb_quantization_config=bnb_quantization_config, device_map="auto")
    model.eval()
    return model

if not QUANTIZED:
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=TORCH_DTYPE,
        device_map=device,
        trust_remote_code=True,
        # mem_attention_grouping is used
        # to trade speed for memory usage
        # for details, see the section Additional configuration
        # in the Github repository
        mem_attention_grouping=(1, 1024),
    )
    model.eval()
else:
    model = load_and_qunatize_model(8, MODEL_PATH)

(…)truct/resolve/main/tokenizer_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

(…)_7b_instruct/resolve/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

(…)uct/resolve/main/special_tokens_map.json:   0%|          | 0.00/330 [00:00<?, ?B/s]

This may affect the model performance!
[0mCollecting bitsandbytes
  Downloading bitsandbytes-0.41.2.post2-py3-none-any.whl.metadata (9.8 kB)
Downloading bitsandbytes-0.41.2.post2-py3-none-any.whl (92.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.6/92.6 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.41.2.post2
[0mCloning into 'long_llama'...
remote: Enumerating objects: 302, done.[K
remote: Counting objects: 100% (126/126), done.[K
remote: Compressing objects: 100% (104/104), done.[K
remote: Total 302 (delta 47), reused 79 (delta 21), pack-reused 176[K
Receiving objects: 100% (302/302), 1.53 MiB | 24.51 MiB/s, done.
Resolving deltas: 100% (150/150), done.


(…)ode_7b_instruct/resolve/main/config.json:   0%|          | 0.00/1.10k [00:00<?, ?B/s]

pytorch_model_8bit.bin:   0%|          | 0.00/7.01G [00:00<?, ?B/s]

In [103]:
import urllib.request
import tempfile
import shutil
import os

@torch.no_grad()
def load_to_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str):
    tokenized_data = tokenizer(text, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)
    torch.manual_seed(0)
    output = model(input_ids=input_ids)
    memory = output.past_key_values
    return memory


@torch.no_grad()
def generate_with_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, memory, prompt: str, temperature=0.2):
    tokenized_data = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)

    streamer = TextStreamer(tokenizer, skip_prompt=False)

    new_memory = memory

    stop = False
    while not stop:
        output = model(input_ids, past_key_values=new_memory, last_context_length=3072)
        new_memory = output.past_key_values
        assert len(output.logits.shape) == 3
        assert output.logits.shape[0] == 1
        last_logit = output.logits[[0], [-1], :]
        dist = torch.distributions.Categorical(logits=last_logit / temperature)
        next_token = dist.sample()
        if next_token[0] == tokenizer.eos_token_id:
            streamer.put(next_token[None, :])
            streamer.end()
            stop = True
        else:
            input_ids = next_token[None, :]
            streamer.put(input_ids)


PROMPT_PREFIX = (
f"\nPretend you are a psychiatrist. The user will give you disorder guidelines and a vignette, which is a hypothetical scenario.\n"
"Please review the patient's history and symptoms as detailed in the vignette. Refer to the guidelines provided for each condition - \n"
"Q1. Generalized Anxiety Disorder, Q2. Panic Disorder, Q3. Agoraphobia, Q4. Specific Phobia, Q5. Social Anxiety Disorder, Q6. Separation Anxiety Disorder, Q7. Selective Mutism, Q8. Other Anxiety and Fear-Related Disorder, Q9. Unspecified Anxiety and Fear-Related Disorder.\n\n"
)


def construct_question_prompt(question: str):
    prompt = (
        # f"\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can.\n\n"
        "Answer the question below referencing the information from the text above.\n"
        f"Question: {question}\nAnswer: "
    )
    return prompt


def ask_model(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, memory, seed=0):
    tokenized_data = tokenizer(prompt, return_tensors="pt")
    input_ids = tokenized_data.input_ids
    input_ids = input_ids.to(model.device)

    torch.manual_seed(seed)
    generate_with_memory(model, tokenizer, memory, prompt)

In [104]:
try:
    del chatbot
except:
    pass
gc.collect()
torch.cuda.empty_cache()

### Questions about code
We download the instruction tuning files from the long_llama repository and ask the model questions about the implementation.
Each question is asked independently without updating the memory.

In [105]:
from google.colab import drive
drive.mount('/content/drive')

import pandas as pd

path_guideline_shuffled = "/content/drive/MyDrive/Capstone/Anxiety_Disorder_EF_ACG_BWN_BWO_shuffled.txt"
path_guideline = "/content/drive/MyDrive/Capstone/Anxiety_Disorder_EF_ACG_BWN_BWO.txt"
path_vignette = "/content/drive/MyDrive/Capstone/Anxiety_Disorder_Vignettes.csv"

with open(path_guideline_shuffled, 'r', encoding='utf-8') as f:
    guideline = f.read()

vignettes = pd.read_csv(path_vignette)
vignette_1A = vignettes.loc[0, 'Description']
vignette_2 = vignettes.loc[2, 'Description']
vignette_3A = vignettes.loc[3, 'Description']

In [106]:
instruct_dp_1 = guideline + vignette_1A
instruct_dp_2 = guideline + vignette_2

In [107]:
try:
    del fot_memory
except:
    pass
gc.collect()
torch.cuda.empty_cache()
fot_memory = load_to_memory(model, tokenizer, PROMPT_PREFIX + instruct_dp_2)

### Response to Vignette 2 - Panic Disorder

In [109]:
question = """
What disorder from the guidelines is the patient in the vignette most likely associated with? Explain it by referring to the symptoms of the disorder from the guideline and how they relate to the vignette."
"""

In [110]:
prompt = construct_question_prompt(question)
ask_model(model, tokenizer, prompt, fot_memory)

The patient in the vignette is most likely associated with Generalized Anxiety Disorder (GAD). The symptoms of GAD, such as excessive fear and anxiety, and the physiological sensations that accompany them, such as chest pain, dizziness, and tingling sensations, are consistent with the symptoms described by the patient in the vignette. Additionally, the patient's fear of having another episode and the mounting costs of her ER visits suggest a significant impact on her life and a need for treatment.</s>


In [113]:
prompt = construct_question_prompt("Could it not be a Panic Disorder? Reference the symptoms of Panic Disorder and explain why the patient does not have a panic disorder.")
ask_model(model, tokenizer, prompt, fot_memory)

The patient does not have a panic disorder because she does not have recurrent, unexpected, self-limited panic attacks that occur in multiple situations, and she does not have persistent fear of having a panic attack or the possible implications of panic attacks.</s>


Vignette 2 explanation: The individual’s panic attacks are unexpected (i.e., not clearly associated with a specific situation), and the individual’s primary concern is about future attacks.

In [116]:
prompt = construct_question_prompt("The patient's panic attacks are unexpected and the individual's primary concern is about future attacks. Given this and the guideline, does the patient still not have panic disorder? Why not?")
ask_model(model, tokenizer, prompt, fot_memory)

Yes, the patient still does not have panic disorder. The guideline states that panic attacks are unexpected, and the individual's primary concern is about future attacks. The patient's panic attacks are unexpected and the individual's primary concern is about future attacks, which meets the criteria for panic disorder.</s>


### Response to vignette 1 - Generalized Anxiety Disorder.

In [47]:
prompt = construct_question_prompt(question)
ask_model(model, tokenizer, prompt, fot_memory)

Based on the information provided in the vignette, the disorder that aligns with the clinical presentation is Generalized Anxiety Disorder (Q1). The symptoms described in the vignette, such as excessive worry and worry about a variety of things, are characteristic of GAD. The other disorders mentioned in the guidelines, such as Panic Disorder, Social Anxiety Disorder, Selective Mutism, and Unspecified Anxiety and Fear-Related Disorder, do not align with the clinical presentation described in the vignette.</s>
