In [640]:
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [641]:
cols_to_drop = ["country", "date", "title"]

In [642]:
incident_data = pd.read_csv("data/incidents_train.csv", index_col="Unnamed: 0")
incident_data["date"] = pd.to_datetime(incident_data[["year", "month", "day"]])
incident_data.drop(labels=["year", "month", "day", *cols_to_drop], axis=1, inplace=True)
incident_data.rename({"text": "recall"}, axis=1, inplace=True)

In [643]:
incident_data

Unnamed: 0,recall,hazard-category,product-category,hazard,product
0,Case Number: 024-94 \n Date Opene...,biological,"meat, egg and dairy products",listeria monocytogenes,smoked sausage
1,Case Number: 033-94 \n Date Opene...,biological,"meat, egg and dairy products",listeria spp,sausage
2,Case Number: 014-94 \n Date Opene...,biological,"meat, egg and dairy products",listeria monocytogenes,ham slices
3,Case Number: 009-94 \n Date Opene...,foreign bodies,"meat, egg and dairy products",plastic fragment,thermal processed pork meat
4,Case Number: 001-94 \n Date Opene...,foreign bodies,"meat, egg and dairy products",plastic fragment,chicken breast
...,...,...,...,...,...
5979,Imported biscuit may contain allergen (peanuts...,allergens,cereals and bakery products,peanuts and products thereof,biscuits
5980,023-2022\n\n \n High - Class I\n\n Produc...,fraud,prepared dishes and snacks,inspection issues,pizza
5981,"FRESNO, Calif. – July 28, 2022 – Lyons Magnus ...",biological,non-alcoholic beverages,cronobacter spp,non-alcoholic beverages
5982,025-2022\n\n \n High - Class I\n\n Misbra...,allergens,"meat, egg and dairy products",eggs and products thereof,frozen beef products


In [644]:
keys = ['hazard-category', 'product-category', 'hazard', 'product']

In [645]:

encoder_dict = {key: LabelEncoder() for key in keys}

In [646]:
encoder_dict

{'hazard-category': LabelEncoder(),
 'product-category': LabelEncoder(),
 'hazard': LabelEncoder(),
 'product': LabelEncoder()}

In [647]:
for column, encoder in encoder_dict.items():

    incident_data[column] = encoder.fit_transform(incident_data[column])

In [648]:
incident_data

Unnamed: 0,recall,hazard-category,product-category,hazard,product
0,Case Number: 024-94 \n Date Opene...,1,13,55,858
1,Case Number: 033-94 \n Date Opene...,1,13,56,825
2,Case Number: 014-94 \n Date Opene...,1,13,55,511
3,Case Number: 009-94 \n Date Opene...,4,13,90,933
4,Case Number: 001-94 \n Date Opene...,4,13,90,168
...,...,...,...,...,...
5979,Imported biscuit may contain allergen (peanuts...,0,1,85,73
5980,023-2022\n\n \n High - Class I\n\n Produc...,5,18,52,712
5981,"FRESNO, Calif. – July 28, 2022 – Lyons Magnus ...",1,14,27,628
5982,025-2022\n\n \n High - Class I\n\n Misbra...,0,13,34,397


In [649]:
df = incident_data

In [650]:
def get_data(df: pd.DataFrame, label: str):
    
    edge_index = torch.tensor([df['product'].values, df['hazard'].values], dtype=torch.long)
    
    x_product = OneHotEncoder().fit_transform(df[['product']]).toarray()
    x_hazard = OneHotEncoder().fit_transform(df[['hazard']]).toarray()
    x = np.concat((x_product, x_hazard), axis=1)
    x = torch.tensor(x, dtype=torch.float)
    
    y = torch.tensor(df[label].values, dtype=torch.long)
    
    data = Data(x=x, edge_index=edge_index, y=y)
    
    num_nodes = df.shape[0]
    try:
        train_mask, test_mask = train_test_split(torch.arange(num_nodes), stratify=df[label], test_size=0.2)
    except:
        train_mask, test_mask = train_test_split(torch.arange(num_nodes), test_size=0.2)
    
    data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    data.train_mask[train_mask] = True
    data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    data.test_mask[test_mask] = True
    
    return data

In [651]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

In [652]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [653]:
def train_model(model, optimizer, criterion, data, epochs = 500):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}')

    return model

In [654]:
def eval_model(model, data, label, encoder_dict):
     model.eval()
     with torch.inference_mode():
         _, pred = model(data.x, data.edge_index).max(dim=1)
     test_mask = data.test_mask.cpu().numpy()
     preds = pred.detach().cpu().numpy()[test_mask]
     test = data.y.detach().cpu().numpy()[test_mask]
     
     preds = encoder_dict[label].inverse_transform(preds)
     test = encoder_dict[label].inverse_transform(test)
     
     return classification_report(y_true=test, y_pred=preds, zero_division=0.0)

In [655]:
def train_and_eval_model_for_label(label: str, encoder_dict: dict):
     
     data = get_data(df, label=label)
     data.to(DEVICE)
     
     model = GCN(in_channels=data.x.shape[1], hidden_channels=64, out_channels=data.y.max().item()+1).to(DEVICE)
     criterion = torch.nn.NLLLoss()
     optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

     model = train_model(model, optimizer, criterion, data)

     cr = eval_model(model, data, label, encoder_dict)

     return cr

In [656]:
cr1 = train_and_eval_model_for_label(label='product-category', encoder_dict=encoder_dict)

Epoch 0, Loss: 3.0873870849609375
Epoch 10, Loss: 1.7286486625671387
Epoch 20, Loss: 0.610405683517456
Epoch 30, Loss: 0.20117074251174927
Epoch 40, Loss: 0.11017559468746185
Epoch 50, Loss: 0.07157884538173676
Epoch 60, Loss: 0.048905692994594574
Epoch 70, Loss: 0.03568638488650322
Epoch 80, Loss: 0.02823181264102459
Epoch 90, Loss: 0.023311177268624306
Epoch 100, Loss: 0.019820956513285637
Epoch 110, Loss: 0.01722358725965023
Epoch 120, Loss: 0.015191130340099335
Epoch 130, Loss: 0.013546226546168327
Epoch 140, Loss: 0.012189716100692749
Epoch 150, Loss: 0.011051258072257042
Epoch 160, Loss: 0.010079236701130867
Epoch 170, Loss: 0.0092439204454422
Epoch 180, Loss: 0.008514043875038624
Epoch 190, Loss: 0.007869589142501354
Epoch 200, Loss: 0.007295954506844282
Epoch 210, Loss: 0.0067827291786670685
Epoch 220, Loss: 0.0063194097019732
Epoch 230, Loss: 0.00590191874653101
Epoch 240, Loss: 0.005527099594473839
Epoch 250, Loss: 0.005183225963264704
Epoch 260, Loss: 0.004867378156632185
Ep

In [657]:
print(cr1)

                                                   precision    recall  f1-score   support

                              alcoholic beverages       1.00      1.00      1.00        12
                      cereals and bakery products       0.94      0.93      0.93       134
     cocoa and cocoa preparations, coffee and tea       0.86      0.88      0.87        42
                                    confectionery       0.89      0.91      0.90        34
dietetic foods, food supplements, fortified foods       1.00      0.88      0.94        26
                                    fats and oils       0.50      0.50      0.50         4
                                   feed materials       0.50      1.00      0.67         1
                   food additives and flavourings       1.00      0.50      0.67         2
                           food contact materials       1.00      1.00      1.00         1
                            fruits and vegetables       0.87      0.90      0.88       10

In [658]:
cr2 = train_and_eval_model_for_label(label='hazard-category', encoder_dict=encoder_dict)

Epoch 0, Loss: 2.30242919921875
Epoch 10, Loss: 0.9111738204956055
Epoch 20, Loss: 0.29292863607406616
Epoch 30, Loss: 0.13944773375988007
Epoch 40, Loss: 0.08134221285581589
Epoch 50, Loss: 0.05502404645085335
Epoch 60, Loss: 0.03934647515416145
Epoch 70, Loss: 0.030038928613066673
Epoch 80, Loss: 0.024147534742951393
Epoch 90, Loss: 0.02010015957057476
Epoch 100, Loss: 0.01719551533460617
Epoch 110, Loss: 0.01504252478480339
Epoch 120, Loss: 0.01339271105825901
Epoch 130, Loss: 0.012086529284715652
Epoch 140, Loss: 0.011028091423213482
Epoch 150, Loss: 0.01015007309615612
Epoch 160, Loss: 0.009410050697624683
Epoch 170, Loss: 0.00877738930284977
Epoch 180, Loss: 0.008229044266045094
Epoch 190, Loss: 0.007749361917376518
Epoch 200, Loss: 0.007326396182179451
Epoch 210, Loss: 0.006950210314244032
Epoch 220, Loss: 0.0066133346408605576
Epoch 230, Loss: 0.006310381460934877
Epoch 240, Loss: 0.006036828272044659
Epoch 250, Loss: 0.005787266418337822
Epoch 260, Loss: 0.005558534059673548
E

In [659]:
print(cr2)

                                precision    recall  f1-score   support

                     allergens       0.97      0.97      0.97       371
                    biological       0.94      0.96      0.95       348
                      chemical       0.87      0.81      0.84        57
food additives and flavourings       0.67      0.40      0.50         5
                foreign bodies       0.90      0.91      0.91       112
                         fraud       0.93      0.92      0.93        74
                     migration       1.00      1.00      1.00         1
          organoleptic aspects       1.00      0.73      0.84        11
                  other hazard       0.93      0.93      0.93        27
              packaging defect       0.83      0.91      0.87        11

                      accuracy                           0.94      1017
                     macro avg       0.90      0.85      0.87      1017
                  weighted avg       0.94      0.94      0.94 

In [660]:
cr3 = train_and_eval_model_for_label(label='hazard', encoder_dict=encoder_dict)

Epoch 0, Loss: 4.854064464569092
Epoch 10, Loss: 3.1646342277526855
Epoch 20, Loss: 1.5111761093139648
Epoch 30, Loss: 0.6531173586845398
Epoch 40, Loss: 0.34581923484802246
Epoch 50, Loss: 0.19198760390281677
Epoch 60, Loss: 0.10558470338582993
Epoch 70, Loss: 0.058682188391685486
Epoch 80, Loss: 0.03595777601003647
Epoch 90, Loss: 0.025993401184678078
Epoch 100, Loss: 0.020827792584896088
Epoch 110, Loss: 0.01749522052705288
Epoch 120, Loss: 0.015094856731593609
Epoch 130, Loss: 0.013304515741765499
Epoch 140, Loss: 0.011912209913134575
Epoch 150, Loss: 0.01080928836017847
Epoch 160, Loss: 0.009915181435644627
Epoch 170, Loss: 0.009175784885883331
Epoch 180, Loss: 0.008551981300115585
Epoch 190, Loss: 0.008023249916732311
Epoch 200, Loss: 0.007566946092993021
Epoch 210, Loss: 0.007168992422521114
Epoch 220, Loss: 0.006818288471549749
Epoch 230, Loss: 0.0065064020454883575
Epoch 240, Loss: 0.006227107718586922
Epoch 250, Loss: 0.005974981468170881
Epoch 260, Loss: 0.005745864473283291

In [661]:
print(cr3)

                                                   precision    recall  f1-score   support

                                        Aflatoxin       1.00      0.50      0.67         2
                                   abnormal smell       0.00      0.00      0.00         1
                                  alcohol content       1.00      1.00      1.00         1
                                        alkaloids       1.00      1.00      1.00         1
                                        allergens       1.00      0.67      0.80         3
                                           almond       1.00      1.00      1.00        13
             altered organoleptic characteristics       1.00      1.00      1.00         1
                           antibiotics, vet drugs       1.00      1.00      1.00         1
                                    bacillus spp.       0.75      1.00      0.86         3
                             bad smell / off odor       0.00      0.00      0.00         

In [662]:
cr4 = train_and_eval_model_for_label(label='product', encoder_dict=encoder_dict)

Epoch 0, Loss: 6.9296674728393555
Epoch 10, Loss: 5.685149669647217
Epoch 20, Loss: 4.25963830947876
Epoch 30, Loss: 2.4765918254852295
Epoch 40, Loss: 1.035162329673767
Epoch 50, Loss: 0.46802642941474915
Epoch 60, Loss: 0.16344209015369415
Epoch 70, Loss: 0.047426674515008926
Epoch 80, Loss: 0.02750089392066002
Epoch 90, Loss: 0.019803084433078766
Epoch 100, Loss: 0.015532148070633411
Epoch 110, Loss: 0.012973341159522533
Epoch 120, Loss: 0.011277452111244202
Epoch 130, Loss: 0.010043990798294544
Epoch 140, Loss: 0.009079455398023129
Epoch 150, Loss: 0.008283978328108788
Epoch 160, Loss: 0.007607229519635439
Epoch 170, Loss: 0.007020610384643078
Epoch 180, Loss: 0.006505331955850124
Epoch 190, Loss: 0.006049669347703457
Epoch 200, Loss: 0.005643607582896948
Epoch 210, Loss: 0.005279153119772673
Epoch 220, Loss: 0.004950209055095911
Epoch 230, Loss: 0.0046532368287444115
Epoch 240, Loss: 0.0043837884441018105
Epoch 250, Loss: 0.004139381926506758
Epoch 260, Loss: 0.003917389083653688


In [663]:
print(cr4)

                                                              precision    recall  f1-score   support

                                      Catfishes (freshwater)       1.00      1.00      1.00         1
                                       Fishes not identified       0.67      1.00      0.80         6
                         Precooked cooked pork meat products       0.20      1.00      0.33         1
                                               Veggie Burger       1.00      1.00      1.00         2
                                             alfalfa sprouts       1.00      1.00      1.00         2
                                                       algae       0.67      1.00      0.80         2
                                       all purpose seasoning       1.00      1.00      1.00         1
                                              almond kernels       0.00      0.00      0.00         1
                                             almond products       1.00      1.00