In [2]:
import os
import csv

import openai
import numpy as np
from tqdm import tqdm
from datasets import load_dataset

openai.api_key = os.environ.get("OPENAI_API_KEY")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = load_dataset('glue', 'sst2')
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [5]:
dataset['train'][1]

{'sentence': 'contains no wit , only labored gags ', 'label': 0, 'idx': 1}

In [71]:
PROMPT = """
Classify the sentiment of the sentence into two classes: positive or negative.
Consider the overall tone and specific words used in the sentence.

Example1)
Sentence: A warm, funny, engaging film.
Sentiment: positive

Example2)
Sentence: A three-hour cinema master class.
Sentiment: negative

Example3)
Sentence: An utterly unconvincing plot.
Sentiment: negative

Example4)
Sentence: Brilliantly crafted and remarkably insightful.
Sentiment: positive

Consider the following sentence and classify its sentiment. Think about what words or phrases in the sentence guide your decision:

Sentence: {sentence}
Sentiment: 
""".strip()


def get_probs(sentence):
    try:
        messages = [
            {
                "role": "system",
                "content": "You are an helpful assistant."
            },
            {
                "role": "user",
                "content": PROMPT.format(sentence=sentence)
            }
        ]
        response = openai.ChatCompletion.create(
            model="gpt-4-turbo-preview",
            messages=messages,
            max_tokens=8,
            logprobs=True,
            top_logprobs=5,
            n=1,
            stop=None,
        )
        
        # print("content:", response["choices"][0]["message"]["content"])
        # print("logprobs:", response["choices"][0]["logprobs"]["content"][0]["logprob"])
        # print("top_logprobs:", response["choices"][0]["logprobs"]["content"][0]["top_logprobs"])

        token_prob_dict = {}
        for top_logprob in response["choices"][0]["logprobs"]["content"][0]["top_logprobs"]:
            token = top_logprob["token"].lower().strip()
            logprob = top_logprob["logprob"]
            if token not in token_prob_dict:
                token_prob_dict[token] = np.exp(logprob)
            else:
                token_prob_dict[token] += np.exp(logprob)

        prob_positive = token_prob_dict.get('positive', None)
        prob_negative = token_prob_dict.get('negative', None)

        if prob_negative is None and prob_positive is not None:
            prob_negative = 1 - prob_positive
        elif prob_negative is not None and prob_positive is None:
            prob_positive = 1 - prob_negative
        else:
            prob_positive = 0.5
            prob_negative = 0.5

        return prob_positive, prob_negative
    except Exception as e:
        print(f"Error processing sentence: {sentence}. Error: {e}")
        return None, None
    

prob_positive, prob_negative = get_probs("great movie")
prob_positive, prob_negative

(0.921758289743126, 0.07824171025687399)

In [72]:
max_num = 100

with open('ickd_sst2_probs_n100.csv', 'w', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerow(["index", "sentence", "label", "positive_prob", "negative_prob"])

    for i, instance in enumerate(tqdm(dataset['train'])):
        if i >= max_num:
            break
        index = instance['idx']
        sentence = instance['sentence']
        label = instance['label']
        prob_positive, prob_negative = get_probs(sentence)
        writer.writerow([index, sentence, label, prob_positive, prob_negative])

  0%|          | 100/67349 [02:25<27:06:32,  1.45s/it]
