In [5]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.preprocessing import MinMaxScaler

from model.GCN_GRU import GCN_GRU, Decoder
from utils import *

In [2]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 30
learning_rate = 0.0001
batch_size = 1000

weights_path = 'GCN_GRU_Link_Prediction'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start mapping encoding...
Finish !!


In [6]:
list_months = sorted(df_cdtx.csmdt.unique())

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [7]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:19<00:00,  1.30it/s]


In [8]:
df_cdtx_count = pd.DataFrame(list(df_cdtx.groupby(['csmdt', 'chid', shop_col]).count().index), columns=['data_dt','chid', 'kind']).groupby(['data_dt','chid']).count()

df_y = pd.DataFrame({'chid':df_cust.chid, 'data_dt':df_cust.data_dt})
df_y = df_y.merge(df_cdtx_count,
                  how='left',
                  left_on=['chid', 'data_dt'],
                  right_on=['chid', 'data_dt']).fillna(0)
kind_y = df_y['kind'].to_numpy()
ma = kind_y.max()
y_scaler = MinMaxScaler()
kind_y = y_scaler.fit_transform(kind_y.reshape(-1,1))

In [9]:
model = GAE(GCN_GRU(input_dim, embedding_dim, entity_dim, emb_dims), Decoder(embedding_dim, 1)).to(device)

x_num = []
x_cat = []
y = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    temp_y = kind_y[df_cust.data_dt==i]
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))
    y.append(torch.Tensor(temp_y).to(device))
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [10]:
model.encoder.load_state_dict(torch.load(weights_path))
model.eval()

criterion = torch.nn.MSELoss()

In [11]:
def train():
    model.train()
    train_output = np.array([])
    train_y = np.array([])
    for i in tqdm(range(10)):
        
        train_dataset = TensorDataset(y[i+12])
        train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size)
        
        for j, true_y in enumerate(train_loader):
            optimizer.zero_grad()
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[j*batch_size:(j+1)*batch_size])
            loss = criterion(output, true_y[0])
            train_output = np.concatenate([train_output, output.cpu().detach().numpy().reshape(-1)])
            train_y = np.concatenate([train_y, true_y[0].cpu().detach().numpy().reshape(-1)])
        
            loss.backward(retain_graph=True)
            optimizer.step()
        
            
    return loss/10, train_output, train_y

In [12]:
def test():
    model.eval()
    test_output = np.array([])
    test_y = np.array([])
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[:n_users])
            test_output = np.concatenate([test_output, output.cpu().detach().numpy().reshape(-1)])
            test_y = np.concatenate([test_y, y[i+12].cpu().detach().numpy().reshape(-1)])
            
    return test_output, test_y

In [None]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    train_RMSE = mean_squared_error(train_output*ma, train_y*ma, squared=False)
    test_RMSE = mean_squared_error(test_output*ma, test_y*ma, squared=False)
    
    train_MAE = mean_absolute_error(train_output*ma, train_y*ma)
    test_MAE = mean_absolute_error(test_output*ma, test_y*ma)
    
    print(f'epoch:{epoch+1}\ntrain loss:{train_RMSE},test loss:{test_RMSE}\
    \ntrain MAE(mean):{train_MAE},test MAE(mean):{test_MAE}')
    

100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:2
train loss:2.165236200424559,test loss:2.2496532584648867    
train MAE(mean):1.3666648233055472,test MAE(mean):1.422824231762588


100%|██████████| 10/10 [02:04<00:00, 12.47s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:3
train loss:2.140335429403321,test loss:2.2270054513271687    
train MAE(mean):1.3468143146381377,test MAE(mean):1.4026663377338648


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:4
train loss:2.1244635817875714,test loss:2.212710249339348    
train MAE(mean):1.334296991835773,test MAE(mean):1.3896375727665424


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:5
train loss:2.112836665059525,test loss:2.202797299825402    
train MAE(mean):1.3253864231076837,test MAE(mean):1.376225566983819


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:6
train loss:2.1043388720184604,test loss:2.196684135665813    
train MAE(mean):1.3192394706183672,test MAE(mean):1.3719727956426144


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:7
train loss:2.0982323564236935,test loss:2.1916982191494947    
train MAE(mean):1.314814547093153,test MAE(mean):1.3626246392422914


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:8
train loss:2.093411646721246,test loss:2.1882139678115413    
train MAE(mean):1.3112367295167446,test MAE(mean):1.3581751673260332


100%|██████████| 10/10 [02:03<00:00, 12.39s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:9
train loss:2.0893339507065027,test loss:2.185763064876355    
train MAE(mean):1.308298210579455,test MAE(mean):1.3501053224426507


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:10
train loss:2.085646506025793,test loss:2.184273448310858    
train MAE(mean):1.3055799370622634,test MAE(mean):1.34364918451041


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:11
train loss:2.0824000126176183,test loss:2.1832107627789483    
train MAE(mean):1.3031624523657561,test MAE(mean):1.340680758164227


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:12
train loss:2.0793809278488604,test loss:2.1819506732877247    
train MAE(mean):1.3009779825968741,test MAE(mean):1.3389627171206475


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:13
train loss:2.076811406714045,test loss:2.1805046503589542    
train MAE(mean):1.29925112134552,test MAE(mean):1.3361206945332884


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:14
train loss:2.074390149333967,test loss:2.1797490616273327    
train MAE(mean):1.2975865796735286,test MAE(mean):1.334640366742909


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:15
train loss:2.0721968625696214,test loss:2.179208984814971    
train MAE(mean):1.2961712490513324,test MAE(mean):1.3332998918277026


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:16
train loss:2.070161973325176,test loss:2.178460326802473    
train MAE(mean):1.2949586885204316,test MAE(mean):1.3330144106769561


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:17
train loss:2.0682116821387924,test loss:2.177353201883164    
train MAE(mean):1.2936145561970471,test MAE(mean):1.333752314352393


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:18
train loss:2.066538995800328,test loss:2.176972357561679    
train MAE(mean):1.2926680391278267,test MAE(mean):1.3315290336933732


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:19
train loss:2.064893941609664,test loss:2.1763834123287458    
train MAE(mean):1.2916466643792988,test MAE(mean):1.3324084880134464


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:20
train loss:2.063296411580356,test loss:2.1761857760021988    
train MAE(mean):1.2907112068276405,test MAE(mean):1.330968615938723


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:21
train loss:2.0616230600953265,test loss:2.175883200263164    
train MAE(mean):1.2896033580592274,test MAE(mean):1.3315745136818289


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:22
train loss:2.0601380299580816,test loss:2.175665939089415    
train MAE(mean):1.2887130517289043,test MAE(mean):1.3303089641198516


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:23
train loss:2.0586590931440476,test loss:2.1749039310397187    
train MAE(mean):1.2878844432435035,test MAE(mean):1.3310437585124373


100%|██████████| 10/10 [02:04<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:24
train loss:2.057256117968778,test loss:2.174931295779662    
train MAE(mean):1.287111682761669,test MAE(mean):1.3294705575075745


100%|██████████| 10/10 [02:04<00:00, 12.47s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:25
train loss:2.0558153149261105,test loss:2.1743939370313883    
train MAE(mean):1.2862804949017168,test MAE(mean):1.3317075154557825


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:26
train loss:2.0545565442876357,test loss:2.1741578604904146    
train MAE(mean):1.2856328630245328,test MAE(mean):1.3303628515943884


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:27
train loss:2.0531549775739006,test loss:2.1740015874640086    
train MAE(mean):1.2847823487115502,test MAE(mean):1.3287882047438622


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:28
train loss:2.051683148530071,test loss:2.1736654079903066    
train MAE(mean):1.2839750688434244,test MAE(mean):1.3299669802862406


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:29
train loss:2.05020528438481,test loss:2.1737989813197705    
train MAE(mean):1.2831270416631102,test MAE(mean):1.3308264930287004


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:30
train loss:2.048867171895058,test loss:2.1739186471544065    
train MAE(mean):1.2824602314453721,test MAE(mean):1.3301762676715851


100%|██████████| 10/10 [02:05<00:00, 12.52s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:31
train loss:2.0476438360915563,test loss:2.173391536376811    
train MAE(mean):1.2817913801102043,test MAE(mean):1.330807831721306


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:32
train loss:2.046421697523303,test loss:2.173897924627217    
train MAE(mean):1.2811630863103867,test MAE(mean):1.331748006171584


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:33
train loss:2.0451225807177797,test loss:2.173921777601218    
train MAE(mean):1.2805004275202752,test MAE(mean):1.3353900231939555


100%|██████████| 10/10 [02:04<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:34
train loss:2.0436942270564624,test loss:2.1742357882775014    
train MAE(mean):1.2797365885788203,test MAE(mean):1.3339557101780175


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:35
train loss:2.0425332935209197,test loss:2.175143421418408    
train MAE(mean):1.2792139898661374,test MAE(mean):1.3345427015662192


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:36
train loss:2.041250774479377,test loss:2.175405568471724    
train MAE(mean):1.27852486447227,test MAE(mean):1.335325807442665


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:37
train loss:2.039864975280989,test loss:2.1760177484946057    
train MAE(mean):1.2778672131528854,test MAE(mean):1.3332119046780466


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:38
train loss:2.038632908348798,test loss:2.1756304377320967    
train MAE(mean):1.277310899066627,test MAE(mean):1.3347474093174934


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:39
train loss:2.037425488823874,test loss:2.176090367450554    
train MAE(mean):1.2767371868943571,test MAE(mean):1.3397424489605427


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

epoch:40
train loss:2.036250801431457,test loss:2.1761852274527835    
train MAE(mean):1.2761759279590845,test MAE(mean):1.340155569331348
