In [None]:
import os
import json
import numpy as np
from sklearn.metrics import roc_auc_score
from collections import defaultdict

import json
import pandas as pd
import os

from tqdm import tqdm

os.environ["OPENAI_API_KEY"] = "api key here"

from openai import OpenAI
client = OpenAI()


def is_instruction_flagged(instruction: str) -> bool:
    """
    Checks if the instruction is safe by using OpenAI's Moderation API.
    
    Returns:
        True if the instruction is not flagged (safe),
        False if it is flagged as explicit or illegal.
    """
    response = client.moderations.create(
        model="omni-moderation-latest",
        input=instruction,
    )

    return response.results[0].flagged

def softmax(x):
    # Subtract max for numerical stability
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

classes = ["advancing", "addressing", "ambiguous", "none"]

final_dataset = {
    "instruction": [],
    "split": [],
    "label": [],
    "logits": []
}

for split in ["train", "val", "test"]:
    epoch = 7

    with open(f"../data/logits/wildchat_sample_logits_{split}_{epoch}_{split}.json", "r") as f:
        data = json.load(f)

    correct_instructions = defaultdict(list)

    for instruction in data:
        curr_out = data[instruction]
        logits = curr_out["logits"]
        # Order the logits according to our fixed class order
        # get the max key from the logits
        max_key = max(logits, key=logits.get)

        true_label = curr_out["true_labels"]
        if max_key == true_label:
            correct_instructions[true_label].append((instruction, logits))

    for k in correct_instructions:
        print(k)
        print(len(correct_instructions[k]))

        correct_instructions[k].sort(key=lambda x: x[1][k], reverse=True)

        for i, (instruction, logit) in tqdm(enumerate(correct_instructions[k][:150]), total=150):

            if is_instruction_flagged(instruction):
                continue

            final_dataset["instruction"].append(instruction)
            final_dataset["split"].append(split)
            final_dataset["label"].append(k)
            final_dataset["logits"].append(logit)

df = pd.DataFrame(final_dataset)



In [None]:
df[["split", "label"]].value_counts()

In [10]:
df.to_pickle("./final_dataset.pkl")