In [11]:
import os
import numpy as np
import pandas as pd
from tqdm import tqdm, trange

import torch
from torch_geometric.nn import GAE
from torch_geometric.utils import is_undirected, to_undirected


from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import f1_score
from sklearn.preprocessing import MinMaxScaler

from model.GCN_GRU import GCN_GRU, Decoder
from utils import *

In [2]:
shop_col = 'stonc_6_label'
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

embedding_dim = 64
entity_dim = 8
epochs = 100
learning_rate = 0.001
batch_size = 1000

weights_path = 'GCN_GRU_Link_Prediction'

In [3]:
data_path = './data'

chid_dict_file = 'sample/sample_50k_idx_map.npy'
cdtx_file = 'sample/sample_50k_cdtx.csv'
cust_file = 'preprocessed/df_cust_log.csv'

chid_path = os.path.join(data_path, chid_dict_file)
cdtx_path = os.path.join(data_path, cdtx_file)
cust_path = os.path.join(data_path, cust_file)

In [4]:
df_cdtx, df_cust, n_users, n_shops = read_sample_files(cdtx_path,
                                                       cust_path,
                                                       chid_path,
                                                       shop_col)

Start reading cdtx file...
Finish reading cdtx file !
Start reading cust file...
Finish reading cust file !
Start mapping encoding...
Finish !!


In [5]:
list_months = sorted(df_cdtx.csmdt.unique())
ma = np.log(df_cdtx.groupby(['chid', 'csmdt']).objam.sum().max()+1)

ignore_cols = ['chid', 'data_dt']
category_cols = [f'category_{i+1}' for i in range(6)]
numeric_cols = list(set(df_cust.columns) - set(category_cols) - set(ignore_cols))

emb_dims = list(df_cust[category_cols].nunique())

input_dim = len(category_cols)*entity_dim + len(numeric_cols)

In [6]:
edge_dict = {}
edge_weights_dict = {}
for month in tqdm(list_months):
    edges = df_cdtx[df_cdtx.csmdt==month].groupby(['chid', shop_col]).objam.sum()
    edge_pairs = np.stack([np.array(i) for i in edges.index]).T
    edge_pairs = torch.LongTensor(edge_pairs)
    
    edge_weights = np.log(edges.values+1)
    edge_weights = torch.Tensor(edge_weights)
    
    if not is_undirected(edge_pairs):
        edge_pairs = torch.cat([edge_pairs, edge_pairs[[1,0],:]], -1)
        edge_weights = edge_weights.repeat(2)
    
    edge_dict[month] = edge_pairs
    edge_weights_dict[month] = edge_weights

100%|██████████| 25/25 [00:18<00:00,  1.39it/s]


In [7]:
model = GAE(GCN_GRU(input_dim, embedding_dim, entity_dim, emb_dims), Decoder(embedding_dim, 6)).to(device)

x_num = []
x_cat = []
y = []
for i in list_months:
    cust_num_features = df_cust[df_cust.data_dt==i][numeric_cols].to_numpy()
    cust_num_features = torch.Tensor(cust_num_features)
    cust_cat_features = df_cust[df_cust.data_dt==i][category_cols].to_numpy()
    cust_cat_features = torch.LongTensor(cust_cat_features)
    
    shop_num_features = torch.zeros(n_shops, cust_num_features.shape[1])
    shop_cat_features = torch.zeros(n_shops, cust_cat_features.shape[1]).long()
    
    temp_y = np.e ** (df_cust[df_cust.data_dt==i]['objam'].to_numpy()*ma) -1
    ind1 = (temp_y >0) & (temp_y <= 1e4)
    ind2 = (temp_y >1e4) & (temp_y <= 5e4)
    ind3 = (temp_y >5e4) & (temp_y <= 1e5)
    ind4 = (temp_y >1e5) & (temp_y <= 3e5)
    ind5 = (temp_y >3e5)

    temp_y[ind1] = 1
    temp_y[ind2] = 2
    temp_y[ind3] = 3
    temp_y[ind4] = 4
    temp_y[ind5] = 5
    
    x_num.append(torch.cat([cust_num_features, shop_num_features], 0).to(device))
    x_cat.append(torch.cat([cust_cat_features, shop_cat_features], 0).to(device))
    y.append(torch.LongTensor(temp_y).to(device))
    
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
pos_edge_index = [i.to(device) for i in edge_dict.values()]
pos_edge_weigts = [i.to(device) for i in edge_weights_dict.values()]

In [13]:
model.encoder.load_state_dict(torch.load(weights_path))
model.eval()

criterion = torch.nn.CrossEntropyLoss()

In [16]:
def train():
    model.train()
    train_output = []
    train_y = np.array([])
    for i in tqdm(range(10)):
        
        train_dataset = TensorDataset(y[i+12])
        train_loader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size)
        
        for j, true_y in enumerate(train_loader):
            optimizer.zero_grad()
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[j*batch_size:(j+1)*batch_size])
            loss = criterion(output, true_y[0])
            train_output.append(output.cpu().detach().numpy())
            train_y = np.concatenate([train_y, true_y[0].cpu().detach().numpy().reshape(-1)])
        
            loss.backward(retain_graph=True)
            optimizer.step()
        
            
    return loss/10, train_output, train_y

In [17]:
def test():
    model.eval()
    test_output = []
    test_y = np.array([])
    
    for i in range(10,12):

        with torch.no_grad():
            z = model.encode(x_cat[i:i+12], x_num[i:i+12], pos_edge_index[i:i+12], pos_edge_weigts[i:i+12])
            output = model.decode(z[:n_users])
            test_output.append(output.cpu().detach().numpy())
            test_y = np.concatenate([test_y, y[i+12].cpu().detach().numpy().reshape(-1)])
            
    return test_output, test_y

In [18]:
for epoch in range(epochs):
    loss, train_output, train_y  = train()
    
    test_output, test_y = test()
    
    print('train',f1_score(train_y, np.argmax(np.concatenate(train_output,0),-1), average='macro'))
    print('test', f1_score(test_y, np.argmax(np.concatenate(test_output,0),-1), average='macro'))

100%|██████████| 10/10 [02:05<00:00, 12.54s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.30485669054512554
test 0.3360582610839118


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3419451956118913
test 0.34917608987324184


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.35577524146969486
test 0.35349566844522484


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.36349748189542264
test 0.36270039239731094


100%|██████████| 10/10 [02:04<00:00, 12.47s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3687473946407445
test 0.36999262114154047


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3739216867921143
test 0.3800125297644466


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3787077442217481
test 0.3862914871221761


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3826489268027309
test 0.3915973461217403


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.38646458795016264
test 0.39381336304012526


100%|██████████| 10/10 [02:03<00:00, 12.37s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.38941236892170794
test 0.3952816764161735


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.392414289186926
test 0.39660031168893584


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3952947946105685
test 0.39749565782739343


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.3976346902073326
test 0.39723471662994897


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.400605810639205
test 0.3991069934370937


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.40303512810176284
test 0.39958522588652395


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4050419584353316
test 0.4006720764925751


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4072001004630941
test 0.4014156998363371


100%|██████████| 10/10 [02:05<00:00, 12.55s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.40892580002304874
test 0.40198719884075


100%|██████████| 10/10 [02:03<00:00, 12.37s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4107661573866273
test 0.40200589399837683


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4124677070582033
test 0.4016347052414779


100%|██████████| 10/10 [02:04<00:00, 12.50s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.41428963771039234
test 0.4031259031455021


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.41655316557637095
test 0.4035816010366391


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.41873283531574046
test 0.404821712081338


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.42072994335384045
test 0.40506400328554903


100%|██████████| 10/10 [02:04<00:00, 12.40s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.42275266531056704
test 0.4055971057849252


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4246438665990178
test 0.4071346541413212


100%|██████████| 10/10 [02:03<00:00, 12.40s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.42658860313566765
test 0.40873594685795345


100%|██████████| 10/10 [02:05<00:00, 12.52s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4282496417263982
test 0.40930293912334353


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.42948487703029575
test 0.4078336627496692


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4303316062573515
test 0.4002624263631635


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.43189440340889035
test 0.3867176246586385


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4338745678796258
test 0.389712597145153


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4362212518646322
test 0.39169102778168813


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.43819877675582547
test 0.3940686314251314


100%|██████████| 10/10 [02:04<00:00, 12.47s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4393624966594661
test 0.40298012755389995


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.44096485208342195
test 0.40190981588599156


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.44365021317136794
test 0.4022319127240643


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.44541984767845255
test 0.3980179277217175


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.44728614704648084
test 0.3982779121136131


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4482080803500619
test 0.39813095019246586


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.449978783113644
test 0.39915531621070954


100%|██████████| 10/10 [02:03<00:00, 12.40s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.45104729906749874
test 0.39906796677566675


100%|██████████| 10/10 [02:05<00:00, 12.53s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4525924764056492
test 0.3989264177031772


100%|██████████| 10/10 [02:04<00:00, 12.48s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4546737570853325
test 0.3992082088541365


100%|██████████| 10/10 [02:04<00:00, 12.43s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4564262421194584
test 0.4002360657937723


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4576588116438694
test 0.40089668905731735


100%|██████████| 10/10 [02:04<00:00, 12.42s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4584194501387608
test 0.4005223313783623


100%|██████████| 10/10 [02:04<00:00, 12.41s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.45952289749275915
test 0.4013621369662199


100%|██████████| 10/10 [02:05<00:00, 12.59s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.461039198579065
test 0.4004803743308977


100%|██████████| 10/10 [02:04<00:00, 12.49s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4619682322396381
test 0.4002629548675176


100%|██████████| 10/10 [02:04<00:00, 12.44s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4630808772783621
test 0.4012557075602969


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4627263009817655
test 0.4010036010840268


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.46416659580750425
test 0.39995367738692655


100%|██████████| 10/10 [02:04<00:00, 12.46s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4646281274510384
test 0.40013861996961264


100%|██████████| 10/10 [02:06<00:00, 12.67s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4659654304963056
test 0.39911307566773185


100%|██████████| 10/10 [02:03<00:00, 12.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4664916487078383
test 0.40040153919151616


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.46834107457735397
test 0.3974009809066083


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47003657241828534
test 0.399959677902692


100%|██████████| 10/10 [02:05<00:00, 12.57s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47200196197016875
test 0.398083654467603


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47315285196073
test 0.398313251426417


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47497561076697087
test 0.3995050168767189


100%|██████████| 10/10 [02:03<00:00, 12.34s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4758196909373404
test 0.3974259250947851


100%|██████████| 10/10 [02:03<00:00, 12.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.477167053449612
test 0.39567079820826817


100%|██████████| 10/10 [02:03<00:00, 12.39s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47707248405799746
test 0.39458017796351214


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4785337910865984
test 0.3933653641007803


100%|██████████| 10/10 [02:03<00:00, 12.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4776584956782497
test 0.3934938236454766


100%|██████████| 10/10 [02:03<00:00, 12.34s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47630897904717084
test 0.39537106075282463


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47836341246845654
test 0.39513534603253025


100%|██████████| 10/10 [02:05<00:00, 12.51s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.47931794904310104
test 0.39183869800212207


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4809710291644456
test 0.3967870497461397


100%|██████████| 10/10 [02:03<00:00, 12.37s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4817759863580835
test 0.3955991702376534


100%|██████████| 10/10 [02:03<00:00, 12.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4829608380338006
test 0.39170302980458804


100%|██████████| 10/10 [02:03<00:00, 12.35s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.48429227348299514
test 0.39566724931312036


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.48534994976165047
test 0.39629020413747224


100%|██████████| 10/10 [02:03<00:00, 12.32s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.48492548732055796
test 0.39350555246726365


100%|██████████| 10/10 [02:03<00:00, 12.38s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.48734066829305517
test 0.39568932439469506


100%|██████████| 10/10 [02:03<00:00, 12.36s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4884638006212448
test 0.3948641581083723


100%|██████████| 10/10 [02:03<00:00, 12.37s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

train 0.4888880939203451
test 0.3941913860920563


 80%|████████  | 8/10 [01:45<00:26, 13.23s/it]


KeyboardInterrupt: 