In [1]:
import numpy as np
import os
import warnings
warnings.filterwarnings("ignore")
  
from data_loader import load_mnist_data
from supervised_models import logit, xgb_model, mlp

%load_ext autoreload
%autoreload 2

from vime_self import vime_self
from vime_semi import vime_semi
from vime_utils import perf_metric

Using TensorFlow backend.


In [2]:
import dataset
import datetime
from datetime import timedelta
from parser_wco import get_parser_wco
import numpy as np 
import pandas as pd 
import torch
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import accuracy_score, roc_auc_score
from utils import *
# from pytorch_lightning import Trainer
# from torch_geometric.utils import to_undirected

## Load data

In [3]:
data = pd.read_csv('../../data/ndata.csv')

In [4]:
data.describe(include='all').to_csv('describe_ndata.csv', index=False)

In [5]:
data = dataset.Ndata(path='../../data/ndata.csv')
parser = get_parser_wco()
args = parser.parse_args(args=
                         ["--data","real-n", 
                          "--sampling","xgb",
                          "--train_from","20140101",
                          "--test_from","20170101",
                          "--test_length","365",
                          "--valid_length","90",
                          "--initial_inspection_rate", "5",
                          "--final_inspection_rate", "10",
                         ])

In [6]:
# args
seed = args.seed
epochs = args.epoch
dim = args.dim
lr = args.lr
weight_decay = args.l2
initial_inspection_rate = args.initial_inspection_rate
inspection_rate_option = args.inspection_plan
train_begin = args.train_from 
test_begin = args.test_from
test_length = args.test_length
valid_length = args.valid_length
chosen_data = args.data
numWeeks = args.numweeks
semi_supervised = args.semi_supervised
save = args.save
gpu_id = args.device

# Initial dataset split
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

# Initial dataset split
train_start_day = datetime.date(int(train_begin[:4]), int(train_begin[4:6]), int(train_begin[6:8]))
test_start_day = datetime.date(int(test_begin[:4]), int(test_begin[4:6]), int(test_begin[6:8]))
test_length = timedelta(days=test_length)    
test_end_day = test_start_day + test_length
valid_length = timedelta(days=valid_length)
valid_start_day = test_start_day - valid_length

# data
data.split(train_start_day, valid_start_day, test_start_day, test_end_day, valid_length, test_length, args)
data.featureEngineering()

Data size:
Train labeled: (54134, 52), Train unlabeled: (1028538, 52), Valid labeled: (70917, 52), Valid unlabeled: (0, 26), Test: (274808, 52)
Checking label distribution
Training: 0.05022795615481618
Validation: 0.035556788645191434
Testing: 0.025360899366070794


## VIME

In [7]:
# Experimental parameters
label_no = 1000  
model_sets = ['logit','xgboost','mlp']
  
# Hyper-parameters
p_m = 0.3
alpha = 2.0
K = 3
beta = 1.0
label_data_rate = 0.1

# Metric
metric = 'auc'
  
# Define output
results = np.zeros([len(model_sets)+2])  
results = {}

# Scale stuff
scalar = StandardScaler()
x_train = scalar.fit_transform(data.dftrainx_lab)
y_train = data.train_cls_label
y_train = y_train.reshape(-1,1)
x_test = scalar.fit_transform(data.dftestx)
y_test = data.test_cls_label
y_test = y_test.reshape(-1,1)
x_unlab = scalar.fit_transform(data.dftrainx_unlab)
x_valid = scalar.fit_transform(data.dfvalidx_lab)
y_valid = data.valid_cls_label

#---Supervised Models for VIME---#

# Logistic regression
y_test_hat = logit(x_train, y_train, x_test)
results[0] = roc_auc_score(y_test, y_test_hat[:,1])

# XGBoost
y_test_hat = xgb_model(x_train, y_train, x_test)    
results[1] = roc_auc_score(y_test, y_test_hat[:,1])   

# MLP
mlp_parameters = dict()
mlp_parameters['hidden_dim'] = 100
mlp_parameters['epochs'] = 10
mlp_parameters['activation'] = 'relu'
mlp_parameters['batch_size'] = 100
      
y_test_hat = mlp(x_train, y_train, x_test, mlp_parameters)
results[2] = roc_auc_score(y_test, y_test_hat[:,1])

# Report performance
for m_it in range(len(model_sets)):  
    
  model_name = model_sets[m_it]  
    
  print('Supervised Performance, Model Name: ' + model_name + 
        ', Performance: ' + str(results[m_it]))


#---Self - supervised VIME---#

# Train VIME-Self
vime_self_parameters = dict()
vime_self_parameters['batch_size'] = 128
vime_self_parameters['epochs'] = 10
vime_self_encoder = vime_self(x_unlab, p_m, alpha, vime_self_parameters)
  
# Save encoder
if not os.path.exists('save_model'):
  os.makedirs('save_model')

file_name = './save_model/encoder_model.h5'
  
vime_self_encoder.save(file_name)  
        
# Test VIME-Self
x_train_hat = vime_self_encoder.predict(x_train)
x_test_hat = vime_self_encoder.predict(x_test)
      
y_test_hat = mlp(x_train_hat, y_train, x_test_hat, mlp_parameters)
results[3] = roc_auc_score(y_test, y_test_hat[:,1])
    
print('VIME-Self Performance: ' + str(results[3]))

#---Semi Supervised VIME---#

# # Train VIME-Semi
# vime_semi_parameters = dict()
# vime_semi_parameters['hidden_dim'] = 100
# vime_semi_parameters['batch_size'] = 128
# vime_semi_parameters['iterations'] = 1000
# y_test_hat = vime_semi(x_train, y_train, x_unlab, x_test, 
#                        vime_semi_parameters, p_m, K, beta, file_name)

# # Test VIME
# results[4] = roc_auc_score(y_test, y_test_hat)
  
# print('VIME Performance: '+ str(results[4]))


#---Performance of VIME stuff---#

for m_it in range(len(model_sets)):  
    
  model_name = model_sets[m_it]  
    
  print('Supervised Performance, Model Name: ' + model_name + 
        ', Performance: ' + str(results[m_it]))
    
print('VIME-Self Performance: ' + str(results[m_it+1]))
  
# print('VIME Performance: '+ str(results[m_it+2]))

#---Modify the best thresh function---#

def find_best_threshold_local(y_pred_prob,x_list,y_test,best_thresh = None):
    '''
    dtype model: scikit-learn classifier model
    dtype x_list: list or array to predict the probability result
    dtype y_test: array of true labels
    
    Find the best probability threshold to separate probability to 0 and 1
    '''
    y_pred_prob = y_pred_prob
    threshold_list = np.arange(0.1,0.6,0.1)
    best_auc = 0.5    # 0.5 is random for AUC.
    
    if best_thresh ==None:
        for th in threshold_list:
            y_pred_label = (y_pred_prob > th)*1 
            try:
                auc_score = roc_auc_score(y_test,y_pred_prob)
            except ValueError:
                auc_score = 0.5
            if auc_score > best_auc:
                best_auc = auc_score
                best_thresh = th 
        return best_thresh, best_auc
    
    else:
        y_pred_label = (y_pred_prob > best_thresh)*1 
        best_auc = roc_auc_score(y_test,y_pred_label)
    print("AUC-score equals to:%.4f"%(best_auc))
    return best_auc

(48720, 2)

Supervised Performance, Model Name: logit, Performance: 0.5587301840914923
Supervised Performance, Model Name: xgboost, Performance: 0.5814185198482685
Supervised Performance, Model Name: mlp, Performance: 0.6299288374648435
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
(48720, 2)
VIME-Self Performance: 0.6658698876732867
Supervised Performance, Model Name: logit, Performance: 0.5587301840914923
Supervised Performance, Model Name: xgboost, Performance: 0.5814185198482685
Supervised Performance, Model Name: mlp, Performance: 0.6299288374648435
VIME-Self Performance: 0.6658698876732867


In [14]:
# Check performance 
best_thresh, best_auc = find_best_threshold_local(y_test_hat, x_valid, y_valid)
overall_f1,auc,pr, re, f, rev = metrics(y_test_hat[:,1], y_test.reshape(-1),data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 2749
Precision: 0.0313, Recall: 0.0127, Revenue: 0.0138
Checking top 2% suspicious transactions: 5497
Precision: 0.0367, Recall: 0.0297, Revenue: 0.0722
Checking top 5% suspicious transactions: 13741
Precision: 0.0339, Recall: 0.0686, Revenue: 0.1566
Checking top 10% suspicious transactions: 27481
Precision: 0.0274, Recall: 0.1108, Revenue: 0.2563


## Prepare DATA

In [31]:
from pygData_util import *

In [None]:
categories=["importer.id","HS6"]
gdata = GraphData(data,use_xgb=True, categories=categories, pos_weight = 50)

In [10]:
best_thresh, best_auc = find_best_threshold(gdata.xgb,data.dfvalidx_lab, data.valid_cls_label)
xgb_test_pred = gdata.xgb.predict_proba(data.dfvalidx_lab)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.valid_cls_label,data.valid_reg_label,best_thresh)
print("-"*50)
xgb_test_pred = gdata.xgb.predict_proba(data.dftestx)[:,-1]
overall_f1,auc,pr, re, f, rev = metrics(xgb_test_pred, data.test_cls_label,data.test_reg_label,best_thresh)

Checking top 1% suspicious transactions: 710
Precision: 0.2127, Recall: 0.0620, Revenue: 0.0370
Checking top 2% suspicious transactions: 1419
Precision: 0.1473, Recall: 0.0858, Revenue: 0.0460
Checking top 5% suspicious transactions: 3546
Precision: 0.0773, Recall: 0.1125, Revenue: 0.0595
Checking top 10% suspicious transactions: 7092
Precision: 0.0468, Recall: 0.1363, Revenue: 0.1279
--------------------------------------------------
Checking top 1% suspicious transactions: 2749
Precision: 0.1186, Recall: 0.0480, Revenue: 0.0928
Checking top 2% suspicious transactions: 5497
Precision: 0.0786, Recall: 0.0636, Revenue: 0.1278
Checking top 5% suspicious transactions: 13741
Precision: 0.0418, Recall: 0.0846, Revenue: 0.1699
Checking top 10% suspicious transactions: 27481
Precision: 0.0318, Recall: 0.1284, Revenue: 0.2528


In [11]:
stage = "train_lab"
trainLab_data = gdata.get_data(stage)
train_nodeidx = torch.tensor(gdata.get_AttNode(stage))
trainLab_data.node_idx = train_nodeidx

In [12]:
stage = "train_unlab"
unlab_data = gdata.get_data(stage)
unlab_nodeidx = torch.tensor(gdata.get_AttNode(stage))
unlab_data.node_idx = unlab_nodeidx

In [13]:
stage = "valid"
valid_data = gdata.get_data(stage)
valid_nodeidx = torch.tensor(gdata.get_AttNode(stage))
valid_data.node_idx = valid_nodeidx

In [14]:
stage = "test"
test_data = gdata.get_data(stage)
test_nodeidx = torch.tensor(gdata.get_AttNode(stage))
test_data.node_idx = test_nodeidx

In [15]:
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)

## New Sampler

In [16]:
from torch_cluster import random_walk
from torch_geometric.data import NeighborSampler 
from pytorch_lightning import LightningDataModule

In [17]:
class UnsupSampler(NeighborSampler):
    def sample(self, batch):
        batch = torch.tensor(batch)
        row, col, _ = self.adj_t.coo()

        # For each node in `batch`, we sample a direct neighbor (as positive
        # example) and a random node (as negative example):
        pos_batch = random_walk(row, col, batch, walk_length=1,
                                coalesced=False)[:, 1]

        neg_batch = torch.randint(0, self.adj_t.size(1), (batch.numel(), ),
                                  dtype=torch.long)

        batch = torch.cat([batch, pos_batch, neg_batch], dim=0)
        return super(UnsupSampler, self).sample(batch)

In [18]:
class Batch(NamedTuple):
    '''
    convert batch data for pytorch-lightning
    '''
    x: Tensor
    y: Tensor
    rev: Tensor
    adjs_t: NamedTuple
    def to(self, *args, **kwargs):
        return Batch(
            x=self.x.to(*args, **kwargs),
            y=self.y.to(*args, **kwargs),
            rev=self.rev.to(*args, **kwargs),
            adjs_t=[(adj_t.to(*args, **kwargs), eid.to(*args, **kwargs), size) for adj_t, eid, size in self.adjs_t],
        )


In [19]:
class UnsupData(LightningDataModule):
    def __init__(self,data,sizes, batch_size = 128):
        '''
        defining dataloader with NeighborSampler to extract k-hop subgraph.
        Args:
            data (Graphdata): graph data for the edges and node index
            sizes ([int]): The number of neighbors to sample for each node in each layer. 
                           If set to :obj:`sizes[l] = -1`, all neighbors are included
            batch_size (int): batch size for training
        '''
        super(UnsupData,self).__init__()
        self.data = data
        self.sizes = sizes
        self.valid_sizes = [-1 for i in self.sizes]
        self.batch_size = batch_size

    def train_dataloader(self):
        return UnsupSampler(self.data.train_edge, sizes=self.sizes,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=True,num_workers=8)
    
    def test_dataloader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.test_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)
    
    def label_loader(self):
        return UnsupSampler(self.data.test_edge, sizes=self.sizes,node_idx=self.data.train_idx,
                               batch_size=self.batch_size,transform=self.convert_batch,
                               shuffle=False,num_workers=8)

    def convert_batch(self, batch_size, n_id, adjs):
        return Batch(
            x=self.data.x[n_id],
            y=self.data.y[n_id[:batch_size]],
            rev = self.data.rev[n_id[:batch_size]],
            adjs_t=adjs,
        )

## Model

In [20]:
from models import MLP, GNNStack, UselessConv, Mish

In [21]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning import LightningModule, seed_everything
from torchtools.optim import RangerLars
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
class PretrainGNN(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.save_hyperparameters()
        self.input_dim = input_dim
        self.dim = hidden_dim*2
        self.numLayers = numLayers
        self.layers = [self.dim, self.dim//2] #* (numLayers+1)
        self.bn = nn.BatchNorm1d(self.dim)
        self.act = Mish()
        self.useXGB = useXGB
        
        # GNN layer
        if self.useXGB:
            self.initEmbedding = nn.Embedding(self.input_dim, self.dim, padding_idx=0)
        else:
            self.initEmbedding = MLP(self.input_dim, self.dim, Numlayer=2)
        self.initGNN = UselessConv()
        self.GNNs = GNNStack(self.layers,self.numLayers)

    def forward(self, x,adjs):
        # update node embedding
        leaf_emb = self.initEmbedding(x)
        if self.useXGB:
            leaf_emb = torch.sum(leaf_emb,dim=1) # summation over the trees
            leaf_emb = self.bn(leaf_emb)
            leaf_emb = self.act(leaf_emb)
        
        # first update 
        firstHop_neighbor = adjs[-1][0]
        leaf_emb = self.initGNN(leaf_emb,to_undirected(firstHop_neighbor))
        
        # GNN 
        embeddings = self.GNNs(leaf_emb, adjs)
        
        return embeddings[-1]
    
    def training_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        pos_loss = F.logsigmoid((out * pos_out).sum(-1)).mean()
        neg_loss = F.logsigmoid(-(out * neg_out).sum(-1)).mean()
        train_loss = -pos_loss - neg_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def test_step(self, batch, batch_idx: int):
        out = self(batch.x, batch.adjs_t)
        out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
        return out
    
    def test_epoch_end(self, val_step_outputs):
        val_step_outputs = torch.cat(val_step_outputs)
        val_step_outputs = val_step_outputs.cpu().detach().numpy()
        return {"log":{"predictions":val_step_outputs}}
    
    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.01, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.8)
        return [optimizer], [scheduler]

In [23]:
# model config
seed_everything(2345)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [50,20]
numLayers = len(sizes)
batch_size = 1024

model = PretrainGNN(input_dim, hidden_size, numLayers, useXGB=gdata.use_xgb)

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = UnsupData(stacked_data, sizes = sizes, batch_size=batch_size)
checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',    
    dirpath='./saved_model',
    filename='Tdata5-pretrain-{train_loss:.4f}',
    save_top_k=1,
    mode='min',
)
trainer = Trainer(gpus=[1], max_epochs=3,
#                   callbacks=[checkpoint_callback],
                 )
trainer.fit(model, train_dataloader=datamodule.train_dataloader(),
           )

Global seed set to 2345
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]

  | Name          | Type        | Params
----------------------------------------------
0 | bn            | BatchNorm1d | 128   
1 | act           | Mish        | 0     
2 | initEmbedding | Embedding   | 73.2 K
3 | initGNN       | UselessConv | 0     
4 | GNNs          | GNNStack    | 33.5 K
----------------------------------------------
106 K     Trainable params
0         Non-trainable params
106 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [24]:
trainer.test()

1

In [25]:
# trainer.save_checkpoint("./saved_model/pretrained.ckpt")

In [26]:
from tqdm import tqdm_notebook

In [27]:
test_embeddings = []
for batch in tqdm_notebook(datamodule.test_dataloader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    test_embeddings.append(out.cpu().detach().numpy())

HBox(children=(FloatProgress(value=0.0, max=269.0), HTML(value='')))




In [28]:
embeddings = []
for batch in tqdm_notebook(datamodule.label_loader()):
    batch = batch.to(model.device)
    out = model(batch.x, batch.adjs_t)
    out, pos_out, neg_out = out.split(out.size(0) // 3, dim=0)
    embeddings.append(out.cpu().detach().numpy())

HBox(children=(FloatProgress(value=0.0, max=53.0), HTML(value='')))




In [29]:
embeddings = np.concatenate(embeddings)
test_embeddings = np.concatenate(test_embeddings)

In [30]:
from sklearn.linear_model import LogisticRegression

In [31]:
lr = LogisticRegression(class_weight={1:50,0:1})
lr.fit(embeddings,data.train_cls_label)

LogisticRegression(C=1.0, class_weight={0: 1, 1: 50}, dual=False,
                   fit_intercept=True, intercept_scaling=1, l1_ratio=None,
                   max_iter=100, multi_class='warn', n_jobs=None, penalty='l2',
                   random_state=None, solver='warn', tol=0.0001, verbose=0,
                   warm_start=False)

In [32]:
y_prob = lr.predict_proba(test_embeddings)[:,1]

In [33]:
_ = metrics(y_prob, data.test_cls_label,data.test_reg_label,None)

Checking top 1% suspicious transactions: 2749
Precision: 0.0291, Recall: 0.0118, Revenue: 0.0156
Checking top 2% suspicious transactions: 5497
Precision: 0.0320, Recall: 0.0259, Revenue: 0.0663
Checking top 5% suspicious transactions: 13741
Precision: 0.0292, Recall: 0.0590, Revenue: 0.1169
Checking top 10% suspicious transactions: 27481
Precision: 0.0310, Recall: 0.1252, Revenue: 0.2015


## Fine tuning

In [34]:
from torch_geometric.nn import GATConv,TransformerConv
from pygData_util import *

In [35]:
class Predictor(LightningModule):
    def __init__(self,input_dim, hidden_dim, numLayers, useXGB=True):
        super().__init__()
        self.gnn_encoder = PretrainGNN(input_dim, hidden_size, numLayers, useXGB)
        self.dim = hidden_dim * 2
        
        # output
        self.clsLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.revLayer = nn.Linear(self.dim,1) #GATConv(self.dim,1)
        self.loss_func = nn.BCEWithLogitsLoss(pos_weight = torch.tensor([20])) #FocalLoss(logits=True)
        
    def load_fromPretrain(self,path):
        self.gnn_encoder.load_from_checkpoint(path)
        
    def loadGNN_state(self,model):
        self.gnn_encoder.load_state_dict(model.state_dict())

    def forward(self, x,adjs):
#         firstHop_neighbor = adjs[-1][0]
        # get node embedding from pre-trained model
        embedding = self.gnn_encoder(x,adjs)
        logit = self.clsLayer(embedding)
        revenue = self.revLayer(embedding)
        
        return logit, revenue
    
    def compute_CLS_loss(self,logit, label):
        logit = logit.flatten()
        loss = self.loss_func(logit,label)
        return loss
    
    def compute_REG_loss(self,pred_rev, rev):
        pred_rev = pred_rev.flatten()
        loss = F.mse_loss(pred_rev,rev)
        return loss 

    def training_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        train_loss = CLS_loss + 10 * REG_loss
        self.log('train_loss', train_loss)
        return train_loss
    
    def validation_step(self, batch, batch_idx: int):
        logits, revenues = self(batch.x, batch.adjs_t)
        CLS_loss = self.compute_CLS_loss(logits, batch.y)  
        REG_loss = self.compute_REG_loss(revenues, batch.rev)
        valid_loss = CLS_loss + 1 * REG_loss
        self.log('val_loss', valid_loss, on_step=True, on_epoch=True, sync_dist=True)
        return logits
    
    def validation_epoch_end(self, val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.valid_cls_label, self.data.valid_reg_label)
        f1_top = np.mean(f)
        self.log("F1-top",f1_top)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Val/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        return {"Val/F1-top":f1_top, "log":tensorboard_logs}
        
    def test_step(self,batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self,val_step_outputs):
        predictions = torch.cat(val_step_outputs).detach().cpu().numpy().ravel()
        f,pr, re, rev = torch_metrics(predictions, self.data.test_cls_label, self.data.test_reg_label)
        f1_top = np.mean(f)
        performance = [*f, *pr, *re, *rev]
        name_performance = ["F1@1","F1@2","F1@5","F1@10","Pr@1","Pr@2","Pr@5","Pr@10",
                            "Re@1","Re@2","Re@5","Re@10","Rev@1","Rev@2","Rev@5","Rev@10"]
        name_performance = ["Test/"+i for i in name_performance]
        tensorboard_logs = dict(zip(name_performance,performance))
        
        return {"F1-top":f1_top, "log":tensorboard_logs}

    def configure_optimizers(self):
        optimizer = RangerLars(self.parameters(), lr=0.05, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=0.99)
        return [optimizer], [scheduler]

In [36]:
# model config
seed_everything(5674)
input_dim = gdata.leaf_dim
hidden_size = 32
sizes = [-1,200]
numLayers = len(sizes)
batch_size = 512
pretrain_path = "./saved_model/pretrained.ckpt"

predictor = Predictor(input_dim, hidden_size, numLayers)
predictor.loadGNN_state(model)
# predictor.load_fromPretrain(pretrain_path)
predictor.data = data

# lightning config
stacked_data = StackData(trainLab_data,unlab_data,valid_data, test_data)
datamodule = CustomData(stacked_data, sizes = sizes, batch_size=batch_size)
logger = TensorBoardLogger("ssl_exp",name="Tdata")
logger.log_hyperparams(model.hparams, metrics={"F1-top":0})
checkpoint_callback = ModelCheckpoint(
    monitor='F1-top',    
    dirpath='./saved_model',
    filename='Analysis-Tdata5-{F1-top:.4f}',
    save_top_k=1,
    mode='max',
)
trainer = Trainer(gpus=[1], max_epochs=40,
                 num_sanity_val_steps=0,
                  check_val_every_n_epoch=1,
                  callbacks=[checkpoint_callback],
                 )
trainer.fit(predictor, datamodule=datamodule)

Global seed set to 5674
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5]

  | Name        | Type              | Params
--------------------------------------------------
0 | gnn_encoder | PretrainGNN       | 106 K 
1 | clsLayer    | Linear            | 65    
2 | revLayer    | Linear            | 65    
3 | loss_func   | BCEWithLogitsLoss | 0     
--------------------------------------------------
107 K     Trainable params
0         Non-trainable params
107 K     Total params


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 708
Precision: 0.1201, Recall: 0.0349, Revenue: 0.0216
Checking top 2% suspicious transactions: 1419
Precision: 0.1043, Recall: 0.0608, Revenue: 0.0357
Checking top 5% suspicious transactions: 3546
Precision: 0.0778, Recall: 0.1133, Revenue: 0.0832
Checking top 10% suspicious transactions: 7092
Precision: 0.0515, Recall: 0.1499, Revenue: 0.2148


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.1775, Recall: 0.0517, Revenue: 0.0212
Checking top 2% suspicious transactions: 1419
Precision: 0.1163, Recall: 0.0678, Revenue: 0.0404
Checking top 5% suspicious transactions: 3546
Precision: 0.0792, Recall: 0.1154, Revenue: 0.0966
Checking top 10% suspicious transactions: 7082
Precision: 0.0541, Recall: 0.1573, Revenue: 0.2201


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2225, Recall: 0.0649, Revenue: 0.0404
Checking top 2% suspicious transactions: 1418
Precision: 0.1396, Recall: 0.0813, Revenue: 0.0461
Checking top 5% suspicious transactions: 3546
Precision: 0.0818, Recall: 0.1191, Revenue: 0.0667
Checking top 10% suspicious transactions: 7091
Precision: 0.0518, Recall: 0.1507, Revenue: 0.1222


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2254, Recall: 0.0657, Revenue: 0.0422
Checking top 2% suspicious transactions: 1418
Precision: 0.1417, Recall: 0.0825, Revenue: 0.0465
Checking top 5% suspicious transactions: 3546
Precision: 0.0801, Recall: 0.1166, Revenue: 0.0806
Checking top 10% suspicious transactions: 7092
Precision: 0.0554, Recall: 0.1614, Revenue: 0.1420


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2296, Recall: 0.0669, Revenue: 0.0380
Checking top 2% suspicious transactions: 1419
Precision: 0.1452, Recall: 0.0846, Revenue: 0.0480
Checking top 5% suspicious transactions: 3546
Precision: 0.0792, Recall: 0.1154, Revenue: 0.0636
Checking top 10% suspicious transactions: 7092
Precision: 0.0495, Recall: 0.1441, Revenue: 0.1297


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.2680, Recall: 0.0780, Revenue: 0.0436
Checking top 2% suspicious transactions: 1419
Precision: 0.1586, Recall: 0.0924, Revenue: 0.0513
Checking top 5% suspicious transactions: 3546
Precision: 0.0823, Recall: 0.1199, Revenue: 0.0658
Checking top 10% suspicious transactions: 7092
Precision: 0.0479, Recall: 0.1396, Revenue: 0.1010


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2606, Recall: 0.0760, Revenue: 0.0443
Checking top 2% suspicious transactions: 1419
Precision: 0.1607, Recall: 0.0936, Revenue: 0.0508
Checking top 5% suspicious transactions: 3546
Precision: 0.0801, Recall: 0.1166, Revenue: 0.0623
Checking top 10% suspicious transactions: 7092
Precision: 0.0568, Recall: 0.1655, Revenue: 0.1175


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2606, Recall: 0.0760, Revenue: 0.0421
Checking top 2% suspicious transactions: 1419
Precision: 0.1628, Recall: 0.0949, Revenue: 0.0572
Checking top 5% suspicious transactions: 3546
Precision: 0.0829, Recall: 0.1207, Revenue: 0.0823
Checking top 10% suspicious transactions: 7092
Precision: 0.0549, Recall: 0.1598, Revenue: 0.1353


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2549, Recall: 0.0743, Revenue: 0.0405
Checking top 2% suspicious transactions: 1419
Precision: 0.1543, Recall: 0.0899, Revenue: 0.0514
Checking top 5% suspicious transactions: 3546
Precision: 0.0795, Recall: 0.1158, Revenue: 0.0765
Checking top 10% suspicious transactions: 7092
Precision: 0.0501, Recall: 0.1458, Revenue: 0.1378


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2775, Recall: 0.0809, Revenue: 0.0429
Checking top 2% suspicious transactions: 1419
Precision: 0.1649, Recall: 0.0961, Revenue: 0.0665
Checking top 5% suspicious transactions: 3545
Precision: 0.0832, Recall: 0.1211, Revenue: 0.0960
Checking top 10% suspicious transactions: 7092
Precision: 0.0620, Recall: 0.1807, Revenue: 0.1550


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2606, Recall: 0.0760, Revenue: 0.0452
Checking top 2% suspicious transactions: 1419
Precision: 0.1642, Recall: 0.0957, Revenue: 0.0518
Checking top 5% suspicious transactions: 3546
Precision: 0.0795, Recall: 0.1158, Revenue: 0.0641
Checking top 10% suspicious transactions: 7092
Precision: 0.0543, Recall: 0.1581, Revenue: 0.1402


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2549, Recall: 0.0743, Revenue: 0.0460
Checking top 2% suspicious transactions: 1419
Precision: 0.1663, Recall: 0.0969, Revenue: 0.0557
Checking top 5% suspicious transactions: 3546
Precision: 0.0801, Recall: 0.1166, Revenue: 0.0708
Checking top 10% suspicious transactions: 7092
Precision: 0.0477, Recall: 0.1388, Revenue: 0.1154


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2831, Recall: 0.0825, Revenue: 0.0420
Checking top 2% suspicious transactions: 1419
Precision: 0.1705, Recall: 0.0994, Revenue: 0.0538
Checking top 5% suspicious transactions: 3546
Precision: 0.0821, Recall: 0.1195, Revenue: 0.0638
Checking top 10% suspicious transactions: 7088
Precision: 0.0542, Recall: 0.1577, Revenue: 0.1101


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2718, Recall: 0.0793, Revenue: 0.0419
Checking top 2% suspicious transactions: 1419
Precision: 0.1670, Recall: 0.0973, Revenue: 0.0553
Checking top 5% suspicious transactions: 3545
Precision: 0.0787, Recall: 0.1146, Revenue: 0.0612
Checking top 10% suspicious transactions: 7092
Precision: 0.0505, Recall: 0.1470, Revenue: 0.1179


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2690, Recall: 0.0784, Revenue: 0.0405
Checking top 2% suspicious transactions: 1418
Precision: 0.1650, Recall: 0.0961, Revenue: 0.0556
Checking top 5% suspicious transactions: 3546
Precision: 0.0807, Recall: 0.1175, Revenue: 0.0623
Checking top 10% suspicious transactions: 7092
Precision: 0.0527, Recall: 0.1536, Revenue: 0.1254


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2803, Recall: 0.0817, Revenue: 0.0448
Checking top 2% suspicious transactions: 1419
Precision: 0.1691, Recall: 0.0986, Revenue: 0.0533
Checking top 5% suspicious transactions: 3546
Precision: 0.0812, Recall: 0.1183, Revenue: 0.0642
Checking top 10% suspicious transactions: 7092
Precision: 0.0682, Recall: 0.1988, Revenue: 0.1494


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2746, Recall: 0.0801, Revenue: 0.0402
Checking top 2% suspicious transactions: 1419
Precision: 0.1691, Recall: 0.0986, Revenue: 0.0527
Checking top 5% suspicious transactions: 3546
Precision: 0.0973, Recall: 0.1417, Revenue: 0.0683
Checking top 10% suspicious transactions: 7091
Precision: 0.0910, Recall: 0.2649, Revenue: 0.1809


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 706
Precision: 0.2904, Recall: 0.0842, Revenue: 0.0461
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0578
Checking top 5% suspicious transactions: 3546
Precision: 0.0846, Recall: 0.1232, Revenue: 0.0706
Checking top 10% suspicious transactions: 7092
Precision: 0.0589, Recall: 0.1717, Revenue: 0.1364


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2873, Recall: 0.0838, Revenue: 0.0411
Checking top 2% suspicious transactions: 1419
Precision: 0.1698, Recall: 0.0990, Revenue: 0.0568
Checking top 5% suspicious transactions: 3546
Precision: 0.0795, Recall: 0.1158, Revenue: 0.0680
Checking top 10% suspicious transactions: 7092
Precision: 0.0530, Recall: 0.1544, Revenue: 0.1141


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2915, Recall: 0.0850, Revenue: 0.0405
Checking top 2% suspicious transactions: 1419
Precision: 0.1670, Recall: 0.0973, Revenue: 0.0530
Checking top 5% suspicious transactions: 3546
Precision: 0.0821, Recall: 0.1195, Revenue: 0.0633
Checking top 10% suspicious transactions: 7092
Precision: 0.0561, Recall: 0.1634, Revenue: 0.1459


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2887, Recall: 0.0842, Revenue: 0.0433
Checking top 2% suspicious transactions: 1419
Precision: 0.1741, Recall: 0.1014, Revenue: 0.0580
Checking top 5% suspicious transactions: 3546
Precision: 0.0826, Recall: 0.1203, Revenue: 0.0672
Checking top 10% suspicious transactions: 7092
Precision: 0.0639, Recall: 0.1860, Revenue: 0.1574


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2732, Recall: 0.0797, Revenue: 0.0428
Checking top 2% suspicious transactions: 1419
Precision: 0.1741, Recall: 0.1014, Revenue: 0.0517
Checking top 5% suspicious transactions: 3546
Precision: 0.0849, Recall: 0.1236, Revenue: 0.0782
Checking top 10% suspicious transactions: 7092
Precision: 0.0553, Recall: 0.1610, Revenue: 0.1521


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2915, Recall: 0.0850, Revenue: 0.0429
Checking top 2% suspicious transactions: 1419
Precision: 0.1705, Recall: 0.0994, Revenue: 0.0520
Checking top 5% suspicious transactions: 3546
Precision: 0.0840, Recall: 0.1224, Revenue: 0.0645
Checking top 10% suspicious transactions: 7092
Precision: 0.0671, Recall: 0.1955, Revenue: 0.1603


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2845, Recall: 0.0830, Revenue: 0.0451
Checking top 2% suspicious transactions: 1419
Precision: 0.1677, Recall: 0.0977, Revenue: 0.0556
Checking top 5% suspicious transactions: 3546
Precision: 0.0854, Recall: 0.1244, Revenue: 0.0663
Checking top 10% suspicious transactions: 7092
Precision: 0.0650, Recall: 0.1893, Revenue: 0.1485


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 709
Precision: 0.2906, Recall: 0.0846, Revenue: 0.0424
Checking top 2% suspicious transactions: 1419
Precision: 0.1684, Recall: 0.0982, Revenue: 0.0568
Checking top 5% suspicious transactions: 3545
Precision: 0.0812, Recall: 0.1183, Revenue: 0.0638
Checking top 10% suspicious transactions: 7092
Precision: 0.0804, Recall: 0.2341, Revenue: 0.1805


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2859, Recall: 0.0834, Revenue: 0.0431
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0568
Checking top 5% suspicious transactions: 3546
Precision: 0.0826, Recall: 0.1203, Revenue: 0.0656
Checking top 10% suspicious transactions: 7089
Precision: 0.0698, Recall: 0.2033, Revenue: 0.1613


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2859, Recall: 0.0834, Revenue: 0.0417
Checking top 2% suspicious transactions: 1419
Precision: 0.1705, Recall: 0.0994, Revenue: 0.0570
Checking top 5% suspicious transactions: 3539
Precision: 0.0825, Recall: 0.1199, Revenue: 0.0665
Checking top 10% suspicious transactions: 7092
Precision: 0.0774, Recall: 0.2255, Revenue: 0.1816


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2972, Recall: 0.0867, Revenue: 0.0441
Checking top 2% suspicious transactions: 1419
Precision: 0.1734, Recall: 0.1010, Revenue: 0.0570
Checking top 5% suspicious transactions: 3546
Precision: 0.0863, Recall: 0.1257, Revenue: 0.0876
Checking top 10% suspicious transactions: 7092
Precision: 0.0774, Recall: 0.2255, Revenue: 0.1671


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2915, Recall: 0.0850, Revenue: 0.0432
Checking top 2% suspicious transactions: 1419
Precision: 0.1755, Recall: 0.1023, Revenue: 0.0579
Checking top 5% suspicious transactions: 3546
Precision: 0.1094, Recall: 0.1593, Revenue: 0.1059
Checking top 10% suspicious transactions: 7092
Precision: 0.1189, Recall: 0.3462, Revenue: 0.2320


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2901, Recall: 0.0846, Revenue: 0.0439
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0572
Checking top 5% suspicious transactions: 3546
Precision: 0.0959, Recall: 0.1396, Revenue: 0.0911
Checking top 10% suspicious transactions: 7092
Precision: 0.0808, Recall: 0.2353, Revenue: 0.1789


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2944, Recall: 0.0858, Revenue: 0.0437
Checking top 2% suspicious transactions: 1419
Precision: 0.1741, Recall: 0.1014, Revenue: 0.0574
Checking top 5% suspicious transactions: 3541
Precision: 0.0844, Recall: 0.1228, Revenue: 0.0679
Checking top 10% suspicious transactions: 7083
Precision: 0.0664, Recall: 0.1930, Revenue: 0.1512


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3085, Recall: 0.0899, Revenue: 0.0439
Checking top 2% suspicious transactions: 1419
Precision: 0.1741, Recall: 0.1014, Revenue: 0.0565
Checking top 5% suspicious transactions: 3546
Precision: 0.0857, Recall: 0.1248, Revenue: 0.0678
Checking top 10% suspicious transactions: 7092
Precision: 0.0774, Recall: 0.2255, Revenue: 0.1533


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2901, Recall: 0.0846, Revenue: 0.0422
Checking top 2% suspicious transactions: 1419
Precision: 0.1734, Recall: 0.1010, Revenue: 0.0568
Checking top 5% suspicious transactions: 3546
Precision: 0.0829, Recall: 0.1207, Revenue: 0.0679
Checking top 10% suspicious transactions: 7092
Precision: 0.0678, Recall: 0.1975, Revenue: 0.1555


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2887, Recall: 0.0842, Revenue: 0.0459
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0578
Checking top 5% suspicious transactions: 3546
Precision: 0.0956, Recall: 0.1392, Revenue: 0.0924
Checking top 10% suspicious transactions: 7092
Precision: 0.0997, Recall: 0.2903, Revenue: 0.2265


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 708
Precision: 0.2924, Recall: 0.0850, Revenue: 0.0474
Checking top 2% suspicious transactions: 1418
Precision: 0.1777, Recall: 0.1035, Revenue: 0.0579
Checking top 5% suspicious transactions: 3546
Precision: 0.0956, Recall: 0.1392, Revenue: 0.1011
Checking top 10% suspicious transactions: 7092
Precision: 0.1003, Recall: 0.2920, Revenue: 0.2060


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3042, Recall: 0.0887, Revenue: 0.0447
Checking top 2% suspicious transactions: 1419
Precision: 0.1776, Recall: 0.1035, Revenue: 0.0582
Checking top 5% suspicious transactions: 3546
Precision: 0.1111, Recall: 0.1618, Revenue: 0.0949
Checking top 10% suspicious transactions: 7091
Precision: 0.1104, Recall: 0.3216, Revenue: 0.2064


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2944, Recall: 0.0858, Revenue: 0.0537
Checking top 2% suspicious transactions: 1419
Precision: 0.1727, Recall: 0.1006, Revenue: 0.0579
Checking top 5% suspicious transactions: 3546
Precision: 0.1058, Recall: 0.1540, Revenue: 0.1062
Checking top 10% suspicious transactions: 7092
Precision: 0.0933, Recall: 0.2719, Revenue: 0.1960


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3014, Recall: 0.0879, Revenue: 0.0428
Checking top 2% suspicious transactions: 1419
Precision: 0.1720, Recall: 0.1002, Revenue: 0.0579
Checking top 5% suspicious transactions: 3546
Precision: 0.0914, Recall: 0.1331, Revenue: 0.0880
Checking top 10% suspicious transactions: 7092
Precision: 0.0883, Recall: 0.2571, Revenue: 0.1913


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.2958, Recall: 0.0862, Revenue: 0.0473
Checking top 2% suspicious transactions: 1419
Precision: 0.1783, Recall: 0.1039, Revenue: 0.0574
Checking top 5% suspicious transactions: 3546
Precision: 0.1038, Recall: 0.1511, Revenue: 0.0804
Checking top 10% suspicious transactions: 7092
Precision: 0.0938, Recall: 0.2731, Revenue: 0.1810


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Checking top 1% suspicious transactions: 710
Precision: 0.3042, Recall: 0.0887, Revenue: 0.0489
Checking top 2% suspicious transactions: 1418
Precision: 0.1791, Recall: 0.1043, Revenue: 0.0576
Checking top 5% suspicious transactions: 3546
Precision: 0.0987, Recall: 0.1437, Revenue: 0.0909
Checking top 10% suspicious transactions: 7092
Precision: 0.0877, Recall: 0.2554, Revenue: 0.1858



1

In [37]:
trainer.test()

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…

Checking top 1% suspicious transactions: 2749
Precision: 0.1648, Recall: 0.0666, Revenue: 0.0945
Checking top 2% suspicious transactions: 5497
Precision: 0.1128, Recall: 0.0912, Revenue: 0.1302
Checking top 5% suspicious transactions: 13741
Precision: 0.1095, Recall: 0.2214, Revenue: 0.2474
Checking top 10% suspicious transactions: 27479
Precision: 0.1170, Recall: 0.4729, Revenue: 0.4114

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'F1-top': 0.13246628519247977,
 'Test/F1@1': 0.09490886235072281,
 'Test/F1@10': 0.18753646866612209,
 'Test/F1@2': 0.10086220920774362,
 'Test/F1@5': 0.14655760054533062,
 'Test/Pr@1': 0.16478719534376138,
 'Test/Pr@10': 0.11696204374249426,
 'Test/Pr@2': 0.11278879388757504,
 'Test/Pr@5': 0.10952623535404993,
 'Test/Re@1': 0.06664705016919228,
 'Test/Re@10': 0.4728556716198323,
 'Test/Re@2': 0.09121671325584817,
 'Test/Re@5': 0.22142121524201855,
 'Test/Rev@1': 0.09453650021569975,
 'Test/Rev@

[{'Test/F1@1': 0.09490886235072281,
  'Test/F1@2': 0.10086220920774362,
  'Test/F1@5': 0.14655760054533062,
  'Test/F1@10': 0.18753646866612209,
  'Test/Pr@1': 0.16478719534376138,
  'Test/Pr@2': 0.11278879388757504,
  'Test/Pr@5': 0.10952623535404993,
  'Test/Pr@10': 0.11696204374249426,
  'Test/Re@1': 0.06664705016919228,
  'Test/Re@2': 0.09121671325584817,
  'Test/Re@5': 0.22142121524201855,
  'Test/Re@10': 0.4728556716198323,
  'Test/Rev@1': 0.09453650021569975,
  'Test/Rev@2': 0.13015026098892682,
  'Test/Rev@5': 0.2473522112718454,
  'Test/Rev@10': 0.41143432461167456,
  'F1-top': 0.13246628519247977,
  'val_loss_epoch': 3.207803726196289,
  'val_loss': 2.7595224380493164}]