In [132]:
import os
import sys
sys.path.append('./../..')
sys.path.append('./..')
import numpy as np
import pandas as pd
sys.path.append('./../..')
from redisStore import redisUtil
from pathlib import Path
import plotly
from plotly import express as px
import os
import plotly.io as pio
import pickle
from DB_Ingestion.sqlite_engine import sqlite
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from sqlalchemy import create_engine
from plotly import io

DATA_LOC = None
subDIR = None
# Singleton object
redis_obj = redisUtil.redisStore
EMB_DIM = None
htmlSaveDir = None
SQL_Conn = None



# =============================
#  This needs to be called to ingest data into redis,
# =============================
def initialize(
    _DATA_LOC,
    _subDIR,
    mp2v_emb_dir = './../records2graph/saved_model_data',
    emb_dim = 64,
    _htmlSaveDir = None
):
    global redis_obj
    global DATA_LOC
    global subDIR
    global EMB_DIM
    global htmlSaveDir
    global SQL_Conn
    
    
    EMB_DIM = emb_dim
    DATA_LOC = _DATA_LOC
    subDIR = _subDIR
    if _htmlSaveDir is None:
        htmlSaveDir = './htmlCache'
    else:
        htmlSaveDir = _htmlSaveDir
    Path(htmlSaveDir).mkdir(exist_ok=True,parents=True )
    
    redis_obj.ingest_record_data(
        DATA_LOC=DATA_LOC,
        subDIR=subDIR
    )
    
    redis_obj.ingest_MP2V_embeddings(
        DATA_LOC,
        subDIR ,
        mp2v_emb_dir,
        emb_dim=emb_dim
    )
    return 


# --------------------
# Helper function
# --------------------
def get_comparison_vecotors(record_row, domain, entity_list):
    global redis_obj
    global EMB_DIM

    emb_dim = EMB_DIM
    vec = []
    for entity_id in entity_list:
        key = 'mp2v_{}_{}_{}'.format(emb_dim, domain, entity_id)
        vec.append(redis_obj.fetch_np(key))
    
    key = 'mp2v_{}_{}_{}'.format(emb_dim, domain,record_row[domain])
    target_entity_vec = redis_obj.fetch_np(key)
    vec.append(target_entity_vec)
    return (domain, np.array([target_entity_vec]), np.array(vec))

    

def get_stackedComparisonPlots(
    record_id, 
    min_count = 1000
    return_type=1
):
    global EMB_DIM
    global htmlSaveDir
    global SQL_Conn
    global DATA_LOC 
    global subDIR = '01_2016'
    with open(os.path.join(DATA_LOC, subDIR,'col_val2id_dict.pkl'), 'rb') as fh:
            column_values2id = pickle.load(fh)
    column_id2value = { _domain: {v:k for k,v in _dict.items()} for _domain,_dict in column_values2id.items()  }

    with open(os.path.join(DATA_LOC, subDIR,'domain_dims.pkl'), 'rb') as fh:
            domain_dims = pickle.load(fh)

    record_row = redis_obj.fetch_data(record_id)
    record_row_w_vals = {} 
    for _domain in column_values2id.keys():
        record_row_w_vals[_domain] = column_id2value[_domain][record_row[_domain]]

    _query_str_0 = 'select PanjivaRecordID from Records where ConsigneePanjivaID={} and ShipperPanjivaID={}'.format(
        record_row_w_vals['ConsigneePanjivaID'],record_row_w_vals['ShipperPanjivaID']
    )
    _query_str_1 = 'select PanjivaRecordID from Records where ConsigneePanjivaID={} or ShipperPanjivaID={}'.format(
        record_row_w_vals['ConsigneePanjivaID'],record_row_w_vals['ShipperPanjivaID']
    )

    _query_str_6= 'select PanjivaRecordID from Records where PortOfLading="{}" and HSCode={} and  PortOfUnlading="{}"'.format(
        record_row_w_vals['PortOfLading'],record_row_w_vals['HSCode'], record_row_w_vals['PortOfUnlading']
    )
    _query_str_4 = 'select PanjivaRecordID from Records where ( ShipmentOrigin="{}" and PortOfLading="{}")  or (ShipmentDestination="{}" and PortOfUnlading="{}")'.format(
        record_row_w_vals['ShipmentOrigin'], 
        record_row_w_vals['PortOfUnlading'],
        record_row_w_vals['ShipmentDestination'],
        record_row_w_vals['PortOfLading'] 
    )

    _query_str_5 = 'select PanjivaRecordID from Records where ( ShipmentOrigin="{}" and HSCode={})  or (ShipmentDestination="{}" and HSCode={})'.format(
        record_row_w_vals['ShipmentOrigin'], 
        record_row_w_vals['HSCode'],
        record_row_w_vals['ShipmentDestination'],
        record_row_w_vals['HSCode'] 
    )

    _query_str_2 = 'select PanjivaRecordID from Records where ShipperPanjivaID={} and ShipmentOrigin="{}"'.format(
        record_row_w_vals['ShipperPanjivaID'], record_row_w_vals['ShipmentOrigin']
    )

    _query_str_3 = 'select PanjivaRecordID from Records where ConsigneePanjivaID={} and  ShipmentDestination="{}"'.format(
        record_row_w_vals['ConsigneePanjivaID'], record_row_w_vals['ShipmentDestination']
    )

    query_string_list = [_query_str_0, _query_str_1, _query_str_2,_query_str_3,  _query_str_4, _query_str_5, _query_str_6 ]

    df_1 = pd.read_csv(os.path.join(DATA_LOC,subDIR,'train_data.csv'), index_col=None)
    df_2 = pd.read_csv(os.path.join(DATA_LOC,subDIR,'test_data.csv'), index_col=None)
    reference_df = df_1.append(df_2,ignore_index=True)

    data = None
    ID_COL = 'PanjivaRecordID'

    for _query in query_string_list:

        _df = pd.read_sql(
                _query,
                con=SQL_Conn,
                index_col=None
        )

        ids = _df[ID_COL].values.tolist()
        tmp = reference_df.loc[reference_df[ID_COL].isin(ids)]
        if data is None:
            data = tmp.copy()
        data = data.append(tmp,ignore_index=True)
        data = data.drop_duplicates(subset=[ID_COL])
        if len(data) >= min_count:
            data = data.head(min_count)
            break

    vectors_dict = {}
    for domain in domain_dims.keys():
        if domain in ['ConsigneePanjivaID','ShipperPanjivaID']:
            continue
        entity_list = data[domain].values.tolist()
        vectors = get_comparison_vecotors( record_row, domain, entity_list )
        vectors_dict[vectors[0]] = (vectors[1], vectors[2])

    # -----------------------------------------------------------------------------------
    fig = make_subplots(rows=2, cols=3, subplot_titles= list(vectors_dict.keys()))
    i = 1
    j = 1

    for domain in domain_dims.keys():
        if domain not in vectors_dict.keys():
            continue
        target_entity_vec = vectors_dict[domain][0]
        _vectors = vectors_dict[domain][1]
        # fig = px.scatter(_vectors[:,0], _vectors[:,1])
        sub_figure = go.Scatter(
            x = _vectors[:,0],
            y = _vectors[:,1],
            mode = 'markers',
            marker = dict(
                color = 'rgba(10,0,225,0.5)',
                size = 6
            )
        )

        fig.add_trace(
            sub_figure,
            row=i, col=j
        )
        fig.add_trace(
            go.Histogram2dContour(
                x = _vectors[:,0],
                y = _vectors[:,1],
                colorscale = 'Greens',
                opacity=0.45,
                showlegend=False
            ),
            row=i, col=j
        )

        fig.add_trace(
            go.Scatter(
            x=target_entity_vec[:,0],
            y=target_entity_vec[:,1],
            mode="markers+text",
            marker = dict(
            color = 'rgba(230,0,10,0.85)',
                size = 10
            ),
            text=["Record Entity"],
            ),
        row=i, col=j
        )
        fig.update_coloraxes({'showscale':False}, row=i, col=j)
        fig.update_traces(showlegend=False,  row=i, col=j)
        j+=1
        if j > 3:
            j=1
            i+= 1
    fig.update_layout(xaxis_showgrid=False, yaxis_showgrid=False)
    fig.update(layout_coloraxis_showscale=False)
    fig.update_coloraxes({'showscale':False})
    
    if return_type == 1:
        html_String = io.to_html(fig, include_plotlyjs='cdn', include_mathjax='cdn', full_html=False)
        return html_String
    elif return_type == 2:
        fpath = os.path.join(htmlSaveDir,'stackedComaprison_{}.html'.format(record_id))
        fig.write_html(fpath,include_plotlyjs='cdn',include_mathjax='cdn',full_html=False)
        return fpath
        
        
# ===========================
'''
initialize(
    _DATA_LOC = './../../generated_data_v1/us_import',
    _subDIR = '01_2016',
    mp2v_emb_dir = './../../records2graph/saved_model_data',
    emb_dim = 64,
    _htmlSaveDir = './htmlCache'
)

get_stackedComparisonPlots(
    record_id, 
    min_count = 500
)


'''
