In [26]:
%load_ext autoreload
%autoreload 2

import argparse
import numpy as np
import pandas as pd
import os
import sys
sys.path.append('./..')
from pandarallel import pandarallel
from joblib import Parallel,delayed
pandarallel.initialize()
from tqdm import tqdm
import pickle
import torch
import multiprocessing as mp
from joblib.externals.loky import set_loky_pickler
from joblib import parallel_backend
from joblib import Parallel, delayed
from joblib import wrap_non_picklable_objects
from collections import OrderedDict, defaultdict
from itertools import combinations
from glob import glob
from utils import util
from typing import *

# ======================================== #

from model_AD_1 import AD_model_container, MEAD

# ======================================== #

DEVICE =  torch.device("cpu")
CONFIG_FILE = './config.yaml'
ad_model_data_dir = None
MEAD_emb_list = []
CONFIG = None
domain_dims = None
train_epochs = 100
error_tol = 0.001
learning_rate = 0.001
model_save_dir = None
num_entities = 0
id_col = 'PanjivaRecordID'
threshold_dict = {}
model_dict = {}

def get_domain_dims(data_loc):
    # with open(os.path.join(data_loc,'domain_dims.pkl'),'rb') as fh:
    #     domain_dims = pickle.load(fh)
    
    domain_dims = pd.read_csv(os.apth.join(data_loc, 'data_dimensions.csv'),index=None)
    return domain_dims

'''
Set up globals
'''
def setup_config(subDIR):
    global CONFIG
    global CONFIG_FILE
    global DIR, ad_model_data_dir, MEAD_emb_list, batch_size, error_tol, train_epochs, learning_rate, model_save_dir
    global domain_dims, num_entities
    
    DIR = subDIR
    with open(CONFIG_FILE, 'r') as fh:
        CONFIG = yaml.safe_load(fh)
        
    ad_model_data_dir = os.path.join(
        CONFIG['DATA_LOC'], 
        DIR,
        CONFIG['AD_model_data_subdir']
    )
    data_loc = CONFIG['DATA_LOC']
    domain_dims = get_domain_dims(os.path.join(data_loc, subDIR))
    MEAD_emb_list = CONFIG['emb_dims']
    MEAD_emb_list = [int(_) for _ in MEAD_emb_list]
    train_epochs = CONFIG['train_epochs']
    learning_rate = CONFIG['learning_rat']
    batch_size = CONFIG['batch_size']
    error_tol = CONFIG['batch_size']
    model_save_dir = os.path.join(CONFIG['model_save_dir'], DIR )
    num_entities = np.sum(domain_dims['dimension'].values)
    return 


'''
Procedure to train the models
'''
def train_AD_models():
    global DIR, ad_model_data_dir, MEAD_emb_list, train_epochs, model_save_dir, DEVICE
    
    def aux(emb_dim, num_entities):
        global ad_model_data_dir, train_epochs, learning_rate, batch_size, error_tol, model_save_dir, DEVICE  
        train_x_pos = os.path.join(ad_model_data_dir, 'train_x_pos.npy')
        train_x_neg = os.path.join(ad_model_data_dir, 'train_x_neg.npy')
        ad_obj = AD_model_container(
            emb_dim,
            num_entities,
            device = DEVICE,
            lr = learning_rate
        ) 
        
        ad_obj.train(
            train_x_pos, 
            train_x_neg, 
            batch_size = batch_size, 
            epochs = train_epochs,
            error_tol = error_tol
        )
        
        ad_obj.save_model( 
            model_save_dir
        )
        return 
        
    Parallel(n_jobs=3)(delayed(aux)(emb_dim, num_entities,) for emb_dim in MEAD_emb_list)
    
    return 
    
    

'''
Stores the pth percentile values for the likelihood scores of training samples.
'''
def calculate_thresholds(percentile_cutoff:List = [2,5,10]):
    global model_save_dir
    global MEAD_emb_list
    global num_entities
    global ad_model_data_dir
    model_dict = {}
    for emb_dim in MEAD_emb_list:
        model_file_path = sorted(glob(os.path.join(model_save_dir, '**{}_**.pth'.format(emb_dim))))[0]
        ad_obj = AD_model_container(
            emb_dim,
            num_entities
        ) 
        ad_obj.load_model(model_file_path)
        model_dict[emb_dim] = ad_obj
    
    dict_embDim_thresholdValue = defaultdict() 
    # Load the training data set 
    train_x_pos = os.path.join(ad_model_data_dir, 'train_x_pos.npy')
    for emb_dim in MEAD_emb_list:
        ad_obj =  model_dict[emb_dim]
        scores = ad_obj.score_samples(train_x_pos)
        dict_embDim_thresholdValue[emb_dim] = {}
        # Calculate the n-th percentile values
        for p in percentile_cutoff:
            dict_embDim_thresholdValue[emb_dim][p] = np.percentile(np.array(scores).reshape(-1), p)
    '''
    Save the values
    '''
    threshold_save_file =  os.path.join(model_save_dir, 'threshold_dict_{}.pkl', format('-'.join([str(_) for _ in percentile_cutoff])))
    with open(threshold_save_file, 'wb') as fh:
        pickle.dump(threshold_save_file,fh,pickle.HIGHEST_PROTOCOL)
    return 


'''
Call before test mode
'''

def read_thresold_dict(percentile_cutoff:List = [2,5,10]):
    global DIR
    global threshold_dict
    threshold_save_file =  os.path.join(model_save_dir, 'threshold_dict_{}.pkl', format('-'.join([str(_) for _ in percentile_cutoff])))
    with open(threshold_save_file, 'rb') as fh:
        threshold_dict = pickle.load(fh)
    return 

def read_models():
    global model_save_dir, model_dict
    global MEAD_emb_list
    global num_entities
    global ad_model_data_dir, DIR
    model_dict = {}
    for emb_dim in MEAD_emb_list:
        model_file_path = sorted(glob(os.path.join(model_save_dir, '**{}_**.pth'.format(emb_dim))))[0]
        ad_obj = AD_model_container(
            emb_dim,
            num_entities
        ) 
        ad_obj.load_model(model_file_path)
        model_dict[emb_dim] = ad_obj
    return 

'''
This is an external facing function
Input : a single row of pandas dataframe
This record is not serialized is a single row of a dataframe
'''
def score_new_sample(record : pd.DataFrame):
    global DIR, id_col
    global model_dict
    global threshold_dict
    # perform serialization
    serialized_record = util.convert_to_serializedID_format(
        record, 
        DIR
    )
    try:
        del serialized_record[id_col]
    except:
        pass
    x_values = serialized_record.values[0]
    x_values = x_values.reshape([1,-1])
    result = {}
    for emb_dim in MEAD_emb_list:
        ad_obj = model_dict[emb_dim] 
        score = ad_obj.predict(x_values)
        _res = {}
        cutoff_perc_values = threshold_dict[emb]
        for perc,v in cutoff_perc_values.items():
            _res[perc] = (v, score[0])
        result[emb_dim] = _res
    return result

# ======================================================================================
if __name__ == 'main':
    parser = argparse.ArgumentParser(description='Generate anomalies')
    parser.add_argument('--dir', type = str, help='Which dataset ? us__import{1,2...}' ) 
    DIR = ars.dir
    setup_config(DIR)
    train_AD_models()
    calculate_thresholds()
    
    
# ======================================================================================
'''
Calling externally after initialization
'''
# 
# setup_config(DIR)
# read_thresold_dict()
# read_models()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
INFO: Pandarallel will run on 40 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.
