In [1]:
import os
import copy
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import roc_auc_score, f1_score
from sklearn.preprocessing import MinMaxScaler

from model.GCN_GRU import GCN_GRU, Decoder
from utils import *

In [10]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 100
learning_rate = 0.0001
batch_size = 1000

weights_path = 'GCN_GRU_Link_Prediction'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start mapping encoding...
Finish !!


In [5]:
shopping_y = (df_cust[['objam']].to_numpy() > 0 ).astype(int)

In [6]:
list_months = sorted(df_cdtx.csmdt.unique())

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [7]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:18<00:00,  1.34it/s]


In [8]:
model = GAE(GCN_GRU(input_dim, embedding_dim, entity_dim, emb_dims), Decoder(embedding_dim, 1)).to(device)

x_num = []
x_cat = []
y = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    temp_y = shopping_y[df_cust.data_dt==i]
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))
    y.append(torch.Tensor(temp_y).to(device))
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [11]:
model.encoder.load_state_dict(torch.load(weights_path))
model.eval()

criterion = torch.nn.BCEWithLogitsLoss()

In [12]:
def train():
    model.train()
    train_output = np.array([])
    train_y = np.array([])
    for i in tqdm(range(10)):
        
        train_dataset = TensorDataset(y[i+12])
        train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size)
        
        for j, true_y in enumerate(train_loader):
            optimizer.zero_grad()
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[j*batch_size:(j+1)*batch_size])
            loss = criterion(output, true_y[0].float())
            train_output = np.concatenate([train_output, output.cpu().detach().numpy().reshape(-1)])
            train_y = np.concatenate([train_y, true_y[0].cpu().detach().numpy().reshape(-1)])
        
            loss.backward(retain_graph=True)
            optimizer.step()
        
            
    return loss/10, train_output, train_y

In [13]:
def test():
    model.eval()
    test_output = np.array([])
    test_y = np.array([])
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[:n_users])
            test_output = np.concatenate([test_output, output.cpu().detach().numpy().reshape(-1)])
            test_y = np.concatenate([test_y, y[i+12].cpu().detach().numpy().reshape(-1)])
            
    return test_output, test_y

In [14]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    train_AUC = roc_auc_score(train_y, train_output)
    test_AUC = roc_auc_score(test_y, test_output)
    train_F1 = f1_score(train_y, train_output>0.5)
    test_F1 = f1_score(test_y, test_output>0.5)
    
    print(f'epoch:{epoch+1}\ntrain AUC:{train_AUC},test AUC:{test_AUC}')
    print(f'epoch:{epoch+1}\ntrain F1:{train_F1},test F1:{test_F1}')
    

100%|██████████| 10/10 [02:03<00:00, 12.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:1
train AUC:0.846231625575033,test AUC:0.9138842834730228
epoch:1
train F1:0.8238176872614231,test F1:0.8849011165438226


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:3
train AUC:0.9062230237172493,test AUC:0.9174353208661582
epoch:3
train F1:0.8755921083795178,test F1:0.8845214987353686


100%|██████████| 10/10 [02:02<00:00, 12.27s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:4
train AUC:0.9070277983957232,test AUC:0.9179871905229334
epoch:4
train F1:0.8758187352686874,test F1:0.8849545815327631


100%|██████████| 10/10 [02:02<00:00, 12.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:5
train AUC:0.9075918963141386,test AUC:0.9183771681222891
epoch:5
train F1:0.8759973463331417,test F1:0.8854669222126759


100%|██████████| 10/10 [02:02<00:00, 12.27s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:6
train AUC:0.9080448408515106,test AUC:0.9186854379032326
epoch:6
train F1:0.8761879655040145,test F1:0.8859802936148671


100%|██████████| 10/10 [02:02<00:00, 12.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:7
train AUC:0.9084267121519797,test AUC:0.9189445696115903
epoch:7
train F1:0.8764146458146035,test F1:0.8863884592548068


100%|██████████| 10/10 [02:02<00:00, 12.27s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:8
train AUC:0.9087638032088422,test AUC:0.9191682277602592
epoch:8
train F1:0.8766123072229289,test F1:0.8868672218926027


100%|██████████| 10/10 [02:02<00:00, 12.26s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:9
train AUC:0.909069449566519,test AUC:0.919364633110046
epoch:9
train F1:0.8768055105180474,test F1:0.8870861460514562


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:10
train AUC:0.9093512888066035,test AUC:0.9195455989146055
epoch:10
train F1:0.8770645406720803,test F1:0.8871012604488427


100%|██████████| 10/10 [02:02<00:00, 12.29s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:11
train AUC:0.909619501478698,test AUC:0.9197207132250096
epoch:11
train F1:0.8773568433170859,test F1:0.887309718969555


100%|██████████| 10/10 [02:02<00:00, 12.26s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:12
train AUC:0.9098874392973219,test AUC:0.9199030422448106
epoch:12
train F1:0.8776451445898973,test F1:0.8874285588835794


100%|██████████| 10/10 [02:02<00:00, 12.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:13
train AUC:0.9101635938298571,test AUC:0.9200907275235483
epoch:13
train F1:0.8778786032452129,test F1:0.8875000914752396


100%|██████████| 10/10 [02:02<00:00, 12.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:14
train AUC:0.9104538100035482,test AUC:0.9202930277545062
epoch:14
train F1:0.8781267465276442,test F1:0.8873859036576708


100%|██████████| 10/10 [02:03<00:00, 12.31s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:15
train AUC:0.9107647271925836,test AUC:0.9205130773499403
epoch:15
train F1:0.8783248376757774,test F1:0.8874343824814953


100%|██████████| 10/10 [02:02<00:00, 12.28s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:16
train AUC:0.911105652331118,test AUC:0.9207572925543327
epoch:16
train F1:0.8785601523097755,test F1:0.8874242990106694


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:17
train AUC:0.9114783201593131,test AUC:0.9210445090855048
epoch:17
train F1:0.8788289885913577,test F1:0.8876325502372723


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:18
train AUC:0.9118756459180312,test AUC:0.9213693311328206
epoch:18
train F1:0.8792820474941831,test F1:0.887915757407109


100%|██████████| 10/10 [02:03<00:00, 12.31s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:19
train AUC:0.9122694107474383,test AUC:0.9216795406249887
epoch:19
train F1:0.8798552918158662,test F1:0.888301649232481


100%|██████████| 10/10 [02:02<00:00, 12.27s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:20
train AUC:0.91262848343039,test AUC:0.9219560326529243
epoch:20
train F1:0.8804136371553147,test F1:0.8888059537661084


100%|██████████| 10/10 [02:02<00:00, 12.29s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:21
train AUC:0.9129437878484844,test AUC:0.9221772280491064
epoch:21
train F1:0.8808905796470559,test F1:0.889235249782305


100%|██████████| 10/10 [02:03<00:00, 12.31s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:22
train AUC:0.9132116720439227,test AUC:0.9223536346214003
epoch:22
train F1:0.8810857554364508,test F1:0.8897683877858574


100%|██████████| 10/10 [02:03<00:00, 12.31s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:23
train AUC:0.9134464468417751,test AUC:0.9224829221754395
epoch:23
train F1:0.8811914798153331,test F1:0.8898496653132887


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:24
train AUC:0.9136554825869322,test AUC:0.9225778253996221
epoch:24
train F1:0.8812118689228308,test F1:0.8899033572562534


100%|██████████| 10/10 [02:03<00:00, 12.31s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:25
train AUC:0.9138464224134868,test AUC:0.9226446109610943
epoch:25
train F1:0.8812461338817377,test F1:0.8898983913797266


100%|██████████| 10/10 [02:02<00:00, 12.30s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:26
train AUC:0.9140223077528435,test AUC:0.9226930982284435
epoch:26
train F1:0.881270191893121,test F1:0.8901258045640725


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:27
train AUC:0.9141907735775914,test AUC:0.9227260503073647
epoch:27
train F1:0.8812979652472078,test F1:0.8901776376893205


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:28
train AUC:0.9143505315646505,test AUC:0.9227504148733566
epoch:28
train F1:0.88132674431313,test F1:0.8901223462992621


  0%|          | 0/10 [00:06<?, ?it/s]


KeyboardInterrupt: 