In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from openai import OpenAI
from tqdm import tqdm
import pickle
import json

### Get the Counsel Chat Dataset

Extract 5 questions from each topic

In [None]:
dataset_name = "nbertagnolli/counsel-chat"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42)

In [None]:
dataset_df = dataset.to_pandas()
dataset_df.head()

In [None]:
dataset_df_filt = dataset_df[['topic', 'questionTitle', 'questionText', 'answerText', 'upvotes']].groupby('topic', group_keys=False).apply(lambda x: x.sort_values(['upvotes'], ascending=False)[:5]).reset_index(drop=True)
dataset_df_filt = dataset_df_filt.fillna('')
dataset_df_filt

### Use OpenAI to generate synthetic data

We need to generate a question pair from OpenAI by giving a relevant example from the CounselChat Dataset as one-shot instruction tuning.

In [None]:
with open("../../api.key", 'r') as file:
    openai_api_key = file.read()
    
openai_client = OpenAI(api_key=openai_api_key)

In [None]:
def get_openai_response(system_prompt: str, user_prompt: str) -> str:
    
    completion = openai_client.chat.completions.create(
    model="gpt-4o",
    temperature=1,
    messages=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
        ]
    )

    openai_response = completion.choices[0].message.content
    
    return openai_response

In [None]:
system_prompt = 'You are an expert mental-health counsellor'
user_prompt = '''You are given a broad topic which covers a specific area in which humans suffer from ill mental health.
You job is to generate a topic relevant question/answer pair with question describing the mental state of the patient and answer describing the counselling advice given to the patient.

Topic: {topic}

Example:
Question-> {question}
Answer-> {answer}

You must return response in a json serializable format as following {{question: question_text, answer:answer_text}}
'''

In [None]:
# openai_responses = []

# for index, row in tqdm(dataset_df_filt.iterrows(), total=len(dataset_df_filt)):
    
#     topic = row['topic']
#     question = row['questionText'] + row['questionTitle']
#     answer = row['answerText']
    
#     response = get_openai_response(system_prompt=system_prompt, user_prompt=user_prompt.format(topic=topic, question=question, answer=answer))
    
#     openai_responses.append(response)

# with open('openai_response.pkl', 'wb') as file:
#     pickle.dump(openai_responses, file)

In [None]:
with open('openai_response.pkl', 'rb') as file:
    openai_responses = pickle.load(file)

In [None]:
openai_responses

### Inference from already fine-tuned model

In [None]:
model_id = "llama32-sft-fine-tune-counselchat"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.padding_side = "left"
tokenizer.model_max_length = 2048

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto") # Must be float32 for MacBooks!
model.config.pad_token_id = tokenizer.pad_token_id # Updating the model config to use the special pad token

In [None]:
input0 = [{"role": "user", "content": dataset[0]['questionText']}]
input1 = [{"role": "user", "content": dataset[1]['questionText']}]
input2 = [{"role": "user", "content": dataset[2]['questionText']}]
input3 = [{"role": "user", "content": dataset[3]['questionText']}]
input4 = [{"role": "user", "content": dataset[4]['questionText']}]


texts = tokenizer.apply_chat_template([input0, input1, input2, input3, input4], tokenize=False, add_generation_prompt=True)
inputs = tokenizer(texts, padding="longest", truncation=True, return_tensors="pt")
inputs = {key: val.to(model.device) for key, val in inputs.items()}
temp_texts = tokenizer.batch_decode(inputs['input_ids'], skip_special_tokens=True)

In [None]:
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [None]:
gen_tokens = model.generate(
    **inputs, 
    max_new_tokens=2048, 
    pad_token_id=tokenizer.pad_token_id, 
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.6,
    top_p=0.9
)

In [None]:
gen_text = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)
gen_text = [i[len(temp_texts[idx]):] for idx, i in enumerate(gen_text)]

In [None]:
gen_text