In [1]:
from enum import Enum

class SentimentLabel(Enum):
    NEGATIVE = 0
    POSITIVE = 1

In [2]:
from pydantic import BaseModel, Field

class EmotionInfo(BaseModel):
    arousal: float = Field(ge=0, le=1, description="Level of energy/activation in the emotion, from calm (0) to excited (1)")
    valence: float = Field(ge=0, le=1, description="Pleasantness of the emotion, from negative (0) to positive (1)")
    intensity: float = Field(ge=0, le=1, description="Overall strength of the emotional response, from weak (0) to strong (1)")

In [3]:
import json
emotion_info_schema = json.dumps(EmotionInfo.model_json_schema())

In [4]:
!pip install python-dotenv -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [5]:
import anthropic
from anthropic import Anthropic
import os
import json
from dotenv import load_dotenv

load_dotenv()

def get_emotion_info(
        input_text: str, 
        parse_error: str | None = None, 
        previous_output: str | None = None,
        try_count: int = 0,
        max_retries: int = 3
    ):
    client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    
    prompt = f"""Analyze the emotional content of this text and output a JSON object with the following schema:
    {emotion_info_schema}
    
    Only output valid JSON, nothing else.
    
    Text to analyze: {input_text}"""

    if parse_error:
        prompt += f"You already outputted the following JSON, but it was invalid:\n{previous_output}\nValidation errors: {parse_error}\nPlease fix the errors and output a valid JSON."

    message = client.messages.create(
        model="claude-3-5-sonnet-20241022",
        max_tokens=1024,
        messages=[{
            "role": "user", 
            "content": prompt
        }]
    )
    try:
        response_json = json.loads(message.content[0].text)
        return response_json
    except json.JSONDecodeError as e:
        return get_emotion_info(input_text, str(e), message.content[0].text, try_count + 1, max_retries)


## Now at scale

In [6]:
from datasets import load_dataset
ds = load_dataset("nyu-mll/glue", "sst2")

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
print(json.dumps(ds["train"][0], indent=4))

{
    "sentence": "hide new secretions from the parental units ",
    "label": 0,
    "idx": 0
}


In [8]:
async def get_emotion_info_async(
        input_text: str, 
        parse_error: str | None = None, 
        previous_output: str | None = None,
        try_count: int = 0,
        max_retries: int = 3
    ):
    client = anthropic.AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    
    prompt = f"""Analyze the emotional content of this text and output a JSON object with the following schema:
    {emotion_info_schema}
    
    Only output valid JSON, nothing else.
    
    Text to analyze: {input_text}"""

    if parse_error:
        prompt += f"You already outputted the following JSON, but it was invalid:\n{previous_output}\nValidation errors: {parse_error}\nPlease fix the errors and output a valid JSON."

    message = await client.messages.create(
        model="claude-3-5-sonnet-20241022",
        max_tokens=1024,
        messages=[{
            "role": "user", 
            "content": prompt
        }]
    )
    try:
        response_json = json.loads(message.content[0].text)
        return response_json
    except json.JSONDecodeError as e:
        if try_count >= max_retries:
            raise e
        return await get_emotion_info_async(input_text, str(e), message.content[0].text, try_count + 1, max_retries)

In [9]:
os.makedirs("augmented_data", exist_ok=True)

In [11]:
import asyncio

async def process_batch(batch_start, batch_size, ds):
    sem = asyncio.Semaphore(2)  # Limit to 5 concurrent tasks
    augmented_batch = []
    
    async def process_item(item):
        async with sem:
            emotion_info = await asyncio.gather(
                get_emotion_info_async(item["sentence"]),
            )
            return {
                "idx": item["idx"],
                "sentence": item["sentence"],
                "emotion_info": emotion_info,
                "label": item["label"]
            }

    # Create tasks for all items in the batch
    tasks = [
        process_item(ds["train"][i]) 
        for i in range(batch_start, batch_start + batch_size)
    ]
    
    # Wait for all tasks to complete and collect results
    results = await asyncio.gather(*tasks)
    augmented_batch.extend(results)
    
    return augmented_batch


In [12]:
BATCH_SIZE = 20
TOTAL_SIZE = 1000
BATCH_START = 40

for batch_start in range(BATCH_START, TOTAL_SIZE, BATCH_SIZE):
    augmented_batch = await process_batch(batch_start, BATCH_SIZE, ds)
    with open(f"augmented_data/batch_{batch_start}-{batch_start+BATCH_SIZE}.json", "w") as f:
        json.dump(augmented_batch, f)
    # Wait 60 seconds before processing next batch
    