In [1]:
import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from sklearn.metrics import confusion_matrix, classification_report

In [5]:
def RNN(path_train,path_test, epoch_num):
    Train = pd.read_csv(path_train)
    Test = pd.read_csv(path_test)
    
    X_train = Train.iloc[:,:-1]
    Y_train = Train.iloc[:,-1]

    X_test = Test.iloc[:,:-1]
    Y_test = Test.iloc[:,-1]

    #binarize the target
    Y_train_binary = Y_train.apply(lambda x: 1 if x > 0 else 0)
    Y_test_binary = Y_test.apply(lambda x: 1 if x > 0 else 0)

    # CNN
    model = Sequential()
    model.add(Conv1D(filters=256, kernel_size=3, activation='relu', input_shape=(22,1)))
    model.add(Conv1D(filters=256, kernel_size=3, activation='relu'))
    model.add(MaxPooling1D(pool_size=2))

    model.add(Conv1D(filters=512, kernel_size=3, activation='relu'))
    model.add(MaxPooling1D(pool_size=2))

    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(2, activation='sigmoid'))

    model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    model.fit(X_train, Y_train_binary, epochs=epoch_num, batch_size=32)
    Y_pred = model.predict(X_test).argmax(axis=1)
    
    cm = confusion_matrix(Y_pred, Y_test_binary)
    print(cm)
    print(classification_report(Y_test_binary, Y_pred))
    
    mismatch = [i for i, (a,b) in enumerate(zip(Y_pred, Y_test_binary)) if a != b]
    print(mismatch)

# Cleveland

In [6]:
path_train = '../traintestdata/cle_train.csv'
path_test = '../traintestdata/cle_test.csv'
RNN(path_train,path_test,19)

Epoch 1/19
Epoch 2/19
Epoch 3/19
Epoch 4/19
Epoch 5/19
Epoch 6/19
Epoch 7/19
Epoch 8/19
Epoch 9/19
Epoch 10/19
Epoch 11/19
Epoch 12/19
Epoch 13/19
Epoch 14/19
Epoch 15/19
Epoch 16/19
Epoch 17/19
Epoch 18/19
Epoch 19/19
[[40  8]
 [11 32]]
              precision    recall  f1-score   support

           0       0.83      0.78      0.81        51
           1       0.74      0.80      0.77        40

    accuracy                           0.79        91
   macro avg       0.79      0.79      0.79        91
weighted avg       0.79      0.79      0.79        91

[0, 7, 9, 10, 19, 21, 22, 28, 30, 32, 38, 52, 53, 60, 61, 67, 69, 70, 77]


# Virginia

In [8]:
path_train = '../traintestdata/vir_train.csv'
path_test = '../traintestdata/vir_test.csv'
RNN(path_train,path_test, 15)

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15
[[ 3  3]
 [10 44]]
              precision    recall  f1-score   support

           0       0.50      0.23      0.32        13
           1       0.81      0.94      0.87        47

    accuracy                           0.78        60
   macro avg       0.66      0.58      0.59        60
weighted avg       0.75      0.78      0.75        60

[3, 6, 14, 22, 28, 32, 34, 42, 48, 49, 56, 57, 58]


# Hungary

In [9]:
path_train = '../traintestdata/hun_train.csv'
path_test = '../traintestdata/hun_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[53  8]
 [ 4 24]]
              precision    recall  f1-score   support

           0       0.87      0.93      0.90        57
           1       0.86      0.75      0.80        32

    accuracy                           0.87        89
   macro avg       0.86      0.84      0.85        89
weighted avg       0.86      0.87      0.86        89

[1, 8, 12, 14, 23, 29, 42, 53, 54, 61, 64, 86]


# Switzerland

In [10]:
path_train = '../traintestdata/swi_train.csv'
path_test = '../traintestdata/swi_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[ 0  0]
 [ 2 35]]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         2
           1       0.95      1.00      0.97        35

    accuracy                           0.95        37
   macro avg       0.47      0.50      0.49        37
weighted avg       0.89      0.95      0.92        37

[20, 25]


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Europe

In [11]:
path_train = '../traintestdata/euro_train.csv'
path_test = '../traintestdata/euro_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[51 11]
 [ 8 56]]
              precision    recall  f1-score   support

           0       0.82      0.86      0.84        59
           1       0.88      0.84      0.85        67

    accuracy                           0.85       126
   macro avg       0.85      0.85      0.85       126
weighted avg       0.85      0.85      0.85       126

[1, 12, 14, 17, 23, 29, 33, 42, 46, 53, 54, 61, 64, 86, 101, 104, 109, 116, 117]


# Combined

In [12]:
path_train = '../traintestdata/com_train.csv'
path_test = '../traintestdata/com_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[101  24]
 [ 21 130]]
              precision    recall  f1-score   support

           0       0.81      0.83      0.82       122
           1       0.86      0.84      0.85       154

    accuracy                           0.84       276
   macro avg       0.83      0.84      0.84       276
weighted avg       0.84      0.84      0.84       276

[11, 38, 42, 52, 59, 63, 76, 79, 83, 88, 91, 104, 106, 116, 121, 122, 125, 127, 130, 135, 146, 148, 155, 158, 164, 167, 169, 173, 184, 188, 190, 195, 214, 217, 221, 222, 225, 228, 234, 247, 258, 259, 264, 274, 275]


# Testing models on different data

In [13]:
path_train = '../traintestdata/cle_train.csv'
path_test = '../traintestdata/euro_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[22 12]
 [37 55]]
              precision    recall  f1-score   support

           0       0.65      0.37      0.47        59
           1       0.60      0.82      0.69        67

    accuracy                           0.61       126
   macro avg       0.62      0.60      0.58       126
weighted avg       0.62      0.61      0.59       126

[0, 1, 3, 9, 10, 12, 13, 14, 16, 18, 20, 21, 22, 23, 26, 29, 31, 32, 35, 36, 37, 39, 40, 42, 43, 44, 45, 46, 47, 48, 50, 51, 54, 60, 62, 66, 71, 72, 75, 85, 87, 88, 90, 95, 104, 109, 111, 116, 117]


In [14]:
path_train = '../traintestdata/euro_train.csv'
path_test = '../traintestdata/cle_test.csv'
RNN(path_train,path_test, 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
[[18  4]
 [33 36]]
              precision    recall  f1-score   support

           0       0.82      0.35      0.49        51
           1       0.52      0.90      0.66        40

    accuracy                           0.59        91
   macro avg       0.67      0.63      0.58        91
weighted avg       0.69      0.59      0.57        91

[0, 2, 3, 5, 7, 10, 11, 14, 19, 22, 23, 24, 25, 28, 31, 32, 33, 36, 38, 39, 46, 47, 48, 51, 60, 61, 62, 63, 64, 67, 69, 77, 78, 83, 87, 88, 90]
