In [None]:
import _base_path
import pickle
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm 

from resources.embedding import Embedding, EmbeddingTfIdf
from resources.evaluator import EvaluatorConformalSimple, EvaluatorMaxK
from resources.data_io import load_mappings, load_data

from TrainerClassic import TrainerClassic

In [None]:
DATA      = 'incidents'
TEXTS     = 'title'
LABEL     = 'product'
MODEL     = 'tfidf-lr'

N_SAMPLES = 2
MIN_K     = 1
P         = .40

K         = [5, 10]
ALPHAS    = [.05, .15, .40]

In [None]:
mappings = load_mappings(f"../data/{DATA}/splits/", LABEL)
mappings

array(['Aquatic invertebrates', 'Catfishes (freshwater)',
       'Dried pork meat', ..., 'yoghurt',
       'yoghurt-like soya-based products containing bacteria cultures',
       'yogurt raisins'], dtype=object)

# Predict sets:

## Option 1: Using Model

In [None]:
def predict(load_base_classifier):
    ps_max_k = [pd.DataFrame() for _ in range(5)]
    ps_conf  = [pd.DataFrame() for _ in range(5)]
 
    for i in range(5):
        # load evaluators
        evaluator_base = load_base_classifier(i)
        evaluator_maxk = EvaluatorMaxK(evaluator_base)
        evaluator_conf = EvaluatorConformalSimple(evaluator_base)

        # load data:
        _, data_calib, data_test = load_data(
            f"../data/{DATA}/splits/",
            TEXTS,
            LABEL,
            i,
            evaluator_base.tokenizer,
            add_texts=True
        )

        # predict class probabilities:
        probs = evaluator_base.predict(data_test)

        ps_max_k[i]['texts']  = evaluator_base.last_texts
        ps_max_k[i]['labels'] = list(probs['labels'])
        
        ps_conf[i]['texts']  = evaluator_base.last_texts
        ps_conf[i]['labels'] = list(probs['labels'])

        # calibrate conformal prediction:
        evaluator_conf.calibrate(data_calib)

        # predict sets:
        for alpha in ALPHAS:
            ps_conf[i][f'{alpha:.2f}'] = evaluator_conf.predict(alpha, y_pred=probs['probabilities'], min_k=MIN_K)['predictions']

        # predict max-k sets:
        for k in K:
            ps_max_k[i][f'{k:d}'] = evaluator_maxk.predict(k, y_pred=probs['probabilities'])['predictions']

    return ps_max_k, ps_conf

In [None]:
ps_classic_max_k, ps_classic_conf = predict(
    lambda i: TrainerClassic.load(
        dir=f'../models/{MODEL}/{MODEL}-{LABEL}-{i:d}/',
        normalize_fcn='sum'
    )
)

## Option 2: Using Saved Predcitions

In [None]:
def predict(model_name:str):
    ps_max_k = [pd.DataFrame() for _ in range(5)]
    ps_conf  = [pd.DataFrame() for _ in range(5)]
 
    for i in range(5):
        # load evaluators
        evaluator_maxk = EvaluatorMaxK(normalize_fcn='sum')
        evaluator_conf = EvaluatorConformalSimple(normalize_fcn='sum')
        
        with open(f'../results/{model_name}/{model_name}-{LABEL}-calib-{i:d}.pickle', 'rb') as f:
                predictions_calib = pickle.load(f)

        # calibrate conformal prediction:
        evaluator_conf.calibrate(
            y_pred=predictions_calib['probabilities'],
            y_true=predictions_calib['labels']
        )

        with open(f'../results/{model_name}/{model_name}-{LABEL}-{i:d}.pickle', 'rb') as f:
                predictions_test = pickle.load(f)

        ps_max_k[i]['labels'] = list(predictions_test['labels'])
        ps_conf[i]['labels'] = list(predictions_test['labels'])

        # predict sets:
        for alpha in ALPHAS:
            ps_conf[i][f'{alpha:.2f}'] = evaluator_conf.predict(alpha, y_pred=predictions_test['probabilities'], min_k=MIN_K)['predictions']

        # predict max-k sets:
        for k in K:
            ps_max_k[i][f'{k:d}'] = evaluator_maxk.predict(k, y_pred=predictions_test['probabilities'])['predictions']

    return ps_max_k, ps_conf

In [None]:
ps_classic_max_k, ps_classic_conf = predict(MODEL)

## Best traditional classifier:

# Create prompts:

In [None]:
def replace_qm(text):
    # replace quotation marks in text:
    for char in ['"', '„', '“', '”', '«', '»', '‚', '‘', '’', '‹', '›']:
        text = text.replace(char, "'")

    return text

In [None]:
from typing import List, Dict

def get_samples(text:str, data_train:List[Dict[str, List]], embedding:Embedding, samples_per_class:int):
    samples = []
    n = len(mappings)

    for i in range(n):
        # extract texts of class:
        texts = data_train[i]

        # get closest sample based on embedding from training data:
        similarity = embedding.cosine_similarity(
            [embedding.tokenizer(text, return_offsets_mapping=False)['input_ids']] + texts['input_ids']
        )[1:,0]

        # replace quotation marks in sample:
        for j in np.argsort(similarity)[::-1][:samples_per_class]:
            sample = (i, texts['texts'][j], mappings[i], similarity[j])
            samples.append(sample)

    # sort samples based on embedding from training data:
    samples.sort(key=lambda e: e[3], reverse=True)

    return samples

In [None]:
def create_prompt(text, samples, labels_reduced=None, total_samples=None):
    samples_copy = samples.copy()
    
    if labels_reduced is not None:
        # reorder samples according to prediction set:
        def get_key(sample):
            try: return labels_reduced.index(sample[0])
            except ValueError: return len(labels_reduced)

        samples_copy.sort(key=get_key)

        # filter samples:
        samples_copy = [(i, x, y, s) for i, x, y, s in samples_copy if i in labels_reduced]

    if total_samples is not None:
        samples_copy = samples_copy[:total_samples]

    # create context:
    context = f'We are looking for food {LABEL.split("-")[0]}s in texts. Here are some labelled examples sorted from most probable to least probable:'

    for _, x, y, _ in samples_copy:
        context += f'\n"{replace_qm(x)}" -> {y}'

    return f'Context start:\n{context}\nContext end:\nPlease predict the correct class for the following sample:\n"{replace_qm(text)}" -> '

In [None]:
prompts = pd.DataFrame(columns=[
    'cv_split',
    'prompt_all',
    'prompt_sim-5',
    'prompt_sim-10',
    'prompt_sim-20',
    'prompt_max-5',
    'prompt_max-10',
    f'prompt_conformal_{P*100:.0f}%',
    'label'
])

In [None]:
for split in [0]:#range(5):
    # load data:
    with open(f"../data/{DATA}/splits/split_{LABEL.split('-')[0]}_{split:d}.pickle", "rb") as f:
        data = pickle.load(f)

    # load embedding:
    embedding = EmbeddingTfIdf.load(f'../models/{MODEL}/{MODEL}-{LABEL}-{split:d}/')

    # assert ordering of test set:
    assert all(data['test'][LABEL].values == ps_classic_conf[split]['labels'].values)

    texts_train  = data['train'][TEXTS].values
    labels_train = data['train'][LABEL].values

    data_train = []
    for label in range(len(mappings)):
        texts = list(texts_train[labels_train == label])

        data_train.append({
            'texts':     texts, 
            'input_ids': [embedding.tokenizer(
                text, 
                return_offsets_mapping=False
            )['input_ids'] for text in texts]
        })

    for y_true, y_conf, y_max5, y_max10, text in tqdm(zip(
            ps_classic_conf[split]['labels'],
            ps_classic_conf[split][f'{P:.2f}'],
            ps_classic_max_k[split]['5'],
            ps_classic_max_k[split]['10'],
            data['test'][TEXTS].values
        )):
        samples = get_samples(text, data_train, embedding=embedding, samples_per_class=N_SAMPLES)

        b_max5    = create_prompt(text, samples, labels_reduced=[c for c,_ in y_max5])
        b_max10   = create_prompt(text, samples, labels_reduced=[c for c,_ in y_max10])
        
        b_all     = create_prompt(text, samples)
        b_conf    = create_prompt(text, samples, labels_reduced=[c for c,_ in y_conf])

        b_sim5    = create_prompt(text, samples, total_samples=5)
        b_sim10   = create_prompt(text, samples, total_samples=10)
        b_sim20   = create_prompt(text, samples, total_samples=20)

        prompts.loc[len(prompts)] = (split, b_all, b_sim5, b_sim10, b_sim20, b_max5, b_max10, b_conf, mappings[y_true])

0it [00:00, ?it/s]

In [None]:
prompts.to_csv(f'prompts_new_{LABEL}_{N_SAMPLES}-shot.csv')

In [None]:
prompts

Unnamed: 0,cv_split,prompt_all,prompt_sim-5,prompt_sim-10,prompt_sim-20,prompt_max-5,prompt_max-10,prompt_conformal_35%,label
0,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,other types of meat
1,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Not classified pork meat
2,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,prepared dish
3,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,ground beef meat
4,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,chicken breast
...,...,...,...,...,...,...,...,...,...
1505,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,muesli
1506,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,cheese
1507,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,tapioca chips
1508,0,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,Context start:\nWe are looking for food produc...,mung bean sprouts


In [14]:
print(prompts[f'prompt_conformal_{P*100:.0f}%'][1], prompts['label'][1])

Context start:
We are looking for food hazards in texts. Here are some labelled examples sorted from most probable to least probable:
"Whole Foods brand Pot Pies recalled due to undeclared milk" -> milk and products thereof
"California Firm Recalls Pork Pie Products Due to Misbranding and Undeclared Allergens" -> milk and products thereof
"Great American Deli Issues Allergy Alert On Undeclared Egg And Soy In Premium Chicken Salad Wheatberry Sandwich" -> soybeans and products thereof
"michigan brand, inc. recalls products due to misbranding and an undeclared allergen" -> soybeans and products thereof
"Bake My Day brand Chicken Pot Pie recalled due to undeclared egg" -> eggs and products thereof
"Perdue Foods LLC Recalls Chicken Products due to Misbranding and Undeclared Allergens" -> eggs and products thereof
"bon appetizers, llc recalls products due to misbranding and an undeclared allergen" -> peanuts and products thereof
"Georgia Firm Recalls Chicken and Beef Products Due to Misbrand