In [57]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from pipeline import build_preprocessing_pipeline
from sklearn.metrics import accuracy_score, f1_score, make_scorer, precision_score, recall_score
from sklearn.feature_selection import RFECV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from scipy.stats import uniform, randint
from scipy.stats import uniform, randint
from modelos import XGBWithThreshold
from metricas import custom_fbeta
from ipaddress import ip_address
from sklearn.preprocessing import LabelEncoder

import xgboost as xgb

pd.set_option('display.max_columns', 200)

In [58]:
df = pd.read_csv('data/train_test_network.csv')

# Separação: Dev-Teste

In [59]:
X = df.iloc[:, :-2]
y = df.iloc[:, -2:]

#X = X.drop(columns=['src_ip', 'src_port', 'dst_ip', 'dst_port'])

le = LabelEncoder()
y.type = le.fit_transform(y.type)

Xdev, Xtest, ydev, ytest = train_test_split(X,
                                            y,
                                            test_size=0.2,
                                            stratify=y,
                                            random_state=42)

Xdev = Xdev.reset_index(drop=True)
Xtest = Xtest.reset_index(drop=True)
ydev = ydev.reset_index(drop=True)
ytest = ytest.reset_index(drop=True)

In [60]:
Xdev.head(5)

Unnamed: 0,ts,src_ip,src_port,dst_ip,dst_port,proto,service,duration,src_bytes,dst_bytes,conn_state,missed_bytes,src_pkts,src_ip_bytes,dst_pkts,dst_ip_bytes,dns_query,dns_qclass,dns_qtype,dns_rcode,dns_AA,dns_RD,dns_RA,dns_rejected,ssl_version,ssl_cipher,ssl_resumed,ssl_established,ssl_subject,ssl_issuer,http_trans_depth,http_method,http_uri,http_version,http_request_body_len,http_response_body_len,http_status_code,http_user_agent,http_orig_mime_types,http_resp_mime_types,weird_name,weird_addl,weird_notice
0,1554271099,192.168.1.152,34296,192.168.1.152,10502,tcp,-,0.0,0,0,OTH,0,0,0,0,0,-,0,0,0,-,-,-,-,-,-,-,-,-,-,-,-,-,-,0,0,0,-,-,-,-,-,-
1,1554320087,192.168.1.152,41266,192.168.1.190,53,udp,dns,0.275332,0,298,SHR,0,0,0,2,354,-,0,0,0,-,-,-,-,-,-,-,-,-,-,-,-,-,-,0,0,0,-,-,-,-,-,-
2,1556203768,192.168.1.30,47508,192.168.1.184,443,tcp,-,60.934442,0,0,S3,0,3,164,2,112,-,0,0,0,-,-,-,-,-,-,-,-,-,-,-,-,-,-,0,0,0,-,-,-,-,-,-
3,1556172230,192.168.1.31,47876,176.28.50.165,80,tcp,http,1.315227,271,2177,SF,0,6,591,6,2497,-,0,0,0,-,-,-,-,-,-,-,-,-,-,-,-,-,-,0,0,0,-,-,-,-,-,-
4,1556249250,192.168.1.190,25861,203.119.86.101,53,udp,dns,0.031347,47,426,SF,0,1,75,1,454,104.3.in-addr.arpa,1,43,0,F,F,F,F,-,-,-,-,-,-,-,-,-,-,0,0,0,-,-,-,-,-,-


# Pré-processamento

In [61]:
df.dst_ip.apply(ip_address).apply(lambda x: x.is_private)

0          True
1          True
2          True
3          True
4          True
          ...  
461038    False
461039    False
461040     True
461041    False
461042    False
Name: dst_ip, Length: 461043, dtype: bool

In [62]:
# http_response_body_len nao e exatamente categorica, mas se considerarmos todas
# as entradas infrequentes como uma coisa so, ela vira categorica

# Textuais que podem ser consideradas categoricas:
#ssl_subject, ssl_issuer, dns_query
#Sinonimo para features textuais - features descritivas
features_textuais = ['http_user_agent', 'http_uri', 'ssl_subject', 'ssl_issuer', 'dns_query']
features_categoricas = ['weird_notice', 'weird_addl', 'weird_name', 'http_resp_mime_types', 'http_orig_mime_types', 'http_status_code', 'http_version',
                        'http_method', 'http_trans_depth', 'ssl_established',
                        'ssl_resumed', 'ssl_cipher', 'ssl_version', 'dns_rejected', 'dns_RA', 'dns_RD', 'dns_AA', 'dns_rcode', 'dns_qtype', 'dns_qclass',
                        'service', 'proto', 'conn_state']
features_numericas = ['duration', 'dst_pkts', 'src_ip_bytes', 'dst_ip_bytes', 'src_bytes', 'http_response_body_len', 'dst_bytes',
                     'missed_bytes', 'src_pkts', 'http_request_body_len']

features_ip = ['src_ip', 'dst_ip']

features_port = ['src_port', 'dst_port']

pipeline = build_preprocessing_pipeline(features_numericas, features_categoricas, features_textuais, features_ip, features_port)

Xdev_pre = pipeline.fit_transform(Xdev)
Xtest_pre = pipeline.transform(Xtest)



In [63]:
Xdev_pre.head(5)

Unnamed: 0,src_ip_host,src_ip_broadcast,src_ip_ipv6,src_ip_privado,src_ip_multicast,dst_ip_host,dst_ip_broadcast,dst_ip_ipv6,dst_ip_privado,dst_ip_multicast,src_port_well_known,src_port_registered,src_port_dynamic,dst_port_well_known,dst_port_registered,dst_port_dynamic,duration,dst_pkts,src_ip_bytes,dst_ip_bytes,src_bytes,http_response_body_len,dst_bytes,missed_bytes,src_pkts,http_request_body_len,weird_notice_F,weird_addl_-,weird_addl_43,weird_addl_46,weird_addl_48,weird_name_-,weird_name_DNS_RR_unknown_type,weird_name_TCP_ack_underflow_or_misorder,weird_name_above_hole_data_without_any_acks,weird_name_active_connection_reuse,weird_name_bad_TCP_checksum,weird_name_bad_UDP_checksum,weird_name_connection_originator_SYN_ack,weird_name_data_before_established,weird_name_dnp3_corrupt_header_checksum,weird_name_inappropriate_FIN,weird_name_possible_split_routing,http_resp_mime_types_-,http_resp_mime_types_application/ocsp-response,http_resp_mime_types_application/vnd.ms-cab-compressed,http_resp_mime_types_application/x-debian-package,http_resp_mime_types_application/xml,http_resp_mime_types_image/jpeg,http_resp_mime_types_image/png,http_resp_mime_types_text/html,http_resp_mime_types_text/json,http_resp_mime_types_text/plain,http_orig_mime_types_-,http_orig_mime_types_application/soap+xml,http_orig_mime_types_application/xml,http_status_code_0,http_status_code_101,http_status_code_200,http_status_code_206,http_status_code_302,http_status_code_304,http_status_code_403,http_status_code_404,http_version_1.1,http_method_-,http_method_GET,http_method_HEAD,http_method_POST,http_trans_depth_-,http_trans_depth_1,http_trans_depth_10,http_trans_depth_2,http_trans_depth_3,http_trans_depth_4,http_trans_depth_5,http_trans_depth_6,http_trans_depth_7,http_trans_depth_8,http_trans_depth_9,ssl_established_-,ssl_established_F,ssl_established_T,ssl_resumed_-,ssl_resumed_F,ssl_resumed_T,ssl_cipher_-,ssl_cipher_TLS_AES_128_GCM_SHA256,ssl_cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,ssl_cipher_TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,ssl_cipher_TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,ssl_version_-,ssl_version_TLSv10,ssl_version_TLSv12,ssl_version_TLSv13,dns_rejected_-,dns_rejected_F,dns_rejected_T,dns_RA_-,dns_RA_F,dns_RA_T,dns_RD_-,dns_RD_F,dns_RD_T,dns_AA_-,dns_AA_F,dns_AA_T,dns_rcode_0,dns_rcode_1,dns_rcode_2,dns_rcode_3,dns_rcode_5,dns_qtype_0,dns_qtype_1,dns_qtype_2,dns_qtype_6,dns_qtype_12,dns_qtype_16,dns_qtype_28,dns_qtype_32,dns_qtype_33,dns_qtype_43,dns_qtype_48,dns_qtype_255,dns_qclass_0,dns_qclass_1,dns_qclass_32769,service_-,service_dce_rpc,service_dhcp,service_dns,service_ftp,service_gssapi,service_http,service_smb,service_smb;gssapi,service_ssl,proto_icmp,proto_tcp,proto_udp,conn_state_OTH,conn_state_REJ,conn_state_RSTO,conn_state_RSTOS0,conn_state_RSTR,conn_state_RSTRH,conn_state_S0,conn_state_S1,conn_state_S2,conn_state_S3,conn_state_SF,conn_state_SH,conn_state_SHR,http_user_agent_infrequent_sklearn,http_uri_infrequent_sklearn,ssl_subject_infrequent_sklearn,ssl_issuer_infrequent_sklearn,dns_query_infrequent_sklearn
0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,-0.018011,-0.012042,-0.011436,-0.011456,-0.010701,-0.002913,-0.010022,-0.004735,-0.0134,-0.004549,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,-0.017442,-0.008912,-0.011436,-0.009413,-0.010701,-0.002913,-0.009999,-0.004735,-0.0134,-0.004549,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.107869,-0.008912,-0.010385,-0.01081,-0.010701,-0.002913,-0.010022,-0.004735,-0.007835,-0.004549,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,-0.015294,-0.002651,-0.007648,0.002959,-0.010679,-0.002913,-0.009854,-0.004735,-0.00227,-0.004549,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,-0.017947,-0.010477,-0.010955,-0.008835,-0.010697,-0.002913,-0.009989,-0.004735,-0.011545,-0.004549,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


# Otimização de Hiperparâmetros

In [64]:
if False:
    estimator = XGBWithThreshold()
    #estimator = xgb.XGBClassifier()

    param_distributions = {
        'max_depth': randint(2, 40),
        'n_estimators': randint(10, 200),
        'learning_rate': uniform(1e-3, 10),  # Valores contínuos entre 0.001 e 0.3
        'reg_lambda': np.logspace(-4, 1, 100),  # Valores contínuos para regularização L2
        'threshold':uniform(0.5, 0.5),
        'random_state': [42]  # Valor fixo
    }

    rscv = RandomizedSearchCV(
        estimator=estimator,
        param_distributions=param_distributions,
        n_iter=100,  # Número de combinações aleatórias a serem testadas
        scoring=make_scorer(custom_fbeta, greater_is_better=True),
        cv=5,  # Stratified K-Fold Cross Validation
        verbose=3,
        return_train_score=True,
        random_state=42,  # Garante reprodutibilidade
        n_jobs=10,
        error_score='raise'
    )

    # Treinar o modelo com RandomizedSearchCV
    rscv.fit(Xdev_pre, ydev.label)

    # Exibir os melhores hiperparâmetros encontrados
    print("Melhores hiperparâmetros:", rscv.best_params_)
    print("Melhor score de validação cruzada:", rscv.best_score_)

In [65]:
#Melhores hiperparametros com possibilidade de modelos maiores
#best_params = {'learning_rate': 0.5157875124998935, 'max_depth': 38, 'n_estimators': 138, 'random_state': 42, 'reg_lambda': 2.395618906669724, 'threshold': 0.6448948720912231}
#best_params = rscv.best_params_
best_params = {'learning_rate': 0.23162425041415757, 'max_depth': 28, 'n_estimators': 68, 'random_state': 42, 'reg_lambda': 0.011768119524349979, 'threshold': 0.5233328316068078}

# Seleção de Features

In [66]:
if False:
    estimator = XGBWithThreshold(**best_params)

    rfecv = RFECV(
        estimator=estimator, 
        step=1,  # Número de features removidas por vez
        cv=5,  # Validação cruzada estratificada
        scoring=make_scorer(custom_fbeta, greater_is_better=True),  # Métrica usada para avaliação
        n_jobs=1,  # Usar todos os núcleos disponíveis
        verbose=3
    )

    # 6. Executar a seleção de features no conjunto de treino
    rfecv.fit(Xdev_pre, ydev.label)

    # 7. Analisar os resultados
    print(f"Número ótimo de features selecionadas: {rfecv.n_features_}")
    print("Features selecionadas:", Xdev_pre.columns[rfecv.support_])

In [67]:
features_selecionadas = ['src_ip_ipv6', 'src_ip_privado', 'dst_ip_broadcast', 'dst_ip_ipv6',
       'dst_ip_privado', 'dst_ip_multicast', 'src_port_well_known',
       'src_port_registered', 'src_port_dynamic', 'dst_port_well_known',
       'dst_port_registered', 'dst_port_dynamic', 'duration', 'dst_pkts',
       'src_ip_bytes', 'dst_ip_bytes', 'src_bytes', 'dst_bytes',
       'missed_bytes', 'src_pkts', 'weird_notice_F', 'http_status_code_0',
       'ssl_established_-', 'ssl_resumed_T', 'dns_rejected_-',
       'dns_rejected_F', 'dns_rejected_T', 'dns_RA_F', 'dns_RA_T', 'dns_RD_F',
       'dns_RD_T', 'dns_AA_F', 'dns_rcode_0', 'dns_rcode_2', 'dns_rcode_3',
       'dns_qtype_0', 'dns_qtype_1', 'dns_qtype_12', 'dns_qtype_28',
       'dns_qtype_33', 'dns_qclass_1', 'service_-', 'service_dns',
       'service_ftp', 'service_http', 'service_ssl', 'proto_icmp', 'proto_tcp',
       'proto_udp', 'conn_state_OTH', 'conn_state_REJ', 'conn_state_RSTO',
       'conn_state_RSTOS0', 'conn_state_RSTR', 'conn_state_RSTRH',
       'conn_state_S0', 'conn_state_S1', 'conn_state_S2', 'conn_state_S3',
       'conn_state_SF', 'conn_state_SH', 'conn_state_SHR',
       'dns_query_infrequent_sklearn']

# Treinamento

In [68]:
model = XGBWithThreshold(**best_params)

Xtrain, Xval, ytrain, yval = train_test_split(Xdev_pre[features_selecionadas], ydev, test_size=0.2, random_state=42)

model.fit(Xtrain, ytrain.label)

In [69]:
ytest

Unnamed: 0,label,type
0,1,2
1,0,5
2,1,1
3,0,5
4,0,5
...,...,...
92204,1,0
92205,0,5
92206,0,5
92207,0,5


In [79]:
print('Acurácia:', accuracy_score(ytest.label, model.predict(Xtest_pre[features_selecionadas])))
print('F1-score:', f1_score(ytest.label, model.predict(Xtest_pre[features_selecionadas])))
print('Fbeta-score:', custom_fbeta(ytest.label, model.predict(Xtest_pre[features_selecionadas])))
print('Precision:', precision_score(ytest.label, model.predict(Xtest_pre[features_selecionadas])))
print('Recall:', recall_score(ytest.label, model.predict(Xtest_pre[features_selecionadas])))

Acurácia: 0.9955644242969776
F1-score: 0.993659602833801
Fbeta-score: 0.9939287525204618
Precision: 0.9922905443061489
Recall: 0.9950324443478531


In [71]:
from sklearn.preprocessing import LabelEncoder

classes = sorted(ytest.type.unique())
metrics = {'Precision':[], 'Recall':[], 'Accuracy':[], 'F1score':[]}
for c in classes:
    mask_c = ytest.type == c
    pred = model.predict(Xtest_pre[features_selecionadas])[mask_c]
    real = ytest.label[mask_c]

    metrics['Precision'].append(precision_score(real, pred))
    metrics['Recall'].append(recall_score(real, pred))
    metrics['Accuracy'].append(accuracy_score(real, pred))
    metrics['F1score'].append(f1_score(real, pred))

pd.DataFrame(metrics, index=le.classes_)

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,Precision,Recall,Accuracy,F1score
backdoor,1.0,1.0,1.0,1.0
ddos,1.0,0.996,0.996,0.997996
dos,1.0,0.998,0.998,0.998999
injection,1.0,0.9995,0.9995,0.99975
mitm,1.0,0.947368,0.947368,0.972973
normal,0.0,0.0,0.99585,0.0
password,1.0,0.99975,0.99975,0.999875
ransomware,1.0,0.9715,0.9715,0.985544
scanning,1.0,0.99925,0.99925,0.999625
xss,1.0,0.99875,0.99875,0.999375


In [78]:
mask_c = ytest.type == 1
pred = model.predict(Xtest_pre[features_selecionadas])[mask_c]
real = ytest.label[mask_c]
recall_score(real, pred)

0.996

In [48]:
precision_score(real, pred)

1.0

Precisão = TP/(TP + FP)

In [83]:
ytest_pred = model.predict(Xtest_pre[features_selecionadas])
ytest_real = ytest.label
TP = (ytest_real[ytest_real == 1] == ytest_pred[ytest_real == 1]).sum()
TN = (ytest_real[ytest_real == 0] == ytest_pred[ytest_real == 0]).sum()
FP = (ytest_real[ytest_real == 0] != ytest_pred[ytest_real == 0]).sum()
FN = (ytest_real[ytest_real == 1] != ytest_pred[ytest_real == 1]).sum()

In [89]:
len(ytest_real)

92209

In [92]:
print('TP:', TP)
print('TN:', TN)
print('FP:', FP)
print('FN:', FN)

TP: 32049
TN: 59751
FP: 249
FN: 160
