In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

from model.GCN_GRU import SingleGCN_GRU, Decoder
from utils import *

In [2]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 1000
learning_rate = 0.01
batch_size = 1000


weights_path = 'SingleGCN_GRU_Link_Prediction'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log_without_shop.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start maping encodeing...
Finish !!


In [None]:
list_months = sorted(df_cdtx.csmdt.unique())
ma = np.log(df_cdtx.groupby(['chid', 'csmdt']).objam.sum().max()+1)

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [6]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)/edge_weights.max()
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:18<00:00,  1.35it/s]


In [7]:
model = GAE(SingleGCN_GRU(input_dim, embedding_dim, entity_dim, emb_dims), Decoder(embedding_dim, 1)).to(device)

x_num = []
x_cat = []
y = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    temp_y = df_cust[df_cust.data_dt==i][['objam']].to_numpy()
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))
    y.append(torch.Tensor(temp_y*ma).to(device))
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [8]:
model.encoder.load_state_dict(torch.load(weights_path))
model.eval()

criterion = torch.nn.MSELoss()

In [9]:
def train():
    model.train()
    train_output = np.array([])
    train_y = np.array([])
    for i in tqdm(range(10)):
        
        train_dataset = TensorDataset(y[i+12])
        train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size)
        
        for j, true_y in enumerate(train_loader):
            optimizer.zero_grad()
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[j*batch_size:(j+1)*batch_size])
            loss = criterion(output, true_y[0])
            train_output = np.concatenate([train_output, output.cpu().detach().numpy().reshape(-1)])
            train_y = np.concatenate([train_y, true_y[0].cpu().detach().numpy().reshape(-1)])
        
            loss.backward()
            optimizer.step()
        
            
    return loss/10, train_output, train_y

In [10]:
def test():
    model.eval()
    test_output = np.array([])
    test_y = np.array([])
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[:n_users])
            test_output = np.concatenate([test_output, output.cpu().detach().numpy().reshape(-1)])
            test_y = np.concatenate([test_y, y[i+12].cpu().detach().numpy().reshape(-1)])
            
    return test_output, test_y

In [11]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    train_RMSE = mean_squared_error(np.e**(train_output), np.e**(train_y), squared=False)
    test_RMSE = mean_squared_error(np.e**(test_output), np.e**(test_y), squared=False)
    
    train_MAE = mean_absolute_error(np.e**(train_output), np.e**(train_y))
    test_MAE = mean_absolute_error(np.e**(test_output), np.e**(test_y))
    
    print(f'epoch:{epoch+1}\ntrain loss:{train_RMSE:.0f},test loss:{test_RMSE:.0f}\ntrain MAE(mean):{train_MAE:.0f},test MAE(mean):{test_MAE:.0f}')
    

100%|██████████| 10/10 [01:58<00:00, 11.90s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:1
train loss:1015595,test loss:608233
train MAE(mean):68858,test MAE(mean):70850


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:2
train loss:1013893,test loss:605378
train MAE(mean):65116,test MAE(mean):69562


100%|██████████| 10/10 [01:58<00:00, 11.89s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:3
train loss:1013343,test loss:607318
train MAE(mean):64671,test MAE(mean):69549


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:4
train loss:1013181,test loss:605292
train MAE(mean):64399,test MAE(mean):70379


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:5
train loss:1012862,test loss:605547
train MAE(mean):64129,test MAE(mean):68973


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:6
train loss:1012605,test loss:603040
train MAE(mean):63821,test MAE(mean):69483


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:7
train loss:1012556,test loss:603311
train MAE(mean):63719,test MAE(mean):69328


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:8
train loss:1012465,test loss:602717
train MAE(mean):63516,test MAE(mean):68929


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:9
train loss:1012338,test loss:603697
train MAE(mean):63561,test MAE(mean):68786


100%|██████████| 10/10 [01:59<00:00, 11.91s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:10
train loss:1012236,test loss:603811
train MAE(mean):63404,test MAE(mean):68232


100%|██████████| 10/10 [01:57<00:00, 11.78s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:11
train loss:1011835,test loss:604174
train MAE(mean):63177,test MAE(mean):68611


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:12
train loss:1011912,test loss:602185
train MAE(mean):63094,test MAE(mean):67757


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:13
train loss:1013581,test loss:605098
train MAE(mean):63141,test MAE(mean):68506


100%|██████████| 10/10 [01:58<00:00, 11.87s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:14
train loss:1011739,test loss:604267
train MAE(mean):62917,test MAE(mean):68102


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:15
train loss:1011645,test loss:606139
train MAE(mean):62869,test MAE(mean):68776


100%|██████████| 10/10 [01:58<00:00, 11.82s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:16
train loss:1011593,test loss:603809
train MAE(mean):62765,test MAE(mean):67423


100%|██████████| 10/10 [01:58<00:00, 11.89s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:17
train loss:1011562,test loss:604400
train MAE(mean):62776,test MAE(mean):68031


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:18
train loss:1011550,test loss:604093
train MAE(mean):62757,test MAE(mean):68011


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:19
train loss:1011711,test loss:605944
train MAE(mean):62897,test MAE(mean):69008


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:20
train loss:1011690,test loss:599324
train MAE(mean):62805,test MAE(mean):67186


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:21
train loss:1011401,test loss:603471
train MAE(mean):62703,test MAE(mean):67506


100%|██████████| 10/10 [01:58<00:00, 11.87s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:22
train loss:1011383,test loss:605058
train MAE(mean):62634,test MAE(mean):67974


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:23
train loss:1011562,test loss:605313
train MAE(mean):62652,test MAE(mean):68545


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:24
train loss:1011629,test loss:602456
train MAE(mean):62649,test MAE(mean):67973


100%|██████████| 10/10 [01:58<00:00, 11.80s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:25
train loss:1011444,test loss:602392
train MAE(mean):62480,test MAE(mean):67185


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:26
train loss:1011567,test loss:602531
train MAE(mean):62490,test MAE(mean):67333


100%|██████████| 10/10 [01:58<00:00, 11.82s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:27
train loss:1011595,test loss:604661
train MAE(mean):62561,test MAE(mean):67528


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:28
train loss:1011171,test loss:604385
train MAE(mean):62390,test MAE(mean):67652


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:29
train loss:1011416,test loss:600462
train MAE(mean):62490,test MAE(mean):66370


100%|██████████| 10/10 [01:58<00:00, 11.82s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:30
train loss:1011121,test loss:603277
train MAE(mean):62384,test MAE(mean):67476


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:31
train loss:1011267,test loss:601368
train MAE(mean):62383,test MAE(mean):66683


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:32
train loss:1011212,test loss:602447
train MAE(mean):62489,test MAE(mean):66806


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:33
train loss:1011209,test loss:603679
train MAE(mean):62270,test MAE(mean):67284


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:34
train loss:1011207,test loss:605248
train MAE(mean):62343,test MAE(mean):68085


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:35
train loss:1011551,test loss:602684
train MAE(mean):62621,test MAE(mean):66893


100%|██████████| 10/10 [01:58<00:00, 11.87s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:36
train loss:1011501,test loss:602653
train MAE(mean):62475,test MAE(mean):66877


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:37
train loss:1011035,test loss:601019
train MAE(mean):62238,test MAE(mean):66757


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:38
train loss:1011106,test loss:602458
train MAE(mean):62317,test MAE(mean):66872


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:39
train loss:1011226,test loss:604046
train MAE(mean):62298,test MAE(mean):67478


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:40
train loss:1011293,test loss:597491
train MAE(mean):62318,test MAE(mean):67415


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:41
train loss:1011207,test loss:604722
train MAE(mean):62454,test MAE(mean):67687


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:42
train loss:1012837,test loss:607764
train MAE(mean):63747,test MAE(mean):69533


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:43
train loss:1012279,test loss:603444
train MAE(mean):62860,test MAE(mean):67036


100%|██████████| 10/10 [01:58<00:00, 11.81s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:44
train loss:1011685,test loss:602011
train MAE(mean):62356,test MAE(mean):66615


100%|██████████| 10/10 [01:58<00:00, 11.89s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:45
train loss:1011391,test loss:601925
train MAE(mean):62269,test MAE(mean):66421


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:46
train loss:1011702,test loss:603712
train MAE(mean):62304,test MAE(mean):67567


100%|██████████| 10/10 [01:58<00:00, 11.87s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:47
train loss:1011272,test loss:603353
train MAE(mean):62182,test MAE(mean):67318


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:48
train loss:1011264,test loss:606202
train MAE(mean):62304,test MAE(mean):69022


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:49
train loss:1011282,test loss:604648
train MAE(mean):62286,test MAE(mean):68028


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:50
train loss:1011397,test loss:602772
train MAE(mean):62332,test MAE(mean):67189


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:51
train loss:1011295,test loss:602079
train MAE(mean):62330,test MAE(mean):66746


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:52
train loss:1010996,test loss:604483
train MAE(mean):62125,test MAE(mean):67775


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:53
train loss:1010758,test loss:604793
train MAE(mean):62155,test MAE(mean):67447


100%|██████████| 10/10 [01:58<00:00, 11.83s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:54
train loss:1010994,test loss:604675
train MAE(mean):62312,test MAE(mean):67987


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:55
train loss:1018098,test loss:604798
train MAE(mean):62428,test MAE(mean):67614


100%|██████████| 10/10 [01:58<00:00, 11.82s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:56
train loss:1011141,test loss:604277
train MAE(mean):62218,test MAE(mean):67434


100%|██████████| 10/10 [01:58<00:00, 11.88s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:57
train loss:1010990,test loss:602556
train MAE(mean):62161,test MAE(mean):67482


100%|██████████| 10/10 [01:58<00:00, 11.86s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:58
train loss:1010994,test loss:604086
train MAE(mean):62168,test MAE(mean):67592


100%|██████████| 10/10 [01:59<00:00, 11.91s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:59
train loss:1010876,test loss:602235
train MAE(mean):62185,test MAE(mean):66789


100%|██████████| 10/10 [01:58<00:00, 11.82s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:60
train loss:1010801,test loss:604733
train MAE(mean):62106,test MAE(mean):67929


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:61
train loss:1010894,test loss:602011
train MAE(mean):62121,test MAE(mean):66687


100%|██████████| 10/10 [01:58<00:00, 11.84s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:62
train loss:1010557,test loss:602102
train MAE(mean):61988,test MAE(mean):66511


100%|██████████| 10/10 [01:58<00:00, 11.85s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:63
train loss:1010518,test loss:604022
train MAE(mean):61965,test MAE(mean):67902


 10%|█         | 1/10 [00:14<02:14, 14.96s/it]


KeyboardInterrupt: 