In [32]:
!pip install -q evaluate 
!pip install -q rouge-score

In [7]:
import pandas as pd
import re 
import evaluate

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

from evaluate_model.generate_predictions import run_generation

In [8]:
# Load the model
tokenizer = AutoTokenizer.from_pretrained("jeffreykthomas/llama-mental-health")
model = AutoModelForCausalLM.from_pretrained("jeffreykthomas/llama-mental-health")

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [26]:
# Testing prompt
user_token = '<user>'
agent_token = '<agent>'
prompt = user_token+' Ive been feeling down for weeks now, but I guess its just my dysthymia acting up again. Ill try to keep a positive attitude, but its hard some days. '+agent_token
inputs = tokenizer(prompt, return_tensors="pt")

# Generate
generate_ids = model.generate(inputs.input_ids, max_new_tokens=100)
output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [27]:
output

"<user> Ive been feeling down for weeks now, but I guess its just my dysthymia acting up again. Ill try to keep a positive attitude, but its hard some days. <agent> I'm sorry to hear that you're feeling down. It's normal to feel down. It's normal to feel down because we all feel down at times. It's normal to feel down because we all have a part to play in our lives. It's normal to feel down because we all have a part to play in our communities. It's normal to feel down because we all have a part to play in our communities. It's normal"

In [15]:
#load all csv into a single dataframe
import glob

files = glob.glob('data/datasets/RogueEval/*.csv')
dfs = [pd.read_csv(file) for file in files]
df = pd.concat(dfs, ignore_index = True)

In [18]:
df['text'].iloc[0]

"\n     [INST] <<SYS>> As an expert in mental health, respond with advice to the user message\n    <</SYS>> I'm just feeling so overwhelmed and exhausted all the time. I can't seem to shake this constant feeling of sadness and anxiety. I know I should be happy with my new baby, but I can't help but feel like I'm failing as a mother. I don't know what to do.\n    [/INST]\n    \nI'm really sorry to hear that you're feeling this way. It's important to remember that it's completely normal to feel overwhelmed and exhausted as a new parent, and it's okay to ask for help. Here are some things you can try to help alleviate your feelings of sadness and anxiety:\n\n1. Reach out to your healthcare provider: It's important to let your healthcare provider know how you're feeling. They can help determine if what you're experiencing is related to postpartum depression or anxiety, and provide you with appropriate treatment options.\n2. Practice self-care: Make sure you're taking care of yourself physi

In [129]:
# Removes the system instruction from the text
def remove_sys_instr(text):
    text = text.strip()
    text = text.replace("\n", "")
    clean = re.sub(r'<<SYS>>(.*?)<</SYS>>', '', text)
    return clean

# Extracts the instruction as input
def extract_input(text):
    start = '[INST]'
    end = '[/INST]'
    idx1 = text.index(start)
    idx2 = text.index(end)
    input = text[idx1 + len(start) + 1: idx2]
    return input.strip()

# Extracts the output as reference
def extract_reference(text):
    reference = text.replace(extract_input(text), "")
    reference = reference.replace('[INST]', "")
    reference = reference.replace('[/INST]', "")
    return reference.strip()

def calculate_predictions(text):
    user_token = '<user> '
    agent_token = ' <agent>'
    text = user_token + text + agent_token
    inputs = tokenizer(text, return_tensors = "pt")
    # Generate
    generate_ids = model.generate(inputs.input_ids, max_new_tokens=100)
    output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    return output

def clean_predictions(text):
    #Remove User Input
    text = re.sub(r'<user>(.*?)<agent>', "", text)
    return text

In [116]:
# Seperate Inputs and References as texts
df['text'] = df['text'].apply(lambda x: remove_sys_instr(x))
df['Input'] = df['text'].apply(lambda x: extract_input(x))
df['Reference'] = df['text'].apply(lambda x: extract_reference(x))
df.head()

Unnamed: 0,text,Input,Reference
0,[INST] I can't decide what to wear today. I n...,I can't decide what to wear today. I need you ...,"As an expert in mental health, I would like to..."
1,"[INST] Please help me with this task, I canno...","Please help me with this task, I cannot do it ...",Thank you for reaching out for support. It tak...
2,"[INST] Hey there, I'm really struggling with ...","Hey there, I'm really struggling with my addic...","Hello there! As an expert in mental health, I'..."
3,"[INST] Hi, I'm struggling with my addiction a...","Hi, I'm struggling with my addiction and feel ...",the following is a response to the user messag...
4,"[INST] Hi, I'm really struggling with my addi...","Hi, I'm really struggling with my addiction ri...",Sorry to hear that you're struggling with addi...


In [None]:
# Generate Predictions based on Inputs
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    df['Predictions'] = df['Input'].apply(lambda x: calculate_predictions(x))
    
df.head()

In [131]:
df['Predictions'] = df['Predictions'].apply(lambda x: clean_predictions(x))

In [139]:
df.Predictions[0]

" I'm sorry to hear that you are struggling with this. I'm glad you are reaching out for help. I'm glad you are willing to try on clothes that you may not have been able to find on your own. I think it would be helpful to have a conversation with you about what you are looking for and what you are willing to do to find it. I think it would be helpful to have a conversation with you about what you are looking for and what you are willing"

In [132]:
df.to_csv('predictions.csv')

In [13]:
df = pd.read_csv('data/datasets/Predictions/predictions.csv')
# sample 100 rows
df = df.sample(100)
# Reset index
df.reset_index(drop=True, inplace=True)

In [14]:
# Generate Predictions based on Inputs with a reference model
reference_model = 'Qwen/Qwen1.5-0.5B-Chat'
reference_predictions = run_generation(df['Input'].tolist(), num_batches=25, batch_size=4, model_path=reference_model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100%|██████████| 25/25 [01:44<00:00,  4.18s/it]


In [15]:
rouge = evaluate.load('rouge')

In [26]:
reference_results = rouge.compute(predictions=reference_predictions, references=df['Reference'])

In [27]:
reference_results = [reference_results]
results_df = pd.DataFrame(reference_results)
results_df['model'] = 'Qwen1.5-0.5B-Chat'
results_df.set_index('model', inplace=True)
results_df

Unnamed: 0_level_0,rouge1,rouge2,rougeL,rougeLsum
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Qwen1.5-0.5B-Chat,0.35571,0.107566,0.173733,0.218531


In [22]:
fine_tune_model = 'jeffreykthomas/llama-mental-health'
fine_tune_predictions = run_generation(df['Input'].tolist(), num_batches=100, batch_size=1, model_path=fine_tune_model)

100%|██████████| 100/100 [08:34<00:00,  5.15s/it]


In [28]:
# Add results of non distilled model to the dataframe
results = rouge.compute(predictions=fine_tune_predictions, references=df['Reference'])

In [29]:
new_data = pd.DataFrame([results]) 
new_data.index = ['llama-mental-health']  
results_df = pd.concat([results_df, new_data], axis=0) 
results_df

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
Qwen1.5-0.5B-Chat,0.35571,0.107566,0.173733,0.218531
llama-mental-health,0.249158,0.037047,0.118727,0.118554


In [30]:
# Generate predictions with the distilled model
distilled_model_path = 'jeffreykthomas/llama-mental-health-distill'

In [31]:
distilled_predictions = run_generation(df['Input'].tolist(), num_batches=100, batch_size=1, model_path=distilled_model_path)

  0%|          | 0/100 [00:00<?, ?it/s]

Truncated response detected, retrying...


  3%|▎         | 3/100 [00:05<02:57,  1.83s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


  7%|▋         | 7/100 [00:18<04:52,  3.14s/it]

Truncated response detected, retrying...


  8%|▊         | 8/100 [00:19<03:38,  2.38s/it]

Truncated response detected, retrying...


 16%|█▌        | 16/100 [00:32<01:43,  1.23s/it]

Truncated response detected, retrying...


 21%|██        | 21/100 [00:38<01:27,  1.11s/it]

Truncated response detected, retrying...


 28%|██▊       | 28/100 [00:51<01:55,  1.60s/it]

Truncated response detected, retrying...


 32%|███▏      | 32/100 [00:58<02:07,  1.88s/it]

Truncated response detected, retrying...


 35%|███▌      | 35/100 [01:04<02:15,  2.09s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...


 38%|███▊      | 38/100 [01:13<02:53,  2.79s/it]

Truncated response detected, retrying...


 40%|████      | 40/100 [01:16<02:17,  2.28s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 49%|████▉     | 49/100 [01:30<01:42,  2.00s/it]

Truncated response detected, retrying...


 51%|█████     | 51/100 [01:34<01:28,  1.81s/it]

Truncated response detected, retrying...


 54%|█████▍    | 54/100 [01:35<00:43,  1.07it/s]

Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...


 55%|█████▌    | 55/100 [01:37<00:49,  1.10s/it]

Truncated response detected, retrying...


 56%|█████▌    | 56/100 [01:38<00:54,  1.24s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 59%|█████▉    | 59/100 [01:41<00:38,  1.07it/s]

Truncated response detected, retrying...


 60%|██████    | 60/100 [01:42<00:44,  1.11s/it]

Truncated response detected, retrying...


 62%|██████▏   | 62/100 [01:46<00:54,  1.42s/it]

Truncated response detected, retrying...


 67%|██████▋   | 67/100 [01:49<00:23,  1.40it/s]

Truncated response detected, retrying...
Truncated response detected, retrying...


 70%|███████   | 70/100 [01:53<00:40,  1.34s/it]

Truncated response detected, retrying...


 71%|███████   | 71/100 [01:55<00:39,  1.36s/it]

Truncated response detected, retrying...


 75%|███████▌  | 75/100 [02:03<00:50,  2.01s/it]

Truncated response detected, retrying...


 80%|████████  | 80/100 [02:13<00:53,  2.67s/it]

Truncated response detected, retrying...


 82%|████████▏ | 82/100 [02:17<00:41,  2.28s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 83%|████████▎ | 83/100 [02:17<00:30,  1.80s/it]

Truncated response detected, retrying...


 84%|████████▍ | 84/100 [02:19<00:27,  1.71s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...


 86%|████████▌ | 86/100 [02:24<00:27,  1.96s/it]

Truncated response detected, retrying...


 87%|████████▋ | 87/100 [02:25<00:23,  1.77s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 90%|█████████ | 90/100 [02:27<00:10,  1.07s/it]

Truncated response detected, retrying...


 96%|█████████▌| 96/100 [02:34<00:05,  1.45s/it]

Truncated response detected, retrying...


100%|██████████| 100/100 [02:38<00:00,  1.59s/it]


In [32]:
distill_results = rouge.compute(predictions=distilled_predictions, references=df['Reference'])

In [33]:
new_data = pd.DataFrame([distill_results]) 
new_data.index = ['llama-mental-health-distill']  
results_df = pd.concat([results_df, new_data], axis=0) 
results_df

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
Qwen1.5-0.5B-Chat,0.35571,0.107566,0.173733,0.218531
llama-mental-health,0.249158,0.037047,0.118727,0.118554
llama-mental-health-distill,0.253921,0.045544,0.122479,0.122491


In [34]:
distilled_v2 = 'jeffreykthomas/llama-distill-v2'

In [35]:
v2_predictions = run_generation(df['Input'].tolist(), num_batches=100, batch_size=1, model_path=distilled_model_path)

  6%|▌         | 6/100 [00:09<03:06,  1.99s/it]

Truncated response detected, retrying...


 17%|█▋        | 17/100 [00:23<02:03,  1.48s/it]

Truncated response detected, retrying...


 23%|██▎       | 23/100 [00:32<01:13,  1.05it/s]

Truncated response detected, retrying...


 24%|██▍       | 24/100 [00:33<01:07,  1.12it/s]

Truncated response detected, retrying...


 25%|██▌       | 25/100 [00:34<01:08,  1.09it/s]

Truncated response detected, retrying...


 26%|██▌       | 26/100 [00:34<01:01,  1.19it/s]

Truncated response detected, retrying...


 38%|███▊      | 38/100 [00:55<01:44,  1.69s/it]

Truncated response detected, retrying...


 46%|████▌     | 46/100 [01:02<00:55,  1.03s/it]

Truncated response detected, retrying...


 47%|████▋     | 47/100 [01:03<00:51,  1.03it/s]

Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...


 50%|█████     | 50/100 [01:06<00:48,  1.04it/s]

Truncated response detected, retrying...
Truncated response detected, retrying...


 51%|█████     | 51/100 [01:08<00:56,  1.15s/it]

Truncated response detected, retrying...


 52%|█████▏    | 52/100 [01:09<00:54,  1.14s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 57%|█████▋    | 57/100 [01:20<01:44,  2.44s/it]

Truncated response detected, retrying...


 58%|█████▊    | 58/100 [01:22<01:29,  2.13s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 64%|██████▍   | 64/100 [01:29<00:44,  1.25s/it]

Truncated response detected, retrying...


 65%|██████▌   | 65/100 [01:31<00:47,  1.36s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 67%|██████▋   | 67/100 [01:31<00:27,  1.21it/s]

Truncated response detected, retrying...


 69%|██████▉   | 69/100 [01:37<00:55,  1.79s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 79%|███████▉  | 79/100 [01:52<00:35,  1.67s/it]

Truncated response detected, retrying...


 86%|████████▌ | 86/100 [02:01<00:24,  1.75s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...


 87%|████████▋ | 87/100 [02:03<00:21,  1.65s/it]

Truncated response detected, retrying...


 92%|█████████▏| 92/100 [02:07<00:07,  1.13it/s]

Truncated response detected, retrying...


 94%|█████████▍| 94/100 [02:10<00:07,  1.27s/it]

Truncated response detected, retrying...
Truncated response detected, retrying...
Truncated response detected, retrying...


 96%|█████████▌| 96/100 [02:12<00:04,  1.07s/it]

Truncated response detected, retrying...


100%|██████████| 100/100 [02:17<00:00,  1.37s/it]


In [36]:
v2_results = rouge.compute(predictions=v2_predictions, references=df['Reference'])

In [37]:
new_data = pd.DataFrame([v2_results])
new_data.index = ['llama-distill-v2']
results_df = pd.concat([results_df, new_data], axis=0)
results_df

Unnamed: 0,rouge1,rouge2,rougeL,rougeLsum
Qwen1.5-0.5B-Chat,0.35571,0.107566,0.173733,0.218531
llama-mental-health,0.249158,0.037047,0.118727,0.118554
llama-mental-health-distill,0.253921,0.045544,0.122479,0.122491
llama-distill-v2,0.219288,0.037153,0.106502,0.106749


In [38]:
for i in range(10):
    print(f"Input: {df['Input'][i]}")
    print("\n")
    print(f"Fine Tune Predictions: {fine_tune_predictions[i]}")
    print("\n")
    print(f"Distilled Predictions: {distilled_predictions[i]}")
    print("\n\n")
    print(f"v2 Predictions: {v2_predictions[i]}")

Input: Hey, guys, I'm feeling really down today... I don't know why, but I just can't seem to shake this feeling off... *sigh* I'm so tired of feeling this way... 😔


Fine Tune Predictions: Well, you don't have to feel alone. It's okay to feel sad, too. It's perfectly normal to feel depressed. You're not alone in feeling that way.


Distilled Predictions: Hi there! 💕✨️ I'm glad you're reaching out for help. You can do so much in your life that can make it even more challenging. Feeling depressed is a tough thing to feel, but if you’ve been working on coping with it by learning new skills or practicing something you like, maybe you’ll start to see some improvement. So, you may want to talk to someone about it because it sounds like it’s a big part of your life.



v2 Predictions: You are absolutely right! There is nothing wrong with you if you feel this way. Depression and depression is a common experience for people who have experienced it. Feeling down is an excellent sign that there 

In [17]:
df_responses = pd.DataFrame({'Input': df['Input'], 'Fine Tune Predictions': fine_tune_predictions, 'Distilled Predictions': distilled_predictions})
df_responses.to_csv('data/datasets/RogueEval/llama-mental-health-responses.csv')

In [2]:
import firebase_admin
from firebase_admin import credentials, firestore
import pandas as pd

In [3]:
df_responses = pd.read_csv('data/datasets/RogueEval/llama-mental-health-responses.csv')

In [4]:
df_responses.head()

Unnamed: 0.1,Unnamed: 0,Input,Fine Tune Predictions,Distilled Predictions
0,0,"I don't need your help, I can do everything on...",Your accusation is completely false and your e...,Hi there! Thank you for your question. It soun...
1,1,I can't believe you did that again. How could ...,Hello! I am a Christian woman of many faiths. ...,Hello. Thank you for your question. What you s...
2,2,I just can't seem to find the motivation to do...,What makes you feel that way?What are your fee...,"Hello, and thank you for your question. I know..."
3,3,I know you're plotting against me. Don't think...,It sounds like you are having a hard time acce...,"Hello there, and thank you for your question. ..."
4,4,"Ugh, I can't stop picking at this spot on my f...",Hi! I get the feeling you're still dealing wit...,Hello! This is certainly something that you ma...


In [3]:
cred = credentials.Certificate('rlhf_pipe/jt-designs-79-firebase-adminsdk-r3inm-6192f78083.json')
firebase_admin.initialize_app(cred)
db = firestore.client()

# set previous prompts to not current
prompts_ref = db.collection('prompts')

In [ ]:
# set previous prompts to not current
prompts = prompts_ref.where('current', '==', True).stream()
for prompt in prompts:
	prompt.reference.update({
		'current': False
	})

# add new prompts
batch = db.batch()
for example in df_responses.to_dict(orient='records'):
	doc_ref = db.collection('prompts').document()
	batch.set(doc_ref, {
		'current': True,
		'prompt': example['Input'],
		'responses': [example['Fine Tune Predictions'], example['Distilled Predictions']],
		'votes': [],
		'id': doc_ref.id
	})

batch.commit()

In [8]:
prompts = prompts_ref.where('current', '==', True).stream()
# print the votes for all prompts
total_votes = 0
vote_count = 0
for prompt in prompts:
    if len(prompt.to_dict()['votes']) > 0:
        for vote in prompt.to_dict()['votes']:
            total_votes += 1
            vote_count += vote
            
print(f"Total Votes: {total_votes}")
print(f"Total Vote Count: {vote_count}")
print(f"Average Vote: {vote_count/total_votes}")

Total Votes: 87
Total Vote Count: 60
Average Vote: 0.6896551724137931
