In [None]:
%pip install datasets

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.2-py3-none-any.whl (485 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading x

In [None]:
import json
import os
import time
import pandas as pd
from tqdm import tqdm
import requests
from openai import OpenAI
from datasets import load_dataset
import random
import concurrent.futures

# Configuration
OUTPUT_DIR = "summaries"
MODEL = "gpt-4o"  # Using GPT-4o as specified
MAX_EXAMPLES = 1000    # Number of examples to summarize
BATCH_SIZE = 20        # Process in batches to handle API rate limits
MAX_WORKERS = 5  # Parallel workers


from google.colab import userdata
API_KEY = userdata.get('OPENAI_KEY')
# API_KEY = "your-openai-api-key"  # Replace with your actual API key

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize OpenAI client
client = OpenAI(api_key=API_KEY)

In [None]:
# Function to load the EmpatheticDialogue dataset
def load_empathetic_dialogue():
    dataset = load_dataset("empathetic_dialogues")
    return dataset

# Function to extract and format conversations more efficiently
def extract_and_format_conversations(dataset_split, max_convs=1500):
    # Get unique conversation IDs
    all_conv_ids = list(set(dataset_split['conv_id']))

    # Select a subset of conversation IDs
    if len(all_conv_ids) > max_convs:
        selected_ids = set(random.sample(all_conv_ids, max_convs))
    else:
        selected_ids = set(all_conv_ids)

    # Group by conversation ID in a single pass
    conversation_map = {}

    for i in tqdm(range(len(dataset_split['conv_id']))):
        conv_id = dataset_split['conv_id'][i]

        if conv_id not in selected_ids:
            continue

        if conv_id not in conversation_map:
            conversation_map[conv_id] = {
                "conv_id": conv_id,
                "context": dataset_split['context'][i],
                "utterances": [],
                "prompt": dataset_split['prompt'][i]
            }

        # Add the utterance to the conversation
        conversation_map[conv_id]["utterances"].append({
            "speaker_idx": dataset_split['speaker_idx'][i],
            "utterance": dataset_split['utterance'][i],
            "utterance_idx": dataset_split['utterance_idx'][i]
        })

    # Format conversations for the model
    formatted_conversations = []

    for conv_id, conv_data in conversation_map.items():
        # Sort utterances by index
        conv_data["utterances"].sort(key=lambda x: x["utterance_idx"])

        # Format the conversation text
        formatted_text = f"Context: {conv_data['context']}\n\nPrompt: {conv_data['prompt']}\n\nConversation:\n"

        for utt in conv_data["utterances"]:
            speaker = f"Speaker {utt['speaker_idx']}"
            formatted_text += f"{speaker}: {utt['utterance']}\n"

        formatted_conversations.append({
            "conv_id": conv_id,
            "formatted_text": formatted_text,
            "context": conv_data["context"],
            "prompt": conv_data["prompt"]
        })

    return formatted_conversations

# Function to generate summaries in parallel
def generate_summaries_parallel(conversations, batch_size=10, max_workers=4):
    """Generate summaries using parallel processing."""
    all_results = []

    def process_conversation(conv):
        try:
            prompt = f"""
Please summarize the following empathetic dialogue. Focus on capturing:
1. The emotional context and situation
2. Key points of the conversation
3. The empathetic responses provided

Keep the summary concise (2-3 sentences) while preserving the emotional essence.

{conv["formatted_text"]}

Summary:
"""
            response = client.chat.completions.create(
                model=MODEL,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=150,
                temperature=0.3,
            )

            summary = response.choices[0].message.content.strip()

            return {
                "conv_id": conv["conv_id"],
                "context": conv["context"],
                "prompt": conv["prompt"],
                "conversation": conv["formatted_text"],
                "summary": summary
            }
        except Exception as e:
            print(f"Error generating summary for conversation {conv['conv_id']}: {e}")
            time.sleep(1)  # Brief pause on error
            return None

    # Process in batches to avoid overwhelming the API
    for i in range(0, len(conversations), batch_size):
        batch = conversations[i:i+batch_size]
        print(f"Processing batch {i//batch_size + 1}/{(len(conversations) + batch_size - 1)//batch_size}")

        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_conv = {executor.submit(process_conversation, conv): conv for conv in batch}

            # Process results as they complete
            for future in tqdm(concurrent.futures.as_completed(future_to_conv), total=len(batch), desc="Processing"):
                result = future.result()
                if result:
                    all_results.append(result)

        # Save intermediate results
        with open(f"{OUTPUT_DIR}/summaries_batch_{i//batch_size}.json", "w") as f:
            json.dump(all_results[-(len(batch)):], f, indent=2)

        # Brief pause between batches
        if i + batch_size < len(conversations):
            time.sleep(1)

    return all_results



print("Loading EmpatheticDialogue dataset...")
dataset = load_empathetic_dialogue()

# Use the training split
train_data = dataset["train"][:15000]

# Extract and format conversations in one step
print("Extracting and formatting conversations...")
all_conversations = extract_and_format_conversations(train_data, max_convs=MAX_EXAMPLES)
print(f"Extracted {len(all_conversations)} conversations")

# Generate summaries in parallel
print("Generating summaries...")
all_summaries = generate_summaries_parallel(
    all_conversations,
    batch_size=BATCH_SIZE,
    max_workers=MAX_WORKERS
)

# Save all summaries to a CSV file
summaries_df = pd.DataFrame(all_summaries)
summaries_df.to_csv(f"{OUTPUT_DIR}/all_summaries.csv", index=False)

# Also save as JSON for backup
with open(f"{OUTPUT_DIR}/all_summaries.json", "w") as f:
    json.dump(all_summaries, f, indent=2)

print(f"Summarization complete. Generated {len(all_summaries)} summaries.")
print(f"Results saved to {OUTPUT_DIR}/all_summaries.csv")

Loading EmpatheticDialogue dataset...
Extracting and formatting conversations...


100%|██████████| 15000/15000 [00:00<00:00, 738745.95it/s]


Extracted 1000 conversations
Generating summaries...
Processing batch 1/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.05it/s]


Processing batch 2/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.90it/s]


Processing batch 3/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.41it/s]


Processing batch 4/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.35it/s]


Processing batch 5/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.93it/s]


Processing batch 6/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.30it/s]


Processing batch 7/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.34it/s]


Processing batch 8/50


Processing: 100%|██████████| 20/20 [00:13<00:00,  1.47it/s]


Processing batch 9/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.60it/s]


Processing batch 10/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.36it/s]


Processing batch 11/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.97it/s]


Processing batch 12/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.13it/s]


Processing batch 13/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.39it/s]


Processing batch 14/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.95it/s]


Processing batch 15/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.29it/s]


Processing batch 16/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.24it/s]


Processing batch 17/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.86it/s]


Processing batch 18/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.75it/s]


Processing batch 19/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.98it/s]


Processing batch 20/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.88it/s]


Processing batch 21/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.27it/s]


Processing batch 22/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.79it/s]


Processing batch 23/50


Processing: 100%|██████████| 20/20 [00:09<00:00,  2.09it/s]


Processing batch 24/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.57it/s]


Processing batch 25/50


Processing: 100%|██████████| 20/20 [00:10<00:00,  1.83it/s]


Processing batch 26/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.83it/s]


Processing batch 27/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.31it/s]


Processing batch 28/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.70it/s]


Processing batch 29/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.36it/s]


Processing batch 30/50


Processing: 100%|██████████| 20/20 [00:11<00:00,  1.82it/s]


Processing batch 31/50


Processing: 100%|██████████| 20/20 [00:08<00:00,  2.26it/s]


Processing batch 32/50


Processing: 100%|██████████| 20/20 [00:09<00:00,  2.08it/s]


Processing batch 33/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.19it/s]


Processing batch 34/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.93it/s]


Processing batch 35/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.10it/s]


Processing batch 36/50


Processing: 100%|██████████| 20/20 [00:07<00:00,  2.65it/s]


Processing batch 37/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  2.94it/s]


Processing batch 38/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.31it/s]


Processing batch 39/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.12it/s]


Processing batch 40/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.42it/s]


Processing batch 41/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.00it/s]


Processing batch 42/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.14it/s]


Processing batch 43/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.01it/s]


Processing batch 44/50


Processing: 100%|██████████| 20/20 [00:06<00:00,  3.06it/s]


Processing batch 45/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.99it/s]


Processing batch 46/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.35it/s]


Processing batch 47/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.55it/s]


Processing batch 48/50


Processing: 100%|██████████| 20/20 [00:05<00:00,  3.39it/s]


Processing batch 49/50


Processing: 100%|██████████| 20/20 [00:11<00:00,  1.80it/s]


Processing batch 50/50


Processing: 100%|██████████| 20/20 [00:09<00:00,  2.14it/s]

Summarization complete. Generated 1000 summaries.
Results saved to summaries/all_summaries.csv



