In [1]:
import _base_path
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold , train_test_split
from resources.data_io import save_data, save_mappings
from resources.spans import SpanCollection

# Settings:

In [2]:
K =       5   # Number of cross validation splits
SEED =    42  # Random seed
DATASET = 'incidents'

# Load Data:

## Load "incidents":

In [3]:
# load incidents:
data = pd.read_csv(f"{DATASET}/{DATASET}_final.csv")

# parse products:
data['product']          = [p.split('|') for p in data['product'].fillna('')]
data['product_category'] = [p.split('|') for p in data['product_category'].fillna('')]
data['product_spans']    = [SpanCollection.parse(p) for p in data['product_spans'].fillna('')]

# parse hazards:
data['hazard']           = [h.split('|') for h in data['hazard'].fillna('')]
data['hazard_category']  = [h.split('|') for h in data['hazard_category'].fillna('')]
data['hazard_spans']     = [SpanCollection.parse(h) for h in data['hazard_spans'].fillna('')]

# parse suppliers:
data['supplier_title']   = [SpanCollection.parse(s) for s in data['supplier_title'].fillna('')]
data['supplier_text']    = [SpanCollection.parse(s) for s in data['supplier_text'].fillna('')]

# fill nan-values:
data['country'].fillna('na', inplace=True)

print(f"N = {len(data):d}")
data.head(5)

N = 7619


Unnamed: 0.1,Unnamed: 0,year,month,day,url,title,text,product,product_category,product_spans,hazard,hazard_category,hazard_spans,supplier_title,supplier_text,language,country
0,0,2015,5,26,https://www.fda.gov/Safety/Recalls/ArchiveReca...,2015 - House of Spices (India) Inc. Issues Ale...,"April 23, 2015 – Flushing, NY – House of Spic...",[dried apricots],[fruits and vegetables],"(slice(98, 140, None), slice(429, 476, None), ...",[undeclared sulphite],[allergens],"(slice(167, 250, None))","(slice(7, 35, None))","(slice(33, 48, None))",en,us
1,1,2022,5,25,https://www.fda.gov/safety/recalls-market-with...,Supplier J.M. Smucker Co.’s Jif Recall Prompts...,"(Miami, FL – May 24, 2022) - J.M. Smucker Co.’...",[peanuts],"[nuts, nut products and seeds]","(slice(452, 518, None), slice(533, 560, None),...",[salmonella],[biological],"(slice(258, 290, None), slice(936, 969, None),...","(slice(53, 62, None))","(slice(0, 10, None), slice(129, 138, None), sl...",en,us
2,2,2020,6,2,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 2 June 2020) Not to consume a bat...,*(Updated on 2 June 2020) Not to consume a bat...,[apple juice],[non-alcoholic beverages],"(slice(212, 261, None), slice(1429, 1466, None))",[patulin],[chemical],"(slice(49, 114, None), slice(827, 863, None), ...",(),"(slice(354, 366, None), slice(581, 593, None),...",en,hk
3,3,2022,7,5,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 5 July 2022) Not to consume smoke...,*(Updated on 5 July 2022) Not to consume smoke...,[chilled smoked salmon],[fish and fish products],"(slice(26, 73, None), slice(244, 294, None), s...",[listeria monocytogenes],[biological],"(slice(74, 134, None), slice(859, 909, None), ...",(),"(slice(484, 492, None), slice(1197, 1218, None))",en,hk
4,4,2021,3,20,http://www.fsis.usda.gov/recalls-alerts/avanza...,"Avanza Pasta, LLC Recalls Beef and Poultry Pro...",0009-2021\r\n\r\n \r\n High\r\n\r\n Produ...,[pasta products],[other food product / mixed],"(slice(661, 667, None), slice(750, 791, None),...",[inspection issues],[fraud],"(slice(40, 69, None), slice(329, 410, None), s...","(slice(0, 17, None))","(slice(294, 311, None), slice(654, 672, None),...",en,us


# Vectorize labels:

In [4]:
def label2one_hot(label_name:str) -> dict:
    labels = data[label_name].values

    # extract and sort unique values:
    unique_labels = np.unique(np.concatenate(labels))
    unique_labels.sort()

    # create label-to-integer mapping:
    label2index = {l:i for i,l in enumerate(unique_labels)}

    # replace strings with integers:
    labels_one_hot = np.zeros((len(labels), len(unique_labels)), dtype=int)
    for i,ls in enumerate(labels):
        for l in ls:
            labels_one_hot[i,label2index[l]] = 1

    data[label_name] = [tuple(row) for row in labels_one_hot]
    return unique_labels

In [5]:
products = label2one_hot("product")
products

array(['adobo seasoning', 'after dinner mints', 'alcoholic beverages',
       ...,
       'yoghurt-like soya-based products containing bacteria cultures',
       'yogurt raisins', 'zomi and palm oil'], dtype='<U70')

In [6]:
hazards = label2one_hot("hazard")
hazards

array(['2-chloroethanol', '2-chloroethanol and ethylene oxide',
       'abnormal colour', 'abnormal smell',
       'absence of expiry/use by dates', 'absence of labelling',
       'addition', 'adulteration (ema)', 'adverse reaction',
       'aeromonas hydrophila', 'aflatoxin', 'alcohol content', 'algae',
       'aliphatic hydrocarbons', 'alkaloids', 'allergens',
       'allergic reaction', 'almond',
       'altered organoleptic characteristics', 'aluminium', 'amygdalin',
       'animal matter', 'anthraquinone', 'antibiotics, vet drugs',
       'appearance', 'apple stems', 'arsenic', 'atropine',
       'attempt to illegally import', 'azinphos-methyl',
       'bacillus cereus', 'bacillus cytotoxicus', 'bacillus spp.',
       'bad smell / off odor', 'barley', 'benzo(a)pyrene',
       'biocontaminants', 'biological', 'biotoxins (other)',
       'bone fragment', 'botulinum toxin', 'brazil nut', 'breakage',
       'bromate', 'bse', 'bulging packaging',
       'bursting possibility of bottle 

In [7]:
productCategories = label2one_hot("product_category")
productCategories

array(['alcoholic beverages', 'bivalve molluscs and products therefor',
       'cephalopods and products thereof', 'cereals and bakery products',
       'cocoa and cocoa preparations, coffee and tea', 'confectionery',
       'crustaceans and products thereof',
       'dietetic foods, food supplements, fortified foods',
       'eggs and egg products', 'fats and oils', 'feed additives',
       'feed materials', 'fish and fish products',
       'food additives and flavourings', 'food contact materials',
       'fruits and vegetables', 'herbs and spices',
       'honey and royal jelly', 'ices and desserts',
       'meat and meat products (other than poultry)',
       'milk and milk products', 'non-alcoholic beverages',
       'nuts, nut products and seeds', 'other food product / mixed',
       'pet feed', 'poultry meat and poultry meat products',
       'prepared dishes and snacks',
       'soups, broths, sauces and condiments', 'sugars and syrups'],
      dtype='<U49')

In [8]:
hazardCategories = label2one_hot("hazard_category")
hazardCategories

array(['allergens', 'biological', 'chemical',
       'food additives and flavourings', 'food contact materials',
       'foreign bodies', 'fraud', 'migration', 'organoleptic aspects',
       'other hazard', 'packaging defect'], dtype='<U30')

In [9]:
# save mappings:
save_mappings(DATASET, products, hazards, productCategories, hazardCategories)

# Create K-Fold splits

In [10]:
data.head(5)

Unnamed: 0.1,Unnamed: 0,year,month,day,url,title,text,product,product_category,product_spans,hazard,hazard_category,hazard_spans,supplier_title,supplier_text,language,country
0,0,2015,5,26,https://www.fda.gov/Safety/Recalls/ArchiveReca...,2015 - House of Spices (India) Inc. Issues Ale...,"April 23, 2015 – Flushing, NY – House of Spic...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(slice(98, 140, None), slice(429, 476, None), ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)","(slice(167, 250, None))","(slice(7, 35, None))","(slice(33, 48, None))",en,us
1,1,2022,5,25,https://www.fda.gov/safety/recalls-market-with...,Supplier J.M. Smucker Co.’s Jif Recall Prompts...,"(Miami, FL – May 24, 2022) - J.M. Smucker Co.’...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(slice(452, 518, None), slice(533, 560, None),...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0)","(slice(258, 290, None), slice(936, 969, None),...","(slice(53, 62, None))","(slice(0, 10, None), slice(129, 138, None), sl...",en,us
2,2,2020,6,2,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 2 June 2020) Not to consume a bat...,*(Updated on 2 June 2020) Not to consume a bat...,"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(slice(212, 261, None), slice(1429, 1466, None))","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0)","(slice(49, 114, None), slice(827, 863, None), ...",(),"(slice(354, 366, None), slice(581, 593, None),...",en,hk
3,3,2022,7,5,http://www.cfs.gov.hk/english/whatsnew/whatsne...,*(Updated on 5 July 2022) Not to consume smoke...,*(Updated on 5 July 2022) Not to consume smoke...,"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","(slice(26, 73, None), slice(244, 294, None), s...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0)","(slice(74, 134, None), slice(859, 909, None), ...",(),"(slice(484, 492, None), slice(1197, 1218, None))",en,hk
4,4,2021,3,20,http://www.fsis.usda.gov/recalls-alerts/avanza...,"Avanza Pasta, LLC Recalls Beef and Poultry Pro...",0009-2021\r\n\r\n \r\n High\r\n\r\n Produ...,"(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(slice(661, 667, None), slice(750, 791, None),...","(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","(0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0)","(slice(40, 69, None), slice(329, 410, None), s...","(slice(0, 17, None))","(slice(294, 311, None), slice(654, 672, None),...",en,us


In [11]:
kf = StratifiedKFold(n_splits=K, shuffle=True, random_state=SEED)

# drop hazards with less than K samples:
labels = np.array(list(data['hazard_category'].values))
drop = labels.sum(axis=0) < K
for label in np.eye(len(drop), dtype=int)[drop]:
    mask = (labels == label).all(axis=1)
    labels = labels[~mask]
    data.drop(index=np.argwhere(mask)[0], inplace=True)

print(f'Dropped {drop.sum():d} classes with n_samples < {K:d}.')

# randomly select a class for stratification in multiclass instances:
labels = np.array([np.random.choice(np.argwhere(label)[0]) for label in labels])

for i, (i_train, i_test) in enumerate(kf.split(range(len(data)), labels)):
    # split off validation set:
    i_train, i_valid = train_test_split(
        i_train,
        test_size=.1,
        stratify=labels[i_train],
        shuffle=True,
        random_state=SEED
    )

    # save mappings:
    os.makedirs(f'{DATASET}/splits/')
    save_data(f'{DATASET}/splits/', i, data.iloc[i_train], data.iloc[i_valid], data.iloc[i_test])

Dropped 1 classes with n_samples < 5.
