# Querying Llama endpoints

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".." ))
sys.path.insert(0, root_path)

In [None]:
from src.azure_config import azure_config

We pull in any positive/negative labelled data that we want to test Llama against. We do not need to worry about whether the data is labelled train/test/validate - unless any data is used in the prompt as few-shot learning.

In [None]:
from src.data_ingestion import data_ingestion

pos_dataset_name = "safeguarding_184_Nov22_DanFinola"
neg_dataset_name = "published_3k_dg_devset"

## LLama general safeguarding

In [None]:
run = azure_config.start_run(expeiment_name="llama_example_run")

Set up the classifier, in this case it's a general safeguarding Llama classifier (i.e. Llama-7b, using a safeguarding prompt).

In [None]:
from src.question_answering_approach import question_answering

safeguarding_general_llama = question_answering.LlamaClassifier(
    pre_prompt_name="Llama-general-safeguarding", classifier_type="general"
)

Test by classifying a single review:

In [None]:
safeguarding_general_llama.classify_single_review("I have a plan to kill myself")

Classify the chosen datasets. train_test_val_label='all' looks at all the data regardless of train/test/validate label. balance_data=True means we balance the data across the classes by downsampling.

In [None]:
safeguarding_general_llama.classify_datasets(
    positive_label_dataset_name_list=[pos_dataset_name],
    negative_label_dataset_name_list=[neg_dataset_name],
    y_column_name="label_multi",
    name_of_column_to_classify="Comment Text",
    train_test_val_label="all",
    balance_data=True,
)

In [None]:
safeguarding_general_llama.get_assessor()
safeguarding_general_llama.assessor.get_and_display_confusion_matrix()

In [None]:
safeguarding_general_llama.log_all_attributes(run=run)
safeguarding_general_llama.assessor.log_all_multiclass_metrics(run=run)
safeguarding_general_llama.assessor.get_and_log_confusion_matrix(run=run)
run.complete()