In [1]:
import _base_path
import pickle
import numpy as np
import pandas as pd
from tqdm.autonotebook import tqdm 

from resources.evaluator import EvaluatorConformalSimple, EvaluatorMaxK
from resources.models import DummyModel
from resources.data_io import load_mappings, load_data

from TrainerBOW import EvaluatorBOW

  from tqdm.autonotebook import tqdm


In [2]:
DATA     = 'incidents'
TEXTS    = 'title'
LABEL    = 'product'

K        = [5, 10]
EPSILON  = [.05, .20, .50]

In [3]:
mappings = load_mappings(f"../data/{DATA}/splits/", LABEL)
mappings

array(['adobo seasoning', 'after dinner mints', 'alcoholic beverages',
       ...,
       'yoghurt-like soya-based products containing bacteria cultures',
       'yogurt raisins', 'zomi and palm oil'], dtype='<U70')

# Predict sets:

In [4]:
def predict(load_base_classifier):
    ps_max_k = [pd.DataFrame() for _ in range(5)]
    ps_conf  = [pd.DataFrame() for _ in range(5)]
 
    for i in range(5):
        # load evaluators
        evaluator_base, tokenizer = load_base_classifier(i)
        evaluator_maxk = EvaluatorMaxK(evaluator_base)
        evaluator_conf = EvaluatorConformalSimple(evaluator_base)

        # load data:
        _, data_calib, data_test = load_data(
            f'../data/{DATA}/splits/',
            TEXTS,
            LABEL,
            i,
            tokenizer,
            add_texts=True
        )

        # predict class probabilities:
        probs = evaluator_base.predict(data_test)

        ps_max_k[i]['texts']  = evaluator_base.last_texts
        ps_max_k[i]['labels'] = list(probs['labels'])
        
        ps_conf[i]['texts']  = evaluator_base.last_texts
        ps_conf[i]['labels'] = list(probs['labels'])

        # calibrate conformal prediction:
        evaluator_conf.calibrate(data_calib)

        # predict sets:
        for epsilon in EPSILON:
            ps_conf[i][f'{epsilon:.2f}'] = evaluator_conf.predict(epsilon, y_pred=probs['predictions'])['predictions']

        # predict max-k sets:
        for k in K:
            ps_max_k[i][f'{k:d}'] = evaluator_maxk.predict(k, y_pred=probs['predictions'])['predictions']

    return ps_max_k, ps_conf

## Best traditional classifier:

In [None]:
ps_classic_max_k, ps_classic_conf = predict(
    lambda i: EvaluatorBOW.load(f'../models/bow-svm/bow-svm-{LABEL}-{i:d}/', 'sum', 2)
)

# Create prompts:

In [None]:
def replace_qm(text):
    # replace quotation marks in text:
    for char in ['"', '„', '“', '”', '«', '»', '‚', '‘', '’', '‹', '›']:
        text = text.replace(char, "'")

    return text

In [None]:
def get_samples(texts_train, labels_train, samples_per_class=2):
    samples = []
    n = len(mappings)
    for _ in range(samples_per_class):
        for i in np.random.choice(np.arange(n), size=n, replace=False):
            # get random sample from training data:
            try: sample = np.random.choice(texts_train[labels_train[:, i]])
            except ValueError: continue

            # replace quotation marks in sample:
            sample = (i, sample, mappings[i])
            if sample not in samples:
                samples.append(sample)

    return samples

In [1]:
def create_prompt(text, samples, labels_reduced=None, total_samples=None):
    samples_copy = samples.copy()
    
    if labels_reduced is not None:
        # reorder samples according to prediction set:
        def get_key(sample):
            try: return labels_reduced.index(sample[0])
            except ValueError: return len(labels_reduced)

        samples_copy.sort(key=get_key)

        # filter samples:
        samples_copy = [(i, x, y) for i, x, y in samples_copy if i in labels_reduced]

    if total_samples is not None:
        samples_copy = samples_copy[:total_samples]

    # create context:
    context = f'We are looking for food {LABEL.split("_")[0]}s in texts. Here are some labelled examples'
    if labels_reduced is not None:
        context += ' sorted from most probable to least probable:'
    else:
        context += ':'

    for _, x, y in samples_copy:
        context += f'\n"{replace_qm(x)}" -> {y}'

    return f'Context start:\n{context}\nContext end:\nPlease predict the correct class for the following sample:\n"{replace_qm(text)}" -> ', len(samples_copy)

In [None]:
N_SAMPLES = 2
P = .05

prompts = pd.DataFrame(columns=[
    'cv_split',
    'prompt_all',
    'prompt_limited',
    f'prompt_conformal_{P*100:.0f}%',
    'prompt_max-5',
    'prompt_max-10',
    'prompt_max-15',
    'label'
])
for i in range(5):
    # load data:
    with open(f"../data/{DATA}/splits/split_{i:d}.pickle", "rb") as f:
        data = pickle.load(f)

    texts_train = data['train'][TEXTS].values
    labels_train = np.array([np.array(l, dtype=bool) for l in data['train'][LABEL].values])

    for y_true, y_conf, y_max5, y_max10, y_max15, text in tqdm(zip(
            ps_classic_conf[i]['labels'],
            ps_classic_max_k[i]['5'],
            ps_classic_max_k[i]['10'],
            ps_classic_max_k[i]['15'],
            ps_classic_conf[i][f'{P:.2f}'],
            ps_classic_conf[i]['texts']
        )):
        samples = get_samples(texts_train, labels_train, samples_per_class=N_SAMPLES)

        b_max5, _  = create_prompt(text, samples, labels_reduced=[c for c,_ in y_max5])
        b_max10, _ = create_prompt(text, samples, labels_reduced=[c for c,_ in y_max10])
        b_max15, _ = create_prompt(text, samples, labels_reduced=[c for c,_ in y_max15])
        
        b_all, _   = create_prompt(text, samples)
        b_conf, n  = create_prompt(text, samples, labels_reduced=[c for c,_ in y_conf])
        b_lim, _   = create_prompt(text, samples, total_samples=n)
        
        prompts.loc[len(prompts)] = (i, b_all, b_lim, b_conf, b_max5, b_max10, b_max15, ' | '.join(mappings[np.array(y_true, dtype=bool)]))

In [None]:
prompts.to_csv(f'prompts_{LABEL}_{P*100:.0f}%.csv')

In [None]:
prompts

Unnamed: 0,cv_split,prompt_all,prompt_limited,prompt_conformal_5%,label
0,0,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,mislabelled
1,0,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,pistachio nut
2,0,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,illegal import
3,0,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,allergens
4,0,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,plastic fragment
...,...,...,...,...,...
7613,4,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,listeria monocytogenes
7614,4,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,cereals containing gluten and products thereof
7615,4,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,dairy products
7616,4,,Context start:\nWe are looking for food hazard...,Context start:\nWe are looking for food hazard...,contaminated with aliphatic hydrocarbons


In [None]:
print(prompts[f'prompt_conformal_{P*100:.0f}%'][1], prompts['label'][1])

Context start:
We are looking for food hazards in texts. Here are some labelled examples sorted from most probable to least probable:
"Ciolo Foods Issues Allergy Alert for Undeclared Tree Nuts in" -> nuts
"New World brand Almond Butter - Smooth - Roasted recalled due to undeclared peanut and cashew" -> nuts
"Archives" -> e 621 - monosodium glutamate undeclared
"Recall Notification Report 064-2013" -> e 621 - monosodium glutamate undeclared
"Nestlé USA Inc. Recalls Frozen DiGiorno Crispy Pan Crust Pepperoni Pizza Due to Misbranding and Undeclared Allergens" -> soybeans and products thereof
"Russ Davis Wholesale Issues Allergy Alert on Undeclared Soy in Veggie Pizza" -> soybeans and products thereof
"Coco Joy Pure Coconut Milk" -> milk and products thereof
"Ducktrap River of Maine Recalls One Lot of Herring Center Cuts in Wine Sauce Due to Undeclared Dairy in Product" -> milk and products thereof
"Fieldsource Food Systems, Inc. Recalls Beef and Poultry Products Due to Misbranding and Und