In [1]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

from model.GCN_GRU import GCN_GRU, Decoder
from utils import *

In [28]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 1000
learning_rate = 0.001
batch_size = 1000

weights_path = 'GCN_GRU_Link_Prediction'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log_without_shop.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start maping encodeing...
Finish !!


In [10]:
list_months = sorted(df_cdtx.csmdt.unique())
ma = df_cdtx.groupby(['chid', 'csmdt']).objam.count().max()

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [9]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:21<00:00,  1.19it/s]


In [11]:
temp_cdtx = df_cdtx.groupby(['csmdt', 'chid']).count()

df_cdtx_count = pd.DataFrame(list(map(list, temp_cdtx.index)), columns=['data_dt','chid' ])
df_cdtx_count['count'] = temp_cdtx.objam.values

df_y = pd.DataFrame({'chid':df_cust.chid, 'data_dt':df_cust.data_dt})
df_y = df_y.merge(df_cdtx_count,
                  how='left',
                  left_on=['chid', 'data_dt'],
                  right_on=['chid', 'data_dt']).fillna(0)

y_scaler = MinMaxScaler()
df_y['count'] = y_scaler.fit_transform(df_y[['count']])

In [12]:
model = GAE(GCN_GRU(input_dim, embedding_dim, entity_dim, emb_dims), Decoder(embedding_dim, 1)).to(device)

x_num = []
x_cat = []
y = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    temp_y = df_y[df_y.data_dt==i][['count']].to_numpy()
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))
    y.append(torch.Tensor(temp_y).to(device))
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [29]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [13]:
model.encoder.load_state_dict(torch.load(weights_path))
model.eval()

criterion = torch.nn.MSELoss()

In [30]:
def train():
    model.train()
    train_output = np.array([])
    train_y = np.array([])
    for i in tqdm(range(10)):
        
        train_dataset = TensorDataset(y[i+12])
        train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size)
        
        for j, true_y in enumerate(train_loader):
            optimizer.zero_grad()
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[j*batch_size:(j+1)*batch_size])
            loss = criterion(output, true_y[0])
            train_output = np.concatenate([train_output, output.cpu().detach().numpy().reshape(-1)])
            train_y = np.concatenate([train_y, true_y[0].cpu().detach().numpy().reshape(-1)])
        
            loss.backward()
            optimizer.step()
        
            
    return loss/10, train_output, train_y

In [31]:
def test():
    model.eval()
    test_output = np.array([])
    test_y = np.array([])
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[:n_users])
            test_output = np.concatenate([test_output, output.cpu().detach().numpy().reshape(-1)])
            test_y = np.concatenate([test_y, y[i+12].cpu().detach().numpy().reshape(-1)])
            
    return test_output, test_y

In [27]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    train_RMSE = mean_squared_error(train_output*ma, train_y*ma, squared=False)
    test_RMSE = mean_squared_error(test_output*ma, test_y*ma, squared=False)
    
    train_MAE = mean_absolute_error(train_output*ma, train_y*ma)
    test_MAE = mean_absolute_error(test_output*ma, test_y*ma)
    
    print(f'epoch:{epoch+1}\ntrain loss:{train_RMSE},test loss:{test_RMSE}\
    \ntrain MAE(mean):{train_MAE},test MAE(mean):{test_MAE}')
    

100%|██████████| 10/10 [02:05<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:1
train loss:11.648646227250044,test loss:16.946317691194288    
train MAE(mean):6.83003429597457,test MAE(mean):6.438409212848208


100%|██████████| 10/10 [02:05<00:00, 12.59s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:2
train loss:10.34275565127834,test loss:16.431955521573624    
train MAE(mean):5.513244531307457,test MAE(mean):5.487134947807025


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:3
train loss:9.750953488761429,test loss:16.18783474804775    
train MAE(mean):4.846725250562114,test MAE(mean):4.948956429732638


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:4
train loss:9.432450708839184,test loss:16.057675371098743    
train MAE(mean):4.450686481381176,test MAE(mean):4.607430673660683


100%|██████████| 10/10 [02:05<00:00, 12.52s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:5
train loss:9.242426795498785,test loss:15.937344615960228    
train MAE(mean):4.197609081170162,test MAE(mean):4.366094538023994


100%|██████████| 10/10 [02:04<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:6
train loss:9.124709675198295,test loss:15.8646530882119    
train MAE(mean):4.040115816384196,test MAE(mean):4.275738317358811


100%|██████████| 10/10 [02:05<00:00, 12.58s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:7
train loss:9.031311686356815,test loss:15.820259350489833    
train MAE(mean):3.9085403072984386,test MAE(mean):4.081699595335657


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:8
train loss:8.955310950394814,test loss:15.80973780175405    
train MAE(mean):3.790455284480054,test MAE(mean):3.962258344578948


100%|██████████| 10/10 [02:05<00:00, 12.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:9
train loss:8.901948628762424,test loss:15.80914139732043    
train MAE(mean):3.706838469485395,test MAE(mean):3.926177224308401


100%|██████████| 10/10 [02:05<00:00, 12.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:10
train loss:8.860661400506228,test loss:15.794951820824695    
train MAE(mean):3.639505651445618,test MAE(mean):3.8789895629319364


100%|██████████| 10/10 [02:05<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:11
train loss:8.825200131576379,test loss:15.77937059062488    
train MAE(mean):3.583417381818721,test MAE(mean):3.8291531788622493


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:12
train loss:8.798357213125371,test loss:15.765559756505048    
train MAE(mean):3.5385987902820055,test MAE(mean):3.7965738707963284


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:13
train loss:8.773598416863942,test loss:15.75493612270905    
train MAE(mean):3.497618348638212,test MAE(mean):3.7612187829432915


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:14
train loss:8.753875226818078,test loss:15.741724540486633    
train MAE(mean):3.465711737644691,test MAE(mean):3.7310334345402567


100%|██████████| 10/10 [02:05<00:00, 12.57s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:15
train loss:8.735727861018844,test loss:15.731189487027192    
train MAE(mean):3.4384235090451947,test MAE(mean):3.705632501527257


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:16
train loss:8.719622064355741,test loss:15.720239613429127    
train MAE(mean):3.41449417528533,test MAE(mean):3.685983410431314


100%|██████████| 10/10 [02:05<00:00, 12.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:17
train loss:8.705535958716805,test loss:15.711764635174092    
train MAE(mean):3.391713066253176,test MAE(mean):3.6654692448911907


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:18
train loss:8.692276285907115,test loss:15.699457728285928    
train MAE(mean):3.3729527090243137,test MAE(mean):3.6440900367145055


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:19
train loss:8.679848856103165,test loss:15.694699957869435    
train MAE(mean):3.354416757794056,test MAE(mean):3.620259435217744


100%|██████████| 10/10 [02:05<00:00, 12.52s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:20
train loss:8.668003730951032,test loss:15.679480354712169    
train MAE(mean):3.3378886165594106,test MAE(mean):3.596410153443292


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:21
train loss:8.659081840659253,test loss:15.66983453279592    
train MAE(mean):3.3248024794254714,test MAE(mean):3.577147020524321


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:22
train loss:8.650081072203845,test loss:15.664378850075222    
train MAE(mean):3.3142215158062576,test MAE(mean):3.5618418875311315


100%|██████████| 10/10 [02:05<00:00, 12.59s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:23
train loss:8.641492537356777,test loss:15.649539895759897    
train MAE(mean):3.300303124111386,test MAE(mean):3.5409011089350098


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:24
train loss:8.636977451206224,test loss:15.637843809025503    
train MAE(mean):3.2918514815624933,test MAE(mean):3.5338339759832156


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:25
train loss:8.627850900199062,test loss:15.617456105381812    
train MAE(mean):3.2823938586133345,test MAE(mean):3.495674709372176


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:26
train loss:8.62246096651255,test loss:15.602527543543728    
train MAE(mean):3.2783358086613807,test MAE(mean):3.4877243953971004


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:27
train loss:8.620253546019377,test loss:15.588440487391193    
train MAE(mean):3.280158451513192,test MAE(mean):3.525075317158755


100%|██████████| 10/10 [02:05<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:28
train loss:8.62115276471443,test loss:15.584503728264075    
train MAE(mean):3.286884260911433,test MAE(mean):3.58917033951655


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:29
train loss:8.612588503945538,test loss:15.57897537804162    
train MAE(mean):3.2702999647358935,test MAE(mean):3.5378094053335207


100%|██████████| 10/10 [02:05<00:00, 12.56s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:30
train loss:8.602321795615607,test loss:15.575550999872226    
train MAE(mean):3.2524148576881755,test MAE(mean):3.5213822896759397


  0%|          | 0/10 [00:11<?, ?it/s]


KeyboardInterrupt: 

In [32]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    train_RMSE = mean_squared_error(train_output*ma, train_y*ma, squared=False)
    test_RMSE = mean_squared_error(test_output*ma, test_y*ma, squared=False)
    
    train_MAE = mean_absolute_error(train_output*ma, train_y*ma)
    test_MAE = mean_absolute_error(test_output*ma, test_y*ma)
    
    print(f'epoch:{epoch+1}\ntrain loss:{train_RMSE},test loss:{test_RMSE}\
    \ntrain MAE(mean):{train_MAE},test MAE(mean):{test_MAE}')
    

100%|██████████| 10/10 [02:02<00:00, 12.26s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:1
train loss:14.316886125658321,test loss:15.72129677653311    
train MAE(mean):5.711408838745013,test MAE(mean):3.7261303053996153


100%|██████████| 10/10 [02:02<00:00, 12.25s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:2
train loss:8.684745085327988,test loss:15.645626224464534    
train MAE(mean):3.3524675252033584,test MAE(mean):3.921499982777694


100%|██████████| 10/10 [02:02<00:00, 12.23s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:3
train loss:8.725475904576461,test loss:16.665666067066955    
train MAE(mean):3.4652809319987874,test MAE(mean):7.394149254927402


100%|██████████| 10/10 [02:02<00:00, 12.23s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:4
train loss:8.781046243663768,test loss:16.682187081215325    
train MAE(mean):3.5790155487974453,test MAE(mean):7.434341958706854


100%|██████████| 10/10 [02:02<00:00, 12.26s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:5
train loss:8.733743802418836,test loss:16.49743289297601    
train MAE(mean):3.4973764664376805,test MAE(mean):6.9940795610826


100%|██████████| 10/10 [02:02<00:00, 12.27s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:6
train loss:8.713782531694132,test loss:15.627515791361136    
train MAE(mean):3.460601090174191,test MAE(mean):3.4072855171816796


 70%|███████   | 7/10 [01:30<00:38, 12.93s/it]


KeyboardInterrupt: 