In [1]:
import warnings
warnings.filterwarnings('ignore')

### **Installing Dependencies and Libraries**

In [2]:
%%capture
!pip install transformers trl datasets accelerate bitsandbytes peft

### **Logging in HuggingFace**

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## **Loading Data and Preprocessing Data**

In [71]:
from datasets import load_dataset
dataset1 = load_dataset("Amod/mental_health_counseling_conversations")
dataset2 = load_dataset("nbertagnolli/counsel-chat")

Repo card metadata block was not found. Setting CardData to empty.


In [72]:
import pandas as pd
df1 = pd.DataFrame(dataset1["train"])
df2 = pd.DataFrame(dataset2["train"])

In [73]:
df3 = df2[["questionText", "answerText"]]
df3 = df3.rename(columns={"questionText":"Context", "answerText":"Response"})
df3.head(5)

Unnamed: 0,Context,Response
0,I have so many issues to address. I have a his...,It is very common for people to have multiple ...
1,I have so many issues to address. I have a his...,"I've never heard of someone having ""too many i..."
2,I have so many issues to address. I have a his...,Absolutely not. I strongly recommending worki...
3,I have so many issues to address. I have a his...,Let me start by saying there are never too man...
4,I have so many issues to address. I have a his...,I just want to acknowledge you for the courage...


In [74]:
final_df = pd.concat([df3, df1], axis=0)
final_df["instructions"] = '''Given the Patient's Context, provide Response that has a diagnosis of the Patient'''
final_df.head()

Unnamed: 0,Context,Response,instructions
0,I have so many issues to address. I have a his...,It is very common for people to have multiple ...,"Given the Patient's Context, provide Response ..."
1,I have so many issues to address. I have a his...,"I've never heard of someone having ""too many i...","Given the Patient's Context, provide Response ..."
2,I have so many issues to address. I have a his...,Absolutely not. I strongly recommending worki...,"Given the Patient's Context, provide Response ..."
3,I have so many issues to address. I have a his...,Let me start by saying there are never too man...,"Given the Patient's Context, provide Response ..."
4,I have so many issues to address. I have a his...,I just want to acknowledge you for the courage...,"Given the Patient's Context, provide Response ..."


In [75]:
print(f"Final Length of the dataframe: {len(final_df)}")

Final Length of the dataframe: 6287


In [76]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(final_df, test_size=0.2, random_state=42)

In [77]:
print(f"Length of training set: {len(train_df)}")
print(f"Length of testing set: {len(test_df)}")

Length of training set: 5029
Length of testing set: 1258


In [78]:
from datasets import Dataset
conversation_train = Dataset.from_pandas(train_df[:3000])
conversation_test = Dataset.from_pandas(test_df[:500])

In [79]:
print(conversation_train)
print(conversation_test)

Dataset({
    features: ['Context', 'Response', 'instructions', '__index_level_0__'],
    num_rows: 3000
})
Dataset({
    features: ['Context', 'Response', 'instructions', '__index_level_0__'],
    num_rows: 500
})


## **Creating a Gemma Prompt Template**

In [81]:
def formatting_func(example):
    text = f"<start_of_turn>user\n{example['Context'][0]}\n{example['instructions'][0]}\n<end_of_turn> <start_of_turn>model\n{example['Response'][0]}<end_of_turn>"
    return [text]

## **Loading Gemma Model using Quantization**

In [82]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GemmaTokenizer
model_id = "google/gemma-2b"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer=AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    use_flash_attention_2=False,
    device_map="auto",
    torch_dtype=torch.float16,
)

model.config.pretraining_tp=1

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [83]:
from peft import prepare_model_for_kbit_training, get_peft_model
model = prepare_model_for_kbit_training(model)
print(model)

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear4bit(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear4bit(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear4bit(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear4bit(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
    

In [84]:
from peft import LoraConfig

lora_config= LoraConfig(
    r=8,
    lora_dropout=0.1,
    lora_alpha=16,
    bias="none",
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)

## **Comparison in the number of params between original and quantized model**

In [85]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable}\nTotal: {total}\nPercentage: {trainable/total*100:.4f}%")

Trainable: 9805824
Total: 2515978240
Percentage: 0.3897%


## **Training**

In [87]:
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=conversation_train,
    eval_dataset=conversation_test,
    #dataset_text_field="prompt",
    peft_config=lora_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        warmup_steps=0.03,
        max_steps=1000,
        learning_rate=2e-4,
        logging_steps=10,
        output_dir="outputs",
        optim="paged_adamw_8bit",
        save_strategy="no",
        report_to="tensorboard",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
    formatting_func=formatting_func,
)

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [88]:
trainer.train()

Step,Training Loss
10,1.2986
20,0.7134
30,0.1867
40,0.0545
50,0.0416
60,0.0403
70,0.0371
80,0.0344
90,0.0334
100,0.0305


TrainOutput(global_step=1000, training_loss=0.04851990021765232, metrics={'train_runtime': 2971.9706, 'train_samples_per_second': 1.346, 'train_steps_per_second': 0.336, 'total_flos': 1.024292033359872e+16, 'train_loss': 0.04851990021765232, 'epoch': 666.67})

## **Pushing the model to HuggingFace**

In [89]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [90]:
trainer.model.push_to_hub(repo_id="omertafveez/Gemma-TherapyChatBot")

adapter_model.safetensors:   0%|          | 0.00/39.3M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/omertafveez/Gemma-TherapyChatBot/commit/f154ded9bc6ceac82b8a0da53198906f4cbac3be', commit_message='Upload model', commit_description='', oid='f154ded9bc6ceac82b8a0da53198906f4cbac3be', pr_url=None, pr_revision=None, pr_num=None)

In [91]:
model_id = "google/gemma-2b"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    use_flash_attention_2=False,
    device_map="auto",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [92]:
from peft import PeftModel

adapter_model = PeftModel.from_pretrained(model, "omertafveez/Gemma-TherapyChatBot")

adapter_config.json:   0%|          | 0.00/716 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/39.3M [00:00<?, ?B/s]

In [93]:
model2 = adapter_model.merge_and_unload()

In [94]:
model2.push_to_hub(repo_id="omertafveez/Gemma-TherapyChatBot")

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.16G [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/omertafveez/Gemma-TherapyChatBot/commit/6b87a8d0b46e79f8413e0d215c4b5c6182ae426a', commit_message='Upload GemmaForCausalLM', commit_description='', oid='6b87a8d0b46e79f8413e0d215c4b5c6182ae426a', pr_url=None, pr_revision=None, pr_num=None)

In [95]:
tokenizer.push_to_hub(repo_id="omertafveez/Gemma-TherapyChatBot")

Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/omertafveez/Gemma-TherapyChatBot/commit/48e05412ec532766fe9b589290bdf2fba40d4c38', commit_message='Upload tokenizer', commit_description='', oid='48e05412ec532766fe9b589290bdf2fba40d4c38', pr_url=None, pr_revision=None, pr_num=None)

## **Inference**

In [96]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [97]:
model_id = "omertafveez/Gemma-TherapyChatBot"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [98]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    use_flash_attention_2=False,
    device_map="auto",
    torch_dtype=torch.float16,
    use_auth_token=True
)

config.json:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [100]:
def inference_prompt(instruction, context):
    text = f"<start_of_turn>user\n{context}\n{instruction}\n<end_of_turn> <start_of_turn>model"
    return [text]

instruction = "How do I address the feelings of worthlessness?"
context = "I feel sad all the time. Am I worthless?"

formatted_prompt = inference_prompt(instruction, context)

print(formatted_prompt)

['<start_of_turn>user\nI feel sad all the time. Am I worthless?\nHow do I address the feelings of worthlessness?\n<end_of_turn> <start_of_turn>model']


In [103]:
inputs = tokenizer(formatted_prompt, return_tensors="pt")
generate_ids = model.generate(inputs['input_ids'], max_length = 512)
response_with_tokens = tokenizer.decode(generate_ids[0], skip_special_tokens=True)

response_start_idx = response.rfind("model") + len("model")
#response_end_idx = response_with_tokens.find("<end_of_turn", response_start_idx)
actual_response = response_with_tokens[response_start_idx:].strip()

In [106]:
print(actual_response)

Hi New Jersey,You talk about two very big things you've been going through lately; being stuck, and feeling worthless. That's a lot! I'm glad you recognize how sad you've been feeling, and I want you to know how worthy you are. You know all of the things you've done and the ways you've helped others. You're a wonderful mom, a great friend, and a loving kidd...
  michelin mpi
SourceChecksum yako jeste stari ekspert. Moja sestra i ja smo svedice nasil intervalnog alkoholizma na njemu. On tokom intervala je bio bezposrednik, dok se niko nimalo nisu mogli utopiti. On tokom intervala je bio bezposrednik, dok seba i niko nimalo nema. On tokom intervala je bio bezposrednik, dok seba i niko nimalo nema. On tokom intervala jeste sam, nema obicaja nema svedka, nema roditelja, nema alkoholizma ali... ricev leyendo gaunSpoljašnje treachery Apesar do interval, ele ainda ama este abuser e ainda nele confia. E isso mesmo o problema: ele ainda ama este abuser. Ele ainda depende dele e ainda nele confi