In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset

from model.GCN_GRU import GCN_GRU_singleGCN
from utils import *

In [2]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 1000
learning_rate = 0.0001
batch_size = 1000

weights_path = 'GCN_GRU_Link_Prediction_without_newf'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log_without_shop.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start maping encodeing...
Finish !!


In [5]:
list_months = sorted(df_cdtx.csmdt.unique())

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [6]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:18<00:00,  1.32it/s]


In [7]:
# neg_edges_dict = {}
# for i in tqdm(list(edge_dict.values())):
#     neg_edges_dict[i] = sample_neg_edges(i, n_users+n_shops, n_users)

neg_edges_dict = np.load('neg_edges', allow_pickle=True)

In [8]:
model = GAE(GCN_GRU_singleGCN(input_dim, embedding_dim, entity_dim, emb_dims)).to(device)

x_num = []
x_cat = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
neg_edges_index = [i.to(device) for i in neg_edges_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [9]:
def train():
    model.train()
    loss_ = 0
    for i in tqdm(range(10)):
        optimizer.zero_grad()
        z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
        loss = model.recon_loss(z, pos_edge_index[i+12], neg_edges_index[i+12])
        loss_ += loss.item()
        loss.backward()
        optimizer.step()
    return loss_/10

In [10]:
def test(pos_edge_index, neg_edge_index):
    model.eval()
    total_auc = 0
    total_ap = 0
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            auc, ap = model.test(z, pos_edge_index[i+12], neg_edges_index[i])
            
            total_auc += auc
            total_ap += ap
            
    return total_auc/2, total_ap/2

In [None]:
for epoch in range(epochs):
    loss = train()
    
    auc, ap = test(pos_edge_index, neg_edges_index)
    print('Epoch: {:03d}, Train Loss:{:.4f}, AUC: {:.4f}, AP: {:.4f}'.format(epoch+1, loss, auc, ap))

100%|██████████| 10/10 [00:03<00:00,  2.71it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 001, Train Loss:2.2981, AUC: 0.5820, AP: 0.7145


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 002, Train Loss:1.4209, AUC: 0.7949, AP: 0.8723


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 003, Train Loss:1.1078, AUC: 0.9685, AP: 0.9792


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 004, Train Loss:0.9186, AUC: 0.9836, AP: 0.9895


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 005, Train Loss:0.8153, AUC: 0.9795, AP: 0.9867


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 006, Train Loss:0.7411, AUC: 0.9795, AP: 0.9866


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 007, Train Loss:0.6866, AUC: 0.9837, AP: 0.9894


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 008, Train Loss:0.6422, AUC: 0.9865, AP: 0.9912


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 009, Train Loss:0.6006, AUC: 0.9895, AP: 0.9931


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 010, Train Loss:0.5601, AUC: 0.9921, AP: 0.9947


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 011, Train Loss:0.5204, AUC: 0.9939, AP: 0.9959


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 012, Train Loss:0.4804, AUC: 0.9953, AP: 0.9969


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 013, Train Loss:0.4423, AUC: 0.9964, AP: 0.9975


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 014, Train Loss:0.4054, AUC: 0.9971, AP: 0.9980


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 015, Train Loss:0.3696, AUC: 0.9977, AP: 0.9984


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 016, Train Loss:0.3354, AUC: 0.9982, AP: 0.9987


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 017, Train Loss:0.3032, AUC: 0.9985, AP: 0.9990


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 018, Train Loss:0.2733, AUC: 0.9988, AP: 0.9992


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 019, Train Loss:0.2458, AUC: 0.9990, AP: 0.9993


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 020, Train Loss:0.2206, AUC: 0.9992, AP: 0.9995


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 021, Train Loss:0.1976, AUC: 0.9994, AP: 0.9996


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 022, Train Loss:0.1778, AUC: 0.9995, AP: 0.9996


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 023, Train Loss:0.1610, AUC: 0.9995, AP: 0.9997


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 024, Train Loss:0.1460, AUC: 0.9996, AP: 0.9997


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 025, Train Loss:0.1334, AUC: 0.9997, AP: 0.9998


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 026, Train Loss:0.1229, AUC: 0.9997, AP: 0.9998


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 027, Train Loss:0.1137, AUC: 0.9997, AP: 0.9998


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 028, Train Loss:0.1057, AUC: 0.9998, AP: 0.9998


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 029, Train Loss:0.0985, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 030, Train Loss:0.0921, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 031, Train Loss:0.0866, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 032, Train Loss:0.0814, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 033, Train Loss:0.0765, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 034, Train Loss:0.0716, AUC: 0.9998, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 035, Train Loss:0.0676, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 036, Train Loss:0.0633, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 037, Train Loss:0.0593, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 038, Train Loss:0.0552, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 039, Train Loss:0.0522, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 040, Train Loss:0.0495, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 041, Train Loss:0.0477, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.88it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 042, Train Loss:0.0455, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 043, Train Loss:0.0441, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 044, Train Loss:0.0421, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 045, Train Loss:0.0409, AUC: 0.9999, AP: 0.9999


100%|██████████| 10/10 [00:03<00:00,  2.88it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 046, Train Loss:0.0390, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 047, Train Loss:0.0381, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 048, Train Loss:0.0362, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 049, Train Loss:0.0358, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 050, Train Loss:0.0336, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 051, Train Loss:0.0338, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 052, Train Loss:0.0312, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.86it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 053, Train Loss:0.0323, AUC: 0.9999, AP: 1.0000


100%|██████████| 10/10 [00:03<00:00,  2.87it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 054, Train Loss:0.0288, AUC: 0.9999, AP: 1.0000


 80%|████████  | 8/10 [00:02<00:00,  2.86it/s]

In [None]:
torch.save(model.encoder.state_dict(), weights_path)