# imports

In [1]:
import pickle
from srgnn_pl import SRGNN_model, SRGNN_Map_Dataset, SRGNN_sampler
from utils import fake_parser
import os

from torch.utils.data import DataLoader

import numpy as np

from tqdm import tqdm

In [2]:
import yaml
from sklearn.mixture import GaussianMixture

In [3]:
from math import ceil

In [4]:
import pytorch_lightning as pl
import torch
from utils import split_validation

# data & models loading

In [5]:
run_id='run-20240316_165704-5z65o3op'


In [6]:
with open(f"./wandb/{run_id}/files/config.yaml", "r") as stream:
        config=yaml.safe_load(stream)

keys=list(config.keys())
for k in keys:
    if k not in fake_parser().__dict__.keys():
        del config[k]
    else:
        config[k]=config[k]['value']

opt=fake_parser(**config)
print(opt.__dict__)

{'dataset': 'yoochoose_custom', 'batchSize': 128, 'hiddenSize': 64, 'epoch': 60, 'lr': 0.001, 'lr_dc': 0.1, 'lr_dc_step': 3, 'l2': 1e-05, 'step': 3, 'patience': 10, 'nonhybrid': False, 'validation': True, 'valid_portion': 0.125, 'pretrained_embedings': True, 'unfreeze_epoch': 2}


In [7]:
torch.set_float32_matmul_precision('medium')
model=SRGNN_model.load_from_checkpoint(f"./GNN_master/{run_id.split('-')[-1]}/checkpoints/"+
                                       os.listdir(f"./GNN_master/{run_id.split('-')[-1]}/checkpoints/")[0], opt=opt)

In [8]:
train_data = pickle.load(open('../datasets/' + opt.dataset + '/train.txt', 'rb'))

if opt.dataset == 'diginetica':
    n_node = 43098
elif opt.dataset == 'yoochoose1_64' or opt.dataset == 'yoochoose1_4':
    n_node = 37484

elif opt.dataset == 'yoochoose_custom':
    n_node = 28583
elif opt.dataset == 'yoochoose_custom_augmented':
    n_node = 27809
elif opt.dataset == 'yoochoose_custom_augmented_5050':
    n_node = 27807
else:
    n_node = 310

In [9]:
train_data, valid_data = split_validation(train_data, opt.valid_portion)
del train_data
train_dataset=SRGNN_Map_Dataset(valid_data, shuffle=False)


train_dataloader=DataLoader(train_dataset, 
                            num_workers=os.cpu_count(),  
                            sampler=SRGNN_sampler(train_dataset, opt.batchSize, shuffle=False, drop_last=False)
                            )

data masking start


100%|██████████| 64/64 [00:05<00:00, 11.94it/s]


done masking


In [10]:
train_dataset.length

369965

# get session embeddings

In [11]:
session_emb=[]
full_sessions=[]

model.to('cuda')
for batch in tqdm(train_dataloader, total=train_dataset.length//opt.batchSize):
    batch=[b.to('cuda') for b in batch]
    session_emb.append(model.get_session_embeddings(batch).cpu().detach().numpy())
session_emb=np.concatenate(session_emb)

2891it [00:27, 104.82it/s]                          


In [12]:
#gm=GaussianMixture(n_components=32, n_init=1, init_params='k-means++')
    
#session_labels=gm.fit_predict(session_emb)

#with open(f'./GMMs/gmm_train_{gm.n_components}_{gm.init_params}_{opt.hiddenSize}_{opt.dataset}.gmm', 'wb') as gmm_file:
 #   pickle.dump(gm, gmm_file)

In [14]:
with open(f'./GMMs/gmm_val_32_k-means++_{opt.hiddenSize}_{opt.dataset}.gmm', 'rb') as gmm_file:
    gm=pickle.load(gmm_file)

session_labels=[]
for i in tqdm(range(ceil(session_emb.shape[0]/opt.batchSize))):
    session_labels.append(gm.predict(session_emb[i*opt.batchSize: (i+1)*opt.batchSize]))
session_labels=np.concatenate(session_labels)

100%|██████████| 2891/2891 [00:03<00:00, 865.94it/s]


In [15]:
np.unique(session_labels, return_counts=True)

(array([ 0,  2,  3,  4,  5,  7,  8,  9, 11, 12, 13, 14, 16, 17, 18, 19, 20,
        21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]),
 array([15262,  3400, 12620, 16258, 33931, 16182, 24881,  9763, 10742,
         9150, 10000, 18495, 10551,  5619, 22822, 11449,  9393,  6975,
        11772, 14395, 17025, 14156,  9581, 12228, 13646, 10186,  8755,
        10728]))

In [14]:
from sklearn.manifold import TSNE
import plotly.graph_objects as go

In [15]:
tsne=TSNE(2)
tsne_session_embeddings=tsne.fit_transform(session_emb)

fig = go.Figure()

for label in np.unique(session_labels):
    label_embedding=tsne_session_embeddings[session_labels==label]
    fig.add_trace(go.Scatter(x=label_embedding[:,0], y=label_embedding[:,1], name=str(label), mode='markers'))

fig.update_layout(title='TSNE reduced session embeddings with GM',
                  margin=dict(l=40, r=40, t=40, b=40),
                  width=1000, height=800)
fig.write_html(f'./images/all_train_sessions_{opt.dataset}_{opt.hiddenSize}.html')


In [16]:
del train_dataloader
del train_dataset
train_data = valid_data#pickle.load(open('../datasets/' + opt.dataset  + '/train.txt', 'rb'))

In [19]:
for cluster in tqdm(np.unique(session_labels)):
    idxs=np.arange(session_labels.shape[0])[session_labels==cluster]
    cluster_sessions=[]
    cluster_targets=[]
    for i in idxs:
        cluster_sessions.append(train_data[0][i])
        cluster_targets.append(train_data[1][i])
    with open(f'../datasets/{opt.dataset}/gm_train_splits_{opt.hiddenSize}_{run_id.split('-')[-1]}/train_{cluster}.txt', 'wb') as cluster_file:
        pickle.dump((cluster_sessions, cluster_targets), cluster_file)
        

100%|██████████| 28/28 [00:00<00:00, 57.01it/s]
