<a href="https://colab.research.google.com/github/iitaejeong/behavior2vec/blob/main/graph2vec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import glob
import hashlib
import pandas as pd
import numpy as np
import networkx as nx
import psycopg2 as pg2

from tqdm import tqdm
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
import random
import gc
import matplotlib.pyplot as plt
import collections

conn = pg2.connect(database=graph_env['database']\
                       , user=graph_env['user']\
                       , host=graph_env['host']\
                       , port=graph_env['port'])

# data loading

In [None]:
# data loading
query = 'SELECT distinct(cust_pty_sbt_id) FROM tb_wed_com_prof'
cust_list = pd.read_sql(query, con=conn)
# 랜덤 cust 한 명 추출.
cust_id = random.choice(cust_list['cust_pty_sbt_id'])


# sgi node
query = f"SELECT * FROM tb_wed_sgi WHERE cust_pty_sbt_id LIKE '%{cust_id}%'"
sgidf = pd.read_sql(query, con=conn)

# cdr node
query = f"SELECT * FROM tb_wed_cdr WHERE cust_pty_sbt_id LIKE '%{cust_id}%'"
cdrdf = pd.read_sql(query, con=conn)

# s1ap node
query = f"SELECT * FROM tb_wed_s1ap WHERE cust_pty_sbt_id LIKE '%{cust_id}%'"
s1apdf = pd.read_sql(query, con=conn)

# data preprocessing 

In [None]:
sgi_feature = ['cust_pty_sbt_id',
                'base_date',
                'timezn_div_cd',
                'svc_nm',
                'byte_size']
sgidf_ = sgidf[sgi_feature].copy()

In [None]:
cdr_feature = ['cust_pty_sbt_id',
               'base_date',
               'timezn_div_cd',
               'tpn114_lctg_cd',
               #'tpn114_ttsctg_cd',
               'tpn114_tsctg_nm',
               'rcv_tlk_time',
               'snd_tlk_time']
cdrdf_ = cdrdf[cdr_feature].copy()

In [None]:
s1ap_feature = ['cust_pty_sbt_id',
          'base_date',
          'timezn_div_cd',
          'address',
          'stay_time',
          'mov_radis_dist',
          'wday_eweek_div_nm'
         ]
s1apdf_ = s1apdf[s1ap_feature].copy()

In [None]:
sgidf_['timezn_div_cd'] = sgidf_['timezn_div_cd'].apply(lambda x : x[:1])
cdrdf_['timezn_div_cd'] = cdrdf_['timezn_div_cd'].apply(lambda x : x[:1])
s1apdf_['timezn_div_cd'] = s1apdf_['timezn_div_cd'].apply(lambda x : x[:1])

sgidf_['base_date'] = sgidf_['base_date'].astype(str)
cdrdf_['base_date'] = cdrdf_['base_date'].astype(str)
s1apdf_['base_date'] = s1apdf_['base_date'].apply(lambda x : x[:4] + '-' + x[4:6] + '-' + x[6:])

## example 

In [None]:
display(sgidf_.head(1))
display(cdrdf_.head(1))
display(s1apdf_.head(1))

gc.collect()

## feature  

### sgi 

In [None]:
# min-max scaler 
sgidf_['rel_value'] = (sgidf_['byte_size'] - sgidf_['byte_size'].min(axis=0)) / (sgidf_['byte_size'].max(axis=0) - sgidf_['byte_size'].min(axis=0))
sgidf_['sgi_behavior'] = sgidf_['timezn_div_cd'] + '_' + sgidf_['svc_nm']

sgidf_

In [None]:
sgidf_['rel_value'] + 1

In [None]:
timezn_sgi = collections.defaultdict(str)
for day in sorted(sgidf_['timezn_div_cd'].unique()):
    timezn_sgi[day] = sgidf_[sgidf_['timezn_div_cd'] == day]['svc_nm'].value_counts().to_dict()
    for itr, app in enumerate(list(timezn_sgi[day].keys())):
        appvalue = abs(np.log(timezn_sgi[day][list(timezn_sgi[day].keys())[itr]] / sum(timezn_sgi[day].values())))
        timezn_sgi[day][app] = appvalue

timezndf_sgi = collections.defaultdict(pd.DataFrame)

for day in list(timezn_sgi.keys()):
    timezndf_sgi[day] = pd.DataFrame(sgidf_[sgidf_['timezn_div_cd'] == day].svc_nm.apply(lambda x : timezn_sgi[day][x]))
    
timezn_index = pd.concat(timezndf_sgi).reset_index().set_index('level_1')
sgidf_ = pd.merge(sgidf_, timezn_index , left_index=True, right_index=True, how='inner')
sgidf_ = sgidf_.rename(columns={'svc_nm_y':'rel_value_2'})

### cdr 

In [None]:
# 수신 전화 대비 송신 전화를 해당 업종에 대한 관심척도로 측정.
## 나눗셈 시 0으로 되는걸 방지하기위해 엡실론 값 혹은 1 부여
## method 1 , 엡실론
#epsilon = np.finfo(float).eps
#cdrdf[['rcv_tlk_time','snd_tlk_time']].apply(lambda x : x + epsilon)
## method 2, 1부여
cdrdf_[['rcv_tlk_time','snd_tlk_time']] = cdrdf_[['rcv_tlk_time','snd_tlk_time']].apply(lambda x : x + 1)
cdrdf_['rel_value'] = (lambda x,y : np.round( x / y , 2) + 100)(cdrdf_['snd_tlk_time'], cdrdf_['rcv_tlk_time'])
#cdrdf_['cdr_behavior'] = cdrdf_['timezn_div_cd'] + '_' + cdrdf_['tpn114_ttsctg_cd']
cdrdf_['cdr_behavior'] = cdrdf_['timezn_div_cd'] + '_' + cdrdf_['tpn114_tsctg_nm']
cdrdf_['rel_value_2'] = cdrdf_['rel_value'].copy()

### s1ap 

In [None]:
s1apdf_['stay_time'] = s1apdf_['stay_time'].apply(lambda x : float(x))
# rel_value 는 체류시간을 기준으로 만들어짐.
s1apdf_['rel_value'] = (s1apdf_.stay_time - s1apdf_.stay_time.min(axis=0)) / (s1apdf_.stay_time.max(axis=0) - s1apdf_.stay_time.min(axis=0)) + 1000
s1apdf_['s1ap_behavior'] = s1apdf_['timezn_div_cd'] + '_' + s1apdf_['address']

# rel_value_2는 특정 시간대에 갔던 장소 cnt 기준으로 만들어짐.
# ex) user '이유' 는 주로 A 시간대에 '광화문','합정' 에 갔었다. 어느날 일정이 생겨 '정릉'에 가게 되면 
# 특정 이벤트라 지각하여 weight를 높게 측정함.
timezntest = collections.defaultdict(str)
for day in sorted(s1apdf_['timezn_div_cd'].unique()):
    timezntest[day] = s1apdf_[s1apdf_['timezn_div_cd'] == day]['address'].value_counts().to_dict()
    for itr, dong in enumerate(list(timezntest[day].keys())):
        dongvalue = abs(np.log(timezntest[day][list(timezntest[day].keys())[itr]] / sum(timezntest[day].values())))
        timezntest[day][dong] = dongvalue

timezndf = collections.defaultdict(pd.DataFrame)

for day in list(timezntest.keys()):
    timezndf[day] = pd.DataFrame(s1apdf_[s1apdf_['timezn_div_cd'] == day].address.apply(lambda x : timezntest[day][x]))
    
timezn_index = pd.concat(timezndf).reset_index().set_index('level_1')
s1apdf_ = pd.merge(s1apdf_, timezn_index , left_index=True, right_index=True, how='inner')
s1apdf_ = s1apdf_.rename(columns={'address_y':'rel_value_2'})
s1apdf_['rel_value_2'] = s1apdf_['rel_value_2'] + 1000

#### TODO LIST ####

df 들간의 base date 가 불일치하여, 특정 데이터프레임 일자 기준으로 그래프를 생성해야 함.

각 데이터 프레임의 basedate nunique값을 참고하여, 

가장 높은 cnt를 지닌 data frame 추출 그리고 타 데이터들의 기준으로 만드는 함수 생성.


####

## network generation 

In [None]:
# version ['generation', 'anomaly']
version = 'anomaly'

In [None]:
if version == 'generation':
    #del daygraph_sgi
    daygraph_sgi = collections.defaultdict(nx.MultiGraph)
    for day in tqdm(sorted(sgidf_['base_date'].unique())):
        G = nx.from_pandas_edgelist(df = sgidf_[sgidf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         'sgi_behavior',
                                                         'rel_value']],
                                    source = 'cust_pty_sbt_id',
                                    target = 'sgi_behavior',
                                    edge_attr = 'rel_value')
        daygraph_sgi[day] = G
else:
        #del daygraph_sgi
    daygraph_sgi = collections.defaultdict(nx.MultiGraph)
    for day in tqdm(sorted(sgidf_['base_date'].unique())):
        G = nx.from_pandas_edgelist(df = sgidf_[sgidf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         'sgi_behavior',
                                                         'rel_value_2']],
                                    source = 'cust_pty_sbt_id',
                                    target = 'sgi_behavior',
                                    edge_attr = 'rel_value_2')
        daygraph_sgi[day] = G

In [None]:
#del daygraph_s1ap
if version == 'generation':
    s1apxsgi_daylist = list(set(s1apdf_['base_date']).intersection(sgidf_['base_date'].unique()))
    daygraph_s1ap = collections.defaultdict(nx.Graph)
    for day in tqdm(sorted(s1apxsgi_daylist)):
        G = nx.from_pandas_edgelist(df = s1apdf_[s1apdf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         's1ap_behavior',
                                                         'rel_value']],
                                    source = 'cust_pty_sbt_id',
                                    target = 's1ap_behavior',
                                    edge_attr = 'rel_value')
        daygraph_s1ap[day] = G
else:
    s1apxsgi_daylist = list(set(s1apdf_['base_date']).intersection(sgidf_['base_date'].unique()))
    daygraph_s1ap = collections.defaultdict(nx.Graph)
    for day in tqdm(sorted(s1apxsgi_daylist)):
        G = nx.from_pandas_edgelist(df = s1apdf_[s1apdf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         's1ap_behavior',
                                                         'rel_value_2']],
                                    source = 'cust_pty_sbt_id',
                                    target = 's1ap_behavior',
                                    edge_attr = 'rel_value_2')
        daygraph_s1ap[day] = G

In [None]:
#del daygraph_cdr
if version == 'generation':
    cdrxsgi_daylist = list(set(cdrdf_['base_date']).intersection(sgidf_['base_date'].unique()))
    daygraph_cdr = collections.defaultdict(nx.Graph)
    for day in tqdm(sorted(cdrxsgi_daylist)):
        G = nx.from_pandas_edgelist(df = cdrdf_[cdrdf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         'cdr_behavior',
                                                         'rel_value']],
                                    source = 'cust_pty_sbt_id',
                                    target = 'cdr_behavior',
                                    edge_attr = 'rel_value')
        daygraph_cdr[day] = G
else:
    cdrxsgi_daylist = list(set(cdrdf_['base_date']).intersection(sgidf_['base_date'].unique()))
    daygraph_cdr = collections.defaultdict(nx.Graph)
    for day in tqdm(sorted(cdrxsgi_daylist)):
        G = nx.from_pandas_edgelist(df = cdrdf_[cdrdf_['base_date'] == day][['cust_pty_sbt_id',\
                                                         'cdr_behavior',
                                                         'rel_value_2']],
                                    source = 'cust_pty_sbt_id',
                                    target = 'cdr_behavior',
                                    edge_attr = 'rel_value_2')
        daygraph_cdr[day] = G

## network merge

In [None]:
from tqdm import tqdm
graphbox = collections.defaultdict(nx.MultiGraph)
# graph compose 의 기준은 date,
# 행동 교집합을 추출 및 하나의 그래프로 합치기 위하여 기준인 date
# 가장 빈번하게 발생하는 sgi를 기준으로 date list 를 생성 및 compose
daylist = sgidf_['base_date'].values
try:
    for day in tqdm(daylist):
        graphbox[day] = nx.compose_all([daygraph_sgi[day],
                       daygraph_cdr[day],
                       daygraph_s1ap[day]])
except:
    print(day)

# algorithm

## model architecrure 

## feature extractor

In [None]:
class WeisfeilerLehmanMachine:
    '''
    Weisfeiler Lehman feature extractor class.
    '''
    def __init__(self, graph, features, iterations):
        '''
        Initialization method which also executes feature extraction.
        graph(networkx.graph): The Nx graph object.
        features(hashvalue): Feature hash table.
        iterations(int): Number of WL iterations.
        '''
        self.graph = graph
        self.nodes = self.graph.nodes()
        self.features = features
        self.extracted_features = [str(v) for k, v in features.items()]
        self.iterations = iterations
        self.do_recursions()
        
    def do_a_recursion(self):
        '''
        The method does a single WL recursion.
        :return new_features: The hash table with extracted WL features.
        '''
        new_features = {}
        for node in self.nodes:
            # 그래프의 이웃들을 리스트로 가져옴.
            nebs = self.graph.neighbors(node)
            # 이웃들의 feature 추출
            degs = [self.features[neb] for neb in nebs]
            # 이웃들의 feature 와 기존 가지고 있던 ego-node의 feature을 결합함
            ## ex) ego-node ; '정이태' - 남자 ㆍ 일반인 특성 + 이웃 ; '이지은' 의 여자 ㆍ 가수 특성 
            features = [str(self.features[node])] + sorted([str(deg) for deg in degs])
            features = '_'.join(features)
            # feature 값들을 hashlib 활용하여 emb
            ## hsahlib 는 MD5, SHA256 등의 알고리즘으로 문자열을 해싱(hashing)할 때 사용하는 모듈이다.
            hash_object = hashlib.md5(features.encode())
            hashing = hash_object.hexdigest()
            new_features[node] = hashing
        self.extracted_features = self.extracted_features + list(new_features.values())
        return new_features
    
    def do_recursions(self):
        '''
        The method does a series of WL recursion.
        '''
        for _ in range(self.iterations):
            self.features = self.do_a_recursion()

##  graph2doc

In [None]:
# 알고리즘 코드
def feature_extractor(day, rounds):
    G, featuers, day = dataset_preprocess(day)
    machine = WeisfeilerLehmanMachine(G, features, rounds)
    doc = TaggedDocument(words=machine.extracted_features, tags=[day])
    return doc

## standard for algorithm

In [None]:
def dataset_preprocess(day, feat = 'closeness', version= 'generation'):
    '''
    Function to read the graph and features
    '''
    import networkx as nx
    G = graphbox[day]
    
    if feat == 'closeness':
        # closeness 
        features = {k: v for k , v in nx.closeness_centrality(G, distance='rel_value').items()}
    elif feat == 'vanilla': 
        # only use the rel value 
        features = collections.defaultdict()
        features[cust_id] = 0
        for itr in range(len(G.edges())):
            if version == 'generation':
            # version 1
                features[list(G.edges(data=True))[itr][1]] = list(G.edges(data=True))[itr][2]['rel_value']
            # version 2
            elif version == 'anomaly':
                features[list(G.edges(data=True))[itr][1]] = list(G.edges(data=True))[itr][2]['rel_value_2']
            else:
                print('enter the version')
    else:
        # original version of graph2vec
        features = {k : v for k , v in dict(nx.degree(G)).items()}
    
    return G, features, day

### parameters 

In [None]:
## vector_size ; number of dimension
## workers ; number of workers.(resource)
## epochs ; number of training epochs.
## min-count ; Minimal feature count to keep.
## alpha ; Initial learning rate.
## sample ; Down sampling rate for frequent features.

# dm ; Defines the training algorithm ; dm =1 -> 'distributed memory' otherwise, pbow
parameters = {'vector_size' : 1,
              'workers' : 8,
              'epochs' : 1,
              'min_count' : 0,
              'alpha' : 0.025,
              'sample' : 0.0001,
              'window' : 5, #fixed
               'dm' : 0, #fixed
              'rounds': 2
            }

In [None]:
# rounds 가 증가할수록 feature의 갯수가 늘어나므로, cost가 커짐.

from tqdm import tqdm
document_collections = collections.defaultdict(str)
for i in tqdm(range(len(graphbox))):
    G , features , day = dataset_preprocess(list(graphbox.keys())[i], feat = 'vanilla', version=version)
    # print(feature_extractor(list(graphbox.keys())[i], rounds=10))
    document_collections[day] = feature_extractor(list(graphbox.keys())[i], rounds=parameters['rounds'])

print(f"{len(document_collections[random.choice(daylist)].words)} is the {parameters['rounds']} round")

## doc2vec training

In [None]:
doc = [v for _ , v in document_collections.items()]

# before training, we check the fact which is the matching size of both graph and document
assert len(graphbox) == len(doc)

model= Doc2Vec(doc,
               vector_size = parameters['vector_size'],
               dm = parameters['dm'],
              workers = parameters['workers'],
              min_count = parameters['min_count'],
              window = parameters['window'],
              epochs = parameters['epochs'],
              alpha = parameters['alpha'],
             )

## chart for check the distribution of user behavior (vector)

In [None]:
def vector_compression(method = 'max', n_components=None):
    if method == 'max':
        vector = [list(vectors_dict.values())[itr].max() for itr in range(len(vectors_dict))]
    elif method == 'min':
        vector = [list(vectors_dict.values())[itr].min() for itr in range(len(vectors_dict))]
    elif method == 'mean':
        vector = [list(vectors_dict.values())[itr].mean() for itr in range(len(vectors_dict))]
    elif method == 'tsne':
        vector = tsne.fit_transform(np.array(list(vectors_dict.values())).reshape(-1,1))
    elif method == 'pca':
        vector = pca.fit_transform(np.array(list(vectors_dict.values())).reshape(-1,1))
    else:
        print('plz input your method')
        method = input(str)
        vector_compression(method)
    return vector

In [None]:
vectors_dict = collections.defaultdict(float)
try:
    for itr in tqdm(range(0, len(doc))):
        vectors_dict[sorted(list(graphbox.keys()))[itr]] = model.__getitem__(sorted(list(graphbox.keys()))[itr])
except:
    pass

## x축 날짜 한달 단위로 보기위해 handling
# xticks 
daylist_viz = list(pd.date_range(start = sorted(sgidf_.base_date.unique())[0],
              end = sorted(sgidf_.base_date.unique())[-1],
             freq='M').strftime('%Y-%m-%d'))

daylist_viz.insert(0, list(vectors_dict.keys())[0])

### vector max version

In [None]:
vector = vector_compression(method = 'max')

plt.figure(figsize=(15,8))
plt.plot(vectors_dict.keys(), vector, color='red', marker='o')
plt.title(f'{cust_id} behavior pattern of SGI ', fontsize=14)
plt.xticks(daylist_viz, rotation=45) #datelist
plt.xlabel('date', fontsize=12)
plt.ylabel('vectors', fontsize=12)

plt.grid(True)
plt.show()

### vector min version 

In [None]:
vector = vector_compression(method = 'min')

plt.figure(figsize=(15,8))
plt.plot(vectors_dict.keys(), vector, color='red', marker='o')
plt.title(f'{cust_id} behavior pattern of SGI ', fontsize=14)
plt.xticks(daylist_viz, rotation=45) #datelist
plt.xlabel('date', fontsize=12)
plt.ylabel('vectors', fontsize=12)

plt.grid(True)
plt.show()

### vector mean version

In [None]:
vector = vector_compression(method = 'mean')

plt.figure(figsize=(15,8))
plt.plot(vectors_dict.keys(), vector, color='red', marker='o')
plt.title(f'{cust_id} behavior pattern of SGI ', fontsize=14)
plt.xticks(daylist_viz, rotation=45) #datelist
plt.xlabel('date', fontsize=12)
plt.ylabel('vectors', fontsize=12)

plt.grid(True)
plt.show()

# check difference 

In [None]:
#vectormean = [vectors_dict[list(vectors_dict.keys())[itr]].mean() for itr in range(len(vectors_dict))]
vectorday = list(vectors_dict.keys())
vectorminus = collections.defaultdict(str)

try:
    for itr in range(len(vector)):
        itr_ = itr+1
        vectorminus[(vectorday[itr_],vectorday[itr])]  = abs(vector[itr_] - vector[itr])
except:
    pass

vectorminus_ = sorted(vectorminus.items(), key = lambda item: item[1], reverse=True)

## anomaly

In [None]:
anomaly , normal = 0 , -1
nextday , prevday = vectorminus_[anomaly][0]

In [None]:
print(nextday)
sorted(list(graphbox[nextday].edges(data=True)))

In [None]:
print(prevday)
sorted(list(graphbox[prevday].edges(data=True)))

In [None]:
# 861200 , 의료
# 861201 , 병원-내과의원
# 431016 , 목욕탕

## normal

In [None]:
nextday , prevday = vectorminus_[normal][0]

In [None]:
print(nextday)
sorted(list(graphbox[nextday].edges(data=True)))

In [None]:
print(prevday)
sorted(list(graphbox[prevday].edges(data=True)))

In [None]:
query = f"SELECT * FROM tb_wed_com_prof WHERE cust_pty_sbt_id LIKE '%{cust_id}%'"
prof_ = pd.read_sql(query, con=conn)

In [None]:
prof_

In [None]:
vectorminus_

# LAB

## normalization 

In [None]:
# hash.hexdigest()
## func ; string --> digest_size bytes which may contain non-ASCII characters, including null bytes.
## like digest() except the digest is returned as a string of double length, containing only 
## hexadecimal digits. This may be used to exchange the value safely in email or other non-binary environment.
# from hashlib of python documentation.

import hashlib

text = ['100','200','300','1000','2000','3000','10000','20000','30000']

for num in text:
    hash_obj = hashlib.md5(num.encode())
    hash_result = hash_obj.hexdigest()
    print(f'num is {num} -->  hash_result is {hash_result}\n\n')

## revalance 1,2 버전 효용성 체크

In [None]:
sgidf_[['timezn_div_cd','sgi_behavior','rel_value_2','rel_value']].sort_values('rel_value_2',ascending=False)[:20]

In [None]:
sgidf_[['timezn_div_cd','sgi_behavior','rel_value_2','rel_value']].sort_values('rel_value_2',ascending=False)[-20:]

### rel_value_2 at E  

In [None]:
sgidf_[sgidf_['timezn_div_cd'] == 'E'].sort_values('rel_value_2',ascending=False)[:20]

In [None]:
sgidf_[sgidf_['timezn_div_cd'] == 'E'].sort_values('rel_value_2',ascending=False)[-20:]

In [None]:
s1apdf_[['rel_value','rel_value_2','address_x']].sort_values('rel_value_2')[:20]

### rel_value at E 

In [None]:
sgidf_[sgidf_['timezn_div_cd'] == 'E'].sort_values('rel_value',ascending=False)[:20]

In [None]:
sgidf_[sgidf_['timezn_div_cd'] == 'E'].sort_values('rel_value',ascending=False)[-20:]

## trouble-shooting

In [None]:
## trouble shooting code
# for itr in range(len(G.edges())):
#     features[list(G.edges(data=True))[itr][1]] = list(G.edges(data=True))[itr][2]['rel_value']
    
# dict_feature = collections.defaultdict(str)
# dict_feature[cust_id] = 0
# for itr in tqdm(range(len(G.edges()))):
#     dict_feature[list(G.edges(data=True))[itr][1]] = list(G.edges(data=True))[itr][2]['rel_value']

## todolist 

### network generation func

In [None]:
## TODOLIST function 만들기##
############################
# def network_gen(netdf, source , target , edge_attr , daylist=None):
#     daygrpah_netdf = collections.defaultdict(nx.Graph)
#     if netdf == sgidf_:
#         daylist = sgidf_['base_date'].unique()
#     else:
#         for day in tqdm(daylist):
#             G = nx.from_pandas_edgelist(df=netdf[netdf['base_date'] == day][[source,target,edge_attr]],\
#                                        source = source,
#                                        target = target,
#                                        edge_attr = [edge_attr])
#             daygraph_netdf[day] = G
#         return daygraph

# network_gen(netdf = sgidf_, source='cust_pty_sbt_id', target='sgi_behavior', edge_attr='sgi_rel_value', daylist = )

### datetime auto annotation

In [None]:
# ## 함수 보완 필요 datatype 에서 datetime 이 인식 error 발생.
# def timefeature(df ,
#                 timecolumn='timezn_div_cd',
#                 datecolumn='base_date'):
    
#     df[timecolumn] = df[timecolumn].apply(lambda x : x[:1])
#     if type(df[datecolumn][0]) != "<class 'datetime.date'>":
#         df[datecolumn].apply(lambda x : x[:4] + '-' + x[4:6] + '-' + x[6:])
#     else:
#         df[datecolumn] = df[datecolumn].map(lambda x : x.strftime('20%y-%m-%d'))
    
#     return df

In [None]:
# def dataset_preprocess(day, feat = 'closeness'):
#     '''
#     Function to read the graph and features
#     '''
#     import networkx as nx
#     G = graphbox[day]
    
#     if feat == 'closeness':
#         # closeness 
#         features = {k: v for k , v in nx.closeness_centrality(G, distance='rel_value').items()}
#     elif feat == 'vanilla': 
#         # only use the rel value 
#         features = {k : v for k , v in enumerate(list(nx.get_edge_attributes(G, 'rel_value').values()))}
#     else:
#         # original version of graph2vec
#         features = nx.degree(G)
#         features = {int(k) : v for k, v in enumerate(list(dict(features).values()))}
#         features = {int(k) : v for k, v in enumerate(list(dict(nx.degree(G)).values()))}
#         features = {k : v for k , v in dict(nx.degree(G)).items()}
    
#     return G, features, day

# lab2