In [2]:
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
data = load_dataset('ShenLab/MentalChat16K')

In [4]:
data = data['train'].train_test_split(test_size=0.1, seed=3407)

In [5]:
test_data = data['test']

In [22]:
from openai import AsyncOpenAI
import asyncio
import random
import logging
import json

# Simplified client setup
def get_client():
    """
    Get client for OpenAI API
    Returns:
        client (AsyncOpenAI): client for OpenAI
    """
    client = AsyncOpenAI(
        base_url="https://mihirathale98--vllm-app-serve.modal.run/v1",
        api_key="super-secret-key",
    )
    return client

async def _make_openai_request(
    client,
    model: str,
    messages: list,
    temperature: float = 1.0,
    max_retries: int = 3,
    initial_delay: float = 10,
    exponential_base: float = 2,
    jitter: bool = True,
):
    """
    Make a request to the OpenAI API with retry logic
    """
    if not messages:
        return {}
    
    num_retries = 0
    delay = initial_delay
    
    for _ in range(max_retries + 1):  # +1 to allow for initial attempt
        try:
            return await client.chat.completions.create(
                model=model,
                messages=messages,
                temperature=temperature,
                max_tokens=2048
            )
        except Exception as e:
            logging.warning(f"Error: {e}")
            num_retries += 1
            
            # Check if max retries has been reached
            if num_retries > max_retries:
                logging.error(f"Maximum number of retries ({max_retries}) exceeded.")
                return {"choices": [{"message": {"content": ""}}]}
            
            # Increment the delay with exponential backoff
            delay *= exponential_base * (1 + jitter * random.random())
            logging.info(f"Retrying in {delay:.2f} seconds (attempt {num_retries}/{max_retries})")
            await asyncio.sleep(delay)
    
    # Fallback if loop exits without returning
    return {"choices": [{"message": {"content": ""}}]}

async def generate_responses(
    prompts: list[str],
    model: str = 'unsloth/Llama-3.1-8B-Instruct',
    temperature: float = 0.0,
):
    """
    Generate responses for a list of prompts in parallel
    
    Args:
        prompts: List of prompts to generate responses for
        model: Model to use for generation
        temperature: Temperature for generation
        
    Returns:
        List of generated responses
    """
    client = get_client()
    
    # Prepare messages for each prompt
    messages = []
    for prompt in prompts:
        if not prompt:
            messages.append({})
            continue
        messages.append([{"role": "user", "content": prompt}])
    
    # Create a list of async tasks
    async_responses = [
        _make_openai_request(
            client,
            model=model,
            messages=message,
            temperature=temperature,
        ) for message in messages
    ]
    
    # Execute all requests in parallel
    responses = await asyncio.gather(*async_responses)
    
    # Process responses
    processed_responses = []
    for response in responses:
        if isinstance(response, dict):
            processed_responses.append(response)
        else:
            processed_responses.append(json.loads(response.model_dump_json(indent=2)))
    
    return processed_responses

# Example usage
async def main():
    prompts = [
        "Write a short poem about the ocean",
        "Explain quantum computing in simple terms",
        "Give me three ideas for a dinner recipe"
    ]
    
    results = await generate_responses(prompts)
    
    # Extract the content from each response
    for i, result in enumerate(results):
        content = ""
        if "choices" in result and result["choices"]:
            content = result["choices"][0]["message"]["content"]
        print(f"Prompt {i+1}: {prompts[i]}")
        print(f"Response: {content}")
        print("-" * 50)


In [7]:
def get_prompt(samples):
    prompts = []
    inputs = samples['input']
    for sample in inputs:
        prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
You are a helpful mental health counselling assistant, please answer the mental health questions based on the patient's description.
The assistant gives helpful, comprehensive, and appropriate answers to the user's questions. Provide a clear and concise answer to the user's problem.

### Input:
{sample}

### Response:
"""
        prompts.append(prompt)
    return {"prompt": prompts}

# Apply the mapping function to the test data
test_data = test_data.map(get_prompt, batched=True, batch_size=8)

Map: 100%|██████████| 1609/1609 [00:00<00:00, 7352.27 examples/s]


In [None]:
test_data = test_data[:500]

In [13]:
with open("test_data.json", "w") as f:
    json.dump(test_data, f)

In [18]:
prompts = test_data["prompt"]

In [19]:
len(prompts)

500

In [23]:
from tqdm import tqdm
batch_size = 32

all_responses = []
for i in tqdm(range(0, len(prompts), batch_size), total=len(prompts)//batch_size):
    batch = prompts[i:i+batch_size]
    responses = await generate_responses(batch)
    all_responses.extend(responses)

Exception ignored in: <function tqdm.__del__ at 0x7f8c73f41b80>
Traceback (most recent call last):
  File "/home/mihirathale/anaconda3/envs/myenv/lib/python3.9/site-packages/tqdm/std.py", line 1148, in __del__
    self.close()
  File "/home/mihirathale/anaconda3/envs/myenv/lib/python3.9/site-packages/tqdm/notebook.py", line 279, in close
    self.disp(bar_style='danger', check_delay=False)
AttributeError: 'tqdm_notebook' object has no attribute 'disp'
16it [03:41, 13.82s/it]


In [24]:
with open("responses_non_finetuned.json", "w") as f:
    json.dump(all_responses, f)