In [None]:
import _base_path
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold , train_test_split
from resources.data_io import save_data, save_mappings
from resources.spans import SpanCollection

# Settings:

In [2]:
K =        5   # Number of cross validation splits
SEED =     42  # Random seed
DATASET =  'incidents'
SAVE_DIR = f'{DATASET}/splits/'

# Load Data:

## Load "incidents":

In [3]:
# load incidents:
data = pd.read_csv(f"{DATASET}/{DATASET}_final.csv", index_col=0)

# fill nan-values:
data['product'].fillna('', inplace=True)
data['product-raw'].fillna('', inplace=True)
data['product-category'].fillna('', inplace=True)

data['hazard'].fillna('', inplace=True)
data['hazard-raw'].fillna('', inplace=True)
data['hazard-category'].fillna('', inplace=True)

data['country'].fillna('na', inplace=True)
data['language'].fillna('na', inplace=True)

# parse spans:
for col in ['product-title', 'product-text', 'hazard-title', 'hazard-text', 'supplier-title', 'supplier-text']: 
    if col in data.columns: data[col] = [SpanCollection.parse(item) for item in data[col].fillna('')]
    else:                   data[col] = [SpanCollection()] * len(data)

print(f"N = {len(data):d}")
data.head(5)

N = 7548


Unnamed: 0,year,month,day,url,title,text,product,product-raw,product-category,product-title,product-text,hazard,hazard-raw,hazard-category,hazard-title,hazard-text,supplier-title,supplier-text,language,country
0,2015,5,26,https://www.fda.gov/Safety/Recalls/ArchiveReca...,2015 - House of Spices (India) Inc. Issues Ale...,"April 23, 2015 – Flushing, NY – House of Spic...",apricots,dried apricots,fruits and vegetables,(),"(slice(108, 120, None), slice(442, 454, None),...",sulphur dioxide and sulphites,undeclared sulphite,allergens,"(slice(7, 34, None))","(slice(33, 51, None), slice(98, 100, None), sl...","(slice(7, 34, None))","(slice(33, 51, None), slice(98, 100, None), sl...",en,us
1,2022,5,25,https://www.fda.gov/safety/recalls-market-with...,Supplier J.M. Smucker Co.’s Jif Recall Prompts...,"(Miami, FL – May 24, 2022) - J.M. Smucker Co.’...",peanuts,peanuts,"nuts, nut products and seeds",(),"(slice(229, 242, None), slice(470, 483, None),...",pathogen,salmonella,biological,"(slice(47, 62, None))","(slice(1, 10, None), slice(123, 138, None), sl...","(slice(47, 62, None))","(slice(1, 10, None), slice(123, 138, None), sl...",en,us
2,2020,6,2,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 2 June 2020) Not to consume a bat...,*(Updated on 2 June 2020) Not to consume a bat...,fruit and vegetable juices,apple juice,non-alcoholic beverages,"(slice(52, 71, None))","(slice(52, 71, None), slice(222, 241, None), s...",mycotoxin,patulin,chemical,(),"(slice(354, 366, None), slice(581, 593, None),...",(),"(slice(354, 366, None), slice(581, 593, None),...",en,hk
3,2022,7,5,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 5 July 2022) Not to consume smoke...,*(Updated on 5 July 2022) Not to consume smoke...,fish and fish products,chilled smoked salmon,seafood,"(slice(41, 54, None))","(slice(41, 54, None), slice(256, 277, None), s...",pathogen,listeria monocytogenes,biological,(),"(slice(177, 185, None), slice(203, 207, None),...",(),"(slice(177, 185, None), slice(203, 207, None),...",en,hk
4,2021,3,20,http://www.fsis.usda.gov/recalls-alerts/avanza...,"Avanza Pasta, LLC Recalls Beef and Poultry Pro...",0009-2021\n\n \n High\n\n Produced Withou...,pasta products,pasta products,other food product / mixed,"(slice(7, 12, None), slice(43, 51, None))","(slice(271, 276, None), slice(307, 315, None),...",smuggling,inspection issues,fraud,"(slice(0, 17, None))","(slice(264, 281, None), slice(617, 634, None),...","(slice(0, 17, None))","(slice(264, 281, None), slice(617, 634, None),...",en,us


# Vectorize labels:

In [4]:
unique_labels = {}

for column in ['product', 'hazard', 'product-category', 'hazard-category']:
    # extract and sort unique values:
    unique_labels[column] = np.unique(data[column].values)
    unique_labels[column].sort()

    # create label-to-integer mapping:
    label2index = {l:i for i,l in enumerate(unique_labels[column])}

    # replace strings with integers:
    data[column] = data[column].apply(lambda label: label2index[label])

save mappings:

In [5]:
os.makedirs(SAVE_DIR, exist_ok=True)

save_mappings(SAVE_DIR,
    unique_labels['product'],
    unique_labels['hazard'],
    unique_labels['product-category'],
    unique_labels['hazard-category']
)

# Create K-Fold splits

In [6]:
data.head(5)

Unnamed: 0,year,month,day,url,title,text,product,product-raw,product-category,product-title,product-text,hazard,hazard-raw,hazard-category,hazard-title,hazard-text,supplier-title,supplier-text,language,country
0,2015,5,26,https://www.fda.gov/Safety/Recalls/ArchiveReca...,2015 - House of Spices (India) Inc. Issues Ale...,"April 23, 2015 – Flushing, NY – House of Spic...",15,dried apricots,10,(),"(slice(108, 120, None), slice(442, 454, None),...",101,undeclared sulphite,0,"(slice(7, 34, None))","(slice(33, 51, None), slice(98, 100, None), sl...","(slice(7, 34, None))","(slice(33, 51, None), slice(98, 100, None), sl...",en,us
1,2022,5,25,https://www.fda.gov/safety/recalls-market-with...,Supplier J.M. Smucker Co.’s Jif Recall Prompts...,"(Miami, FL – May 24, 2022) - J.M. Smucker Co.’...",433,peanuts,16,(),"(slice(229, 242, None), slice(470, 483, None),...",74,salmonella,1,"(slice(47, 62, None))","(slice(1, 10, None), slice(123, 138, None), sl...","(slice(47, 62, None))","(slice(1, 10, None), slice(123, 138, None), sl...",en,us
2,2020,6,2,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 2 June 2020) Not to consume a bat...,*(Updated on 2 June 2020) Not to consume a bat...,250,apple juice,15,"(slice(52, 71, None))","(slice(52, 71, None), slice(222, 241, None), s...",58,patulin,2,(),"(slice(354, 366, None), slice(581, 593, None),...",(),"(slice(354, 366, None), slice(581, 593, None),...",en,hk
3,2022,7,5,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 5 July 2022) Not to consume smoke...,*(Updated on 5 July 2022) Not to consume smoke...,206,chilled smoked salmon,20,"(slice(41, 54, None))","(slice(41, 54, None), slice(256, 277, None), s...",74,listeria monocytogenes,1,(),"(slice(177, 185, None), slice(203, 207, None),...",(),"(slice(177, 185, None), slice(203, 207, None),...",en,hk
4,2021,3,20,http://www.fsis.usda.gov/recalls-alerts/avanza...,"Avanza Pasta, LLC Recalls Beef and Poultry Pro...",0009-2021\n\n \n High\n\n Produced Withou...,426,pasta products,17,"(slice(7, 12, None), slice(43, 51, None))","(slice(271, 276, None), slice(307, 315, None),...",91,inspection issues,6,"(slice(0, 17, None))","(slice(264, 281, None), slice(617, 634, None),...","(slice(0, 17, None))","(slice(264, 281, None), slice(617, 634, None),...",en,us


In [7]:
def filter_by_support(rows, column, min_support):
    for i_label, label in enumerate(unique_labels[column]):
        mask = (data.loc[rows, column].values == i_label)

        if sum(mask) < min_support:
            rows = rows[~mask]
            print(f'{column.upper()}: dropped class "{label}" with n_samples = {sum(mask)} < {min_support:d}.')

    return rows

In [None]:
for label_fine, label_coarse, columns in [
        ('product', 'product-category', [col for col in data.columns if not col.startswith('hazard')]),
        ('hazard',  'hazard-category',  [col for col in data.columns if not col.startswith('product')])
    ]:

    # drop coarse labels with less than K samples:
    i_filtered = filter_by_support(data.index, label_coarse, K)

    # create K-fold splits:
    kf = StratifiedKFold(n_splits=K, shuffle=True, random_state=SEED)
    for split, (i_train, i_test) in enumerate(kf.split(i_filtered, data.loc[i_filtered, label_coarse].values)):
        i_train = filter_by_support(i_filtered[i_train], label_coarse, 2)
        i_test  = i_filtered[i_test]
        # split off validation set:
        i_train, i_valid = train_test_split(
            i_train,
            test_size=.1,
            stratify=data.loc[i_train, label_coarse].values,
            shuffle=True,
            random_state=SEED
        )

        # save mappings:
        save_data(SAVE_DIR , split, label_fine,
            data.loc[i_train, columns],
            data.loc[i_valid, columns],
            data.loc[i_test,  columns]
        )