# Template for Zero-Shot Classification with OpenAI Models

### Import all Modules

In [None]:
import os
import glob
import pprint
import time
import argparse
from ast import literal_eval

from sklearn.metrics import accuracy_score
from tqdm import tqdm

import wandb
import openai
import pandas as pd

from utils import (
    dataset_has_format_errors,
    write_jsonl,
)
from utils_src import task_num_to_task_name, dataset_num_to_dataset_name, plot_count_and_normalized_confusion_matrix, \
    task_to_display_labels, load_dataset_task_prompt_mappings

module_dir = os.path.dirname(os.path.abspath(__file__))

# read API key
with open('src/OpenAI_key.txt') as f:
    openai.api_key = f.readlines()[0]

### Setup Arguments and Data

In [None]:
# Specs WandB and Which Model you want to fine-tune
WANDB_PROJECT_NAME = "chatGPT_template_1"
MODEL_NAME = 'gpt-3.5-turbo-0613'
COMPLETION_RETRIES = 10

In [None]:
# Configuration Variables

# Type of task to run inference on
task = 1  # Choices: [1,2,3,4,5,6]

# Dataset to run inference on
dataset = 1  # Choices: [1, 2, 3, 4]

# Path to the directory to store the generated samples
output_dir = '../../data'

# Random seed to use
seed = 2019

# Path to the directory containing the datasets
data_dir = '../../data'

# Whether to use the full label
not_use_full_labels = False

# Path to the dataset-task mappings file
dataset_task_mappings_fp = os.path.normpath(os.path.join(module_dir, '..', 'dataset_task_mappings.csv'))

# Whether to rewrite the dataframe in OpenAI format
rewrite_df_in_openai = True

# Number of epochs to train the model
n_epochs = 3

# Name of the run
run_name = 'zeroshot_chatGPT_3.5_template'

# Temperature to use when generating text
temp = 0.0

# Fewshot
few_shot = False

# Separation between system and user prompt
system_user_division = 3 

### Define Unitily Functions

## Main Implementation

In [None]:
# Initialize the Weights and Biases run
wandb.init(
    # set the wandb project where this run will be logged
    project=WANDB_PROJECT_NAME,
    name=run_name if run_name != '' else f'{MODEL_NAME}_ds_{dataset}_task_{int(task)}'
                                                    f'_sample_{sample_size}_epochs_{n_epochs}'
                                                    f'_full_label_names_{str(not not_use_full_labels)}'
                                                    f'_temp_{temp}',

    # track hyperparameters and run metadata
    config = {
        "model": MODEL_NAME,
        "dataset": dataset_num_to_dataset_name[int(dataset)],
        "task": task_num_to_task_name[int(task)],
        "epochs": n_epochs,
        "temp": temp
    }
)

### Load and Process Data

In [None]:
dataset_idx, dataset_task_mappings = load_dataset_task_prompt_mappings(
    dataset_num=dataset, task_num=task, dataset_task_mappings_fp=dataset_task_mappings_fp

In [None]:
# Get information specific to the dataset
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
system_prompt = dataset_task_mappings.loc[dataset_idx, 'zero_shot_prompt']
user_prompt_format = dataset_task_mappings.loc[dataset_idx, 'user_prompt']

#system_user_prompt_division_line = 3 if args.task != 3 else 15
system_user_prompt_division_line = args.system_user_division
system_prompt = ('\n'.join(prompt.split('\n')[:-system_user_prompt_division_line])).strip()
user_prompt_format = ('\n'.join(prompt.split('\n')[-system_user_prompt_division_line:])).strip()
print(user_prompt_format)


# Log the system prompt and user_prompt_format as files in wandb
prompts_artifact = wandb.Artifact('prompts', type='prompts')
with prompts_artifact.new_file('system_prompt.txt', mode='w', encoding='utf-8') as f:
    f.write(system_prompt)
with prompts_artifact.new_file('user_prompt_format.txt', mode='w', encoding='utf-8') as f:
    f.write(user_prompt_format)
wandb.run.log_artifact(prompts_artifact)

In [None]:
datasets = load_full_dataset(data_dir=args.data_dir, dataset_num=args.dataset, task_num=args.task)

In [None]:
# Get information specific to the dataset
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
if few_shot:
    prompt = dataset_task_mappings.loc[dataset_idx, 'few_shot_prompt']
else:
    prompt = dataset_task_mappings.loc[dataset_idx, 'zero_shot_prompt']

In [None]:
#system_user_prompt_division_line = 3 if args.task != 3 else 15
    system_user_prompt_division_line = system_user_division
    system_prompt = ('\n'.join(prompt.split('\n')[:-system_user_prompt_division_line])).strip()
    user_prompt_format = ('\n'.join(prompt.split('\n')[-system_user_prompt_division_line:])).strip()
    print(user_prompt_format)

    # Log the system prompt and user_prompt_format as files in wandb
    prompts_artifact = wandb.Artifact('prompts', type='prompts')
    with prompts_artifact.new_file('system_prompt.txt', mode='w', encoding="utf-8") as f:
        f.write(system_prompt)
    with prompts_artifact.new_file('user_prompt_format.txt', mode='w', encoding="utf-8") as f:
        f.write(user_prompt_format)
    wandb.run.log_artifact(prompts_artifact)

In [None]:
datasets = load_full_dataset(data_dir=data_dir, dataset_num=dataset, task_num=task)

In [None]:
preprocessed_output_dir = os.path.join(
    output_dir, 'preprocessed', 'full_name_labels' if not not_use_full_labels else 'single_letter_labels')


In [None]:
for df_name, df in datasets.items():
        print(df_name)
        print(df.head())
        df['completion_label'] = df[label_column].map(
            lambda label: map_label_to_completion(label=label, task_num=task,
                                                  full_label=not not_use_full_labels)
        )
        df['openai_instance_format'] = df.apply(
            lambda row: create_training_example(
                system_prompt=system_prompt, user_prompt_format=user_prompt_format,
                user_prompt_text=row['text'],
                completion=row['completion_label']
            ),
            axis=1
        )
        df['openai_instance_without_completion'] = df['openai_instance_format'].map(lambda x: x['messages'][:-1])

        print(f'Check for errors {df_name} set: ')
        assert not dataset_has_format_errors(df['openai_instance_format'].tolist()), f"Errors found in {df_name}"
        os.makedirs(preprocessed_output_dir, exist_ok= True)
        df.to_csv(os.path.join(preprocessed_output_dir, df_name + '.csv'), index=False)

In [None]:
# Create jsonl file and upload to OpenAI
df_id_metadata = upload_datasets_to_openai(args, datasets)

In [None]:
train_set_name = f'ds_{dataset}__task_{task}_train_set'
model_name = (train_set_name.replace('__', '_')
                .replace('train_set', 'trn')
                .replace('task', 't')
                .replace('_single_letter_labels', '_sl'))

if args.few_shot:
    model_name += "_few_shot"

In [None]:
# Evaluate the model on the evaluation set and store the predictions
print("\n" + "#" * 50)
print("Getting predictions on the evaluation set")
predictions = []

for messages in tqdm(eval_df['openai_instance_without_completion'].tolist()):
    # Retry the completion at least COMPLETION_RETRIES times
    num_retries = 0
    response = None
    while num_retries < COMPLETION_RETRIES and response is None:
        try:
            response = openai.ChatCompletion.create(
            model=full_model_name,
            messages=messages,
            temperature=args.temp,
            n=1
        )
        except Exception as e:
            print('Error getting predictions. Retrying...')
            time.sleep(5)
            num_retries += 1
            if num_retries >= COMPLETION_RETRIES:
                print('Maximum amount of retires reached')
                raise e
    predictions.append(response['choices'][0]['message']['content'])

# Add predictions to df
eval_df['prediction'] = predictions

In [None]:
# Store output
predictions_output_dir = os.path.join(output_dir, 'predictions',
                                        f'dataset_{dataset}_task_{task}')
os.makedirs(predictions_output_dir, exist_ok=True)
#edited after running

datasets[eval_set_name].to_csv(
    os.path.join(predictions_output_dir, f"{fmodel_name}--{run_name}.csv"),
    index=False)


In [None]:
# Get performance metrics--
y_true = eval_df['completion_label']
y_pred = eval_df['prediction']

label_type = 'full_name' if not not_use_full_labels else 'short_name'
display_labels = task_to_display_labels[task][label_type]
labels = display_labels

cm_plot, classification_report, metrics = plot_count_and_normalized_confusion_matrix(
    y_true, y_pred, display_labels, labels, xticks_rotation='horizontal')

In [None]:
# Log metrics
for metric_name, metric_value in metrics.items():
    wandb.log({metric_name: metric_value})

# Log the confusion matrix matplotlib figure
wandb.log({'confusion_matrix': wandb.Image(cm_plot)})

# Log the classification report as an artifact
classification_report = (pd.DataFrame({k: v for k, v in classification_report.items() if k != 'accuracy'})
                            .transpose().reset_index())
wandb.log({'classification_report': wandb.Table(
    dataframe=classification_report)})

classification_report_artifact = wandb.Artifact(
    f'classification_report_{model_name}', type='classification_report')

with classification_report_artifact.new_file('classification_report.txt', mode='w') as f:
    f.write(pprint.pformat(classification_report))

wandb.run.log_artifact(classification_report_artifact)

wandb.finish()