In [1]:
import json
with open("data.json", "r") as f:
    data = json.load(f)

In [2]:
from enum import Enum

class MNLILabel(Enum):
    ENTAILMENT = 0
    NEUTRAL = 1
    CONTRADICTION = 2

In [3]:
for item in data:
    print(item["premise"])
    print(item["hypothesis"])
    print(MNLILabel(item["label"]))
    break

Conceptually cream skimming has two basic dimensions - product and geography.
Product and geography are what make cream skimming work.
MNLILabel.NEUTRAL


In [4]:
import anthropic
from anthropic import Anthropic
import os
from dotenv import load_dotenv

load_dotenv()

client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))

message = client.messages.create(
    model="claude-3-5-sonnet-20241022",
    max_tokens=1024,
    messages=[{
        "role": "user",
        "content": "Here is a premise and hypothesis pair. Tell me if the relationship between them is entailment, neutral, or contradiction:\n\n" + 
                  f"Premise: {data[0]['premise']}\n" +
                  f"Hypothesis: {data[0]['hypothesis']}"
    }]
)

In [5]:
print(message.content[0].text)

Let me analyze this carefully:

The relationship between the premise and hypothesis is NEUTRAL.

Here's why:
- The premise states that cream skimming has two basic dimensions: product and geography
- The hypothesis makes a stronger claim, stating that these two elements "make cream skimming work"
- While the premise identifies these as dimensions of cream skimming, it doesn't make any claims about whether these dimensions are what make it successful or functional
- The hypothesis goes beyond the information given in the premise by making a claim about causation/functionality

Since we can't determine from the premise alone whether product and geography are what make cream skimming work (as opposed to just being descriptive dimensions), we cannot say the premise entails the hypothesis, nor does it contradict it. Therefore, the relationship is neutral.


## Extract emotion info from text

In [6]:
from pydantic import BaseModel, Field

class EmotionInfo(BaseModel):
    arousal: float = Field(ge=0, le=1, description="Level of energy/activation in the emotion, from calm (0) to excited (1)")
    valence: float = Field(ge=0, le=1, description="Pleasantness of the emotion, from negative (0) to positive (1)")
    intensity: float = Field(ge=0, le=1, description="Overall strength of the emotional response, from weak (0) to strong (1)")

In [7]:
emotion_info_schema = json.dumps(EmotionInfo.model_json_schema())

In [8]:
import anthropic
from anthropic import Anthropic
import os
import json
from dotenv import load_dotenv

load_dotenv()

def get_emotion_info(
        input_text: str, 
        parse_error: str | None = None, 
        previous_output: str | None = None,
        try_count: int = 0,
        max_retries: int = 3
    ):
    client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    
    prompt = f"""Analyze the emotional content of this text and output a JSON object with the following schema:
    {emotion_info_schema}
    
    Only output valid JSON, nothing else.
    
    Text to analyze: {input_text}"""

    if parse_error:
        prompt += f"You already outputted the following JSON, but it was invalid:\n{previous_output}\nValidation errors: {parse_error}\nPlease fix the errors and output a valid JSON."

    message = client.messages.create(
        model="claude-3-5-sonnet-20241022",
        max_tokens=1024,
        messages=[{
            "role": "user", 
            "content": prompt
        }]
    )
    try:
        response_json = json.loads(message.content[0].text)
        return response_json
    except json.JSONDecodeError as e:
        return get_emotion_info(input_text, str(e), message.content[0].text, try_count + 1, max_retries)


In [37]:

augmented_data = []
for item in data:
    emotion_info_premise = get_emotion_info(item["premise"])
    emotion_info_hypothesis = get_emotion_info(item["hypothesis"])

    augmented_data.append({
        "premise": item["premise"],
        "hypothesis": item["hypothesis"],
        "emotion_info_premise": emotion_info_premise,
        "emotion_info_hypothesis": emotion_info_hypothesis,
        "label": item["label"]
    })

In [38]:
json.dump(augmented_data, open("augmented_data.json", "w"))

## Now at scale

In [9]:
from datasets import load_dataset
ds = load_dataset("nyu-mll/glue", "mnli")

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
print(json.dumps(ds["train"][0], indent=4))

{
    "premise": "Conceptually cream skimming has two basic dimensions - product and geography.",
    "hypothesis": "Product and geography are what make cream skimming work. ",
    "label": 1,
    "idx": 0
}


In [11]:
async def get_emotion_info_async(
        input_text: str, 
        parse_error: str | None = None, 
        previous_output: str | None = None,
        try_count: int = 0,
        max_retries: int = 3
    ):
    client = anthropic.AsyncAnthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
    
    prompt = f"""Analyze the emotional content of this text and output a JSON object with the following schema:
    {emotion_info_schema}
    
    Only output valid JSON, nothing else.
    
    Text to analyze: {input_text}"""

    if parse_error:
        prompt += f"You already outputted the following JSON, but it was invalid:\n{previous_output}\nValidation errors: {parse_error}\nPlease fix the errors and output a valid JSON."

    message = await client.messages.create(
        model="claude-3-5-sonnet-20241022",
        max_tokens=1024,
        messages=[{
            "role": "user", 
            "content": prompt
        }]
    )
    try:
        response_json = json.loads(message.content[0].text)
        return response_json
    except json.JSONDecodeError as e:
        if try_count >= max_retries:
            raise e
        return await get_emotion_info_async(input_text, str(e), message.content[0].text, try_count + 1, max_retries)

In [12]:
os.makedirs("augmented_data", exist_ok=True)

In [13]:
import asyncio

async def process_batch(batch_start, batch_size, ds):
    sem = asyncio.Semaphore(5)  # Limit to 5 concurrent tasks
    augmented_batch = []
    
    async def process_item(item):
        async with sem:
            emotion_info_premise, emotion_info_hypothesis = await asyncio.gather(
                get_emotion_info_async(item["premise"]),
                get_emotion_info_async(item["hypothesis"])
            )
            return {
                "idx": item["idx"],
                "premise": {
                    "text": item["premise"],
                    "emotion_info": emotion_info_premise
                },
                "hypothesis": {
                    "text": item["hypothesis"],
                    "emotion_info": emotion_info_hypothesis
                },
                "label": item["label"]
            }

    # Create tasks for all items in the batch
    tasks = [
        process_item(ds["train"][i]) 
        for i in range(batch_start, batch_start + batch_size)
    ]
    
    # Wait for all tasks to complete and collect results
    results = await asyncio.gather(*tasks)
    augmented_batch.extend(results)
    
    return augmented_batch


In [16]:
BATCH_SIZE = 20
TOTAL_SIZE = 1000
BATCH_START = 20

for batch_start in range(BATCH_START, TOTAL_SIZE, BATCH_SIZE):
    augmented_batch = await process_batch(batch_start, BATCH_SIZE, ds)
    with open(f"augmented_data/batch_{batch_start}-{batch_start+BATCH_SIZE}.json", "w") as f:
        json.dump(augmented_batch, f)
    # Wait 60 seconds before processing next batch
    await asyncio.sleep(60)