# imports

In [1]:
import pickle
from srgnn_pl import SRGNN_model, SRGNN_Map_Dataset, SRGNN_sampler
from utils import fake_parser
import os

from torch.utils.data import DataLoader

import numpy as np

from tqdm import tqdm

In [2]:
import yaml
from sklearn.mixture import GaussianMixture

In [3]:
from math import ceil

# data & models loading

In [4]:
run_id='run-20240213_043223-0zuvfc9x'

In [5]:
with open(f"./wandb/{run_id}/files/config.yaml", "r") as stream:
        config=yaml.safe_load(stream)

keys=list(config.keys())
for k in keys:
    if k not in fake_parser().__dict__.keys():
        del config[k]
    else:
        config[k]=config[k]['value']

opt=fake_parser(**config)
print(opt.__dict__)

{'dataset': 'yoochoose_custom', 'batchSize': 128, 'hiddenSize': 64, 'epoch': 60, 'lr': 0.001, 'lr_dc': 0.1, 'lr_dc_step': 3, 'l2': 1e-05, 'step': 3, 'patience': 6, 'nonhybrid': False, 'validation': True, 'valid_portion': 0.2, 'pretrained_embedings': True, 'unfreeze_epoch': 2}


In [6]:
model=SRGNN_model.load_from_checkpoint(f"./GNN_master/{run_id.split('-')[-1]}/checkpoints/"+
                                       os.listdir(f"./GNN_master/{run_id.split('-')[-1]}/checkpoints/")[0], opt=opt)

In [8]:
with open(f'./gmm_better_32_k-means++_64.gmm', 'rb') as gmm_file:
    gm=pickle.load(gmm_file)


In [9]:
train_data = pickle.load(open('../datasets/' + opt.dataset  + '/train.txt', 'rb'))

if opt.dataset == 'diginetica':
    n_node = 43098
elif opt.dataset == 'yoochoose1_64' or opt.dataset == 'yoochoose1_4':
    n_node = 37484
elif opt.dataset == 'yoochoose_custom':
    n_node = 28583
elif opt.dataset == 'yoochoose_custom_augmented':
    n_node = 27809
elif opt.dataset == 'yoochoose_custom_augmented_5050':
    n_node = 27807
else:
    n_node = 310


In [10]:
train_dataset=SRGNN_Map_Dataset(train_data, shuffle=False)
del train_data

train_dataloader=DataLoader(train_dataset, 
                            num_workers=os.cpu_count(),  
                            sampler=SRGNN_sampler(train_dataset, opt.batchSize, shuffle=False, drop_last=False)
                            )

data masking start
data masking 1
data masking 2
data masking 3
done masking


# get session embeddings

In [11]:
session_emb=[]
full_sessions=[]
for batch in tqdm(train_dataloader, total=train_dataset.length//opt.batchSize):
    batch=[b.to('cuda') for b in batch]
    session_emb.append(model.get_session_embeddings(batch).cpu().detach().numpy())
session_emb=np.concatenate(session_emb)

23123it [02:11, 175.35it/s]                           


In [12]:
np.save(f'../datasets/yoochoose_custom/gm_all_splits_{opt.hiddenSize}/session_embeddings.npy', session_emb)

In [13]:
session_labels=[]
for i in tqdm(range(ceil(session_emb.shape[0]/opt.batchSize))):
    session_labels.append(gm.predict(session_emb[i*opt.batchSize: (i+1)*opt.batchSize]))

100%|██████████| 23123/23123 [00:26<00:00, 860.25it/s]


In [14]:
session_labels=np.concatenate(session_labels)

In [15]:
np.unique(session_labels, return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 array([ 53614, 103865, 267685,   6932,  50273, 115312, 109676, 430829,
        110049, 115298, 123759,  15358,  99744,  75390, 144375,  40777,
         36418,  39866, 197015,    960, 174476,  39363, 120382,   3374,
          9112,  43901,  90290,  55582,  77412,  40844, 167788]))

In [16]:
del train_dataloader
del train_dataset
train_data = pickle.load(open('../datasets/' + opt.dataset  + '/train.txt', 'rb'))

In [17]:
for cluster in tqdm(np.unique(session_labels)):
    idxs=np.arange(session_labels.shape[0])[session_labels==cluster]
    cluster_sessions=[]
    cluster_targets=[]
    for i in idxs:
        cluster_sessions.append(train_data[0][i])
        cluster_targets.append(train_data[1][i])
    with open(f'../datasets/{opt.dataset}/gm_all_splits_{opt.hiddenSize}/train_{cluster}.txt', 'wb') as cluster_file:
        pickle.dump((cluster_sessions, cluster_targets), cluster_file)
        

100%|██████████| 31/31 [00:01<00:00, 19.05it/s]
