In [None]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("../../")
print(os.getcwd())

In [None]:
import pandas as pd
import numpy as np
import functools
import operator
import json
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.tensorboard import SummaryWriter
import torch_geometric
import torch_geometric.transforms as T
from torch_geometric.sampler import NegativeSampling
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_scipy_sparse_matrix

from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import StandardScaler

from models.gnn.sage import GraphSAGE
from scripts.train_graph import train_epoch, test

torch.set_printoptions(precision=2, sci_mode=False)
torch.manual_seed(0)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
with open(os.path.join("data/steam", 'graph.pkl'), "rb") as f:
    graph = pd.read_pickle(f)

train_data = graph['train_data']
valid_data = graph['valid_data']

user_shape = train_data['user'].x.shape
app_shape = train_data['app'].x.shape

In [None]:
next(iter(train_loader))[('user', 'recommends', 'app')]['edge_label'].to(torch.float).mean()

In [None]:
arr = mp_matrix.sum(axis=1).getA()
unique_values, counts = np.unique(arr, return_counts=True)
normalized_counts = counts / len(arr)

# Create a dictionary with normalized value counts
value_counts_normalized = {value: count for value, count in zip(unique_values, normalized_counts)}

# Print the result
print(value_counts_normalized)

In [None]:
# real_cols = ['positive_ratio', 'user_reviews', 'price_final', 'price_original', 'discount']

# scaler = StandardScaler()
# app_features_norm = scaler.fit_transform(app_features[real_cols].numpy())

In [None]:
# Dataloader:
#  - user: x->attributes of sampled nodes, n_id->mapping of sampled nodes to ids from whole graph
#  - app: x->attributes of sampled nodes, n_id->mapping of sampled nodes to ids from whole graph
#  - (user recommends app): 
#      edge_index -> sampled edges with batch ids with neighbors
#      edge_label -> labels of edges which will be evaluated, size of batch size
#      e_id -> mapping of sampled edges to ids from whole graph, refers to ?????
#      input_id -> mapping of sampled edges to ids from whole graph, refers to edge_label_index
#      edge_label_index -> edge index, ids of nodes in sampled graph which will be evaluated


# To validate nodes first get sampled nodes ids from edge_label_index, then map them to whole graph
# using n_ids of user and app and then check if such edge exists in dataframe

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = nn.SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True)
        self.conv2 = nn.SAGEConv((hidden_channels, hidden_channels), out_channels, normalize=False)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
    
class Classifier(torch.nn.Module):
    def forward(self, x_user, x_app, edge_label_index):
        x_user = x_user[edge_label_index[0]]
        x_app = x_app[edge_label_index[1]]
        return (x_user * x_app).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, entities, hidden_channels, out_channels, metadata):
        super().__init__()
        
        self.user_emb = torch.nn.Embedding(entities[0].x.shape[0], hidden_channels)
        self.app_emb = torch.nn.Embedding(entities[1].x.shape[0], hidden_channels)
        self.app_lin = torch.nn.Linear(entities[1].x.shape[1], hidden_channels)
        
        self.gnn = GNN(hidden_channels=hidden_channels, out_channels=out_channels)
        self.gnn = nn.to_hetero(self.gnn, metadata=metadata, aggr='sum')
        
        self.clf = Classifier()
        
    def forward(self, batch):  
        x_dict = {
          "user": self.user_emb(batch['user'].n_id),
          "app": self.app_emb(batch['app'].n_id) + self.app_lin(batch['app'].x),
        } 
        
        x_dict = self.gnn(x_dict, batch.edge_index_dict)
        pred = self.clf(
            x_dict["user"],
            x_dict["app"],
            batch['user', 'recommends', 'app'].edge_label_index,
        )
        return pred
    
    def evaluate(self, batch):
        x_dict = {
          "user": self.user_emb(batch['user'].n_id),
          "app": self.app_emb(batch['app'].n_id) + self.app_lin(batch['app'].x),
        } 

        x_dict = self.gnn(x_dict, batch.edge_index_dict)

        return x_dict

def xavier_init(m):
    if isinstance(m, torch.nn.Linear) or isinstance(m, torch_geometric.nn.dense.linear.Linear):
        torch.nn.init.xavier_normal_(m.weight, gain=1.41)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    
model = Model(entities=(train_data['user'], train_data['app']), 
              hidden_channels=32, out_channels=32, metadata=train_data.metadata())
model.apply(xavier_init)
model = model.to(device)

In [None]:
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 15],
    neg_sampling_ratio=5.0,
    edge_label_index=(('user', 'recommends', 'app'), train_data['user', 'recommends', 'app'].edge_label_index),
    edge_label=train_data['user', 'recommends', 'app'].edge_label,
    batch_size=1024,
    shuffle=True,
    drop_last=True
)
valid_loader = LinkNeighborLoader(
    data=valid_data,
    num_neighbors=[20, 15],
    neg_sampling_ratio=5.0,
    edge_label_index=(('user', 'recommends', 'app'), valid_data['user', 'recommends', 'app'].edge_label_index),
    edge_label=valid_data['user', 'recommends', 'app'].edge_label,
    batch_size=1024,
    shuffle=True,
    drop_last=True
)

In [None]:
model = GraphSAGE(
    entities_shapes={"user": user_shape, "app": app_shape},
    hidden_channels=32,
    out_channels=32,
    metadata=train_data.metadata()
).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.RMSprop(params=model.parameters(), lr=1e-4, momentum=0.9)

In [None]:
n_epochs = 10
for epoch in tqdm(range(n_epochs)):
    train_loss, train_roc_auc = train_epoch(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        val_loader=valid_loader,
        device=deviced
    )
    test_loss, test_roc_auc = test(
        model=model,
        criterion=criterion,
        val_loader=valid_loader,
        device=device
    )
    print(f"""Epoch <{epoch}>\ntrain_loss: {train_loss} - train_roc_auc: {train_roc_auc}
test_loss: {test_loss} - test_roc_auc: {test_roc_auc}\n""")

In [None]:
# from datetime import datetime
# torch.save(model.state_dict(), f"models/{model.__class__.__name__}/{datetime.now().strftime('%Y%m%d%H%M%S')}.pth")

In [None]:
p2 = retrieval.RetrievalPrecision(top_k=5)
r2 = retrieval.RetrievalRecall(top_k=2)
ndcg = retrieval.RetrievalNormalizedDCG(top_k=7)

In [None]:
preds = torch.tensor([0.7, 0.8, 0.1, 0.2, 0.4, 0.6, 0.5, 0.9, 0.3, 0.15])
targets = torch.tensor([True, True, False, False, False, False, False, False, False, False])
indices = torch.tensor([0,0,0,0,0,0,0,0,0,0])

mask = torch.ones(preds.shape, dtype=torch.bool)
mask[[4, 5, 7]] = False
print(preds[mask])
print(targets[mask])

In [None]:
preds = torch.tensor([1.0, 0.85, 0.8, 0.7, 0.65])
targets = torch.tensor([False, False, False, False, False])
indices = torch.tensor([0,0,0,0,0])

In [None]:
p2(preds, targets, indexes=indices)

In [None]:
r2(preds, targets, indexes=indices)

In [None]:
ndcg(preds, targets, indexes=indices)

In [None]:
def dcg(rel):
    g = 0.
    for i in range(1,6):
        g+= (2**rel[i-1] - 1)/np.log2(i+1)
    return g

In [None]:
rel = [1,0,1,0,1]
rel_idcg = [1,1,1,0,0]
print(dcg(rel))
print(dcg(rel_idcg))
print(dcg(rel)/dcg(rel_idcg))

In [None]:
torch.tensor([0.9, 0.7, 0.6])

In [None]:
(preds[i][:3] / torch.log2(torch.arange(3)+2)).sum() / (torch.tensor([0.9, 0.7, 0.6]) / torch.log2(torch.arange(3)+2)).sum()

In [None]:
torch.log2(torch.arange(3)+2)

In [None]:
cm = confusion_matrix(y_true.detach().cpu().numpy(), y_pred.detach().cpu().numpy().round())
cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[False, True])
cm_display.plot()
plt.show()

In [None]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    
def load_model(path):
    model = Model(hidden_channels=32, out_channels=32, metadata=train_data.metadata())
    model.load_state_dict(torch.load(path))
    model = model.to(device)
    return model

In [None]:
#save_model(model, "models/gnn_03.pth")

In [None]:
model = load_model("models/gnn_03.pth")

In [None]:
print(nn.summary(model, next(iter(train_loader)).to(device)))

In [None]:
model