In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir("../../")
print(os.getcwd())

C:\Users\Milosz\Desktop\python\thesis-recsys


In [2]:
import pandas as pd
import numpy as np
import functools
import operator
import json
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch_geometric
import torch_geometric.transforms as T
from torch_geometric import nn
from torch_geometric.sampler import NegativeSampling
from torch_geometric.loader import LinkNeighborLoader, NeighborLoader
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_scipy_sparse_matrix

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import StandardScaler

from src.utils import *

torch.set_printoptions(precision=2, sci_mode=False)
torch.manual_seed(715037601397000) # used to train gnn_02.pth

<torch._C.Generator at 0x28352e415b0>

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [4]:
app_meta = load_data_from_csv("data/graph_app_tags_stacked.csv")

In [5]:
app_features = np.zeros((1797, 425), dtype=np.float32)

for i, (_, x) in enumerate(app_meta['tags_id'].iteritems()):
    x = np.fromstring(x[1:-1], dtype=int, sep=',')
    app_features[i, x] = 1

  for i, (_, x) in enumerate(app_meta['tags_id'].iteritems()):


In [6]:
#app_features = load_data_from_csv("data/graph_app_features.csv")

In [7]:
# real_cols = ['positive_ratio', 'user_reviews', 'price_final', 'price_original', 'discount']

# scaler = StandardScaler()
# app_features_norm = scaler.fit_transform(app_features[real_cols].numpy())

In [8]:
def load_graph(df: pd.DataFrame, n_users: int, n_items: int) -> HeteroData:
    """
    Loads a graph data structure from a pandas DataFrame.
    
    Parameters:
        - df (pd.DataFrame): The input DataFrame containing the graph data.

    Returns:
        - HeteroData: A heterogeneous graph data object representing the input graph.

    Example:
        >>> import pandas as pd
        >>> df = pd.DataFrame({'user_id': [1, 2, 3], 'app_id': [4, 5, 6], 'is_recommended': [1, 0 ,1]})
        >>> graph = load_graph(df)
    """
    
    data = HeteroData()
    
    data['user'].x = torch.ones(n_users, 1)
    data['user'].n_id = torch.arange(n_users)
    
    data['app'].x = torch.from_numpy(app_features)
    data['app'].n_id = torch.arange(n_items)
    
    edge_index = torch.tensor([df['user_id'].values, df['app_id'].values])
    edge_label = torch.tensor(df['is_recommended'].values, dtype=torch.long)

    data['user', 'recommends', 'app'].edge_index = edge_index
    data['user', 'recommends', 'app'].edge_label = edge_label
    
    return data

In [9]:
def transform_graph(data: HeteroData) -> HeteroData:
    """
    Applies a transformation to a heterogeneous graph data object.

    Parameters:
        data: The input graph data object to be transformed.

    Returns:
        HeteroData: A new heterogeneous graph data object resulting from the transformation.

    Example:
        >>> transformed_data = transform_graph(data)
    """
    transform = T.Compose([T.ToUndirected()])
    return transform(data)

In [10]:
def init_edge_loader(data: HeteroData, **kwargs) -> NeighborLoader:
    """
    Initializes a neighbor loader for edge-based data in a heterogeneous graph.
    Firstly we sample `batch_size` edges and then sample at most `num_neighbors[0]`
    neighboring edges at first hop and at most `num_neighbors[1]` at second hop. 
    Value returned by next(iter(loader)) is a subgraph of `data` graph containing
    only sampled edges and congruent nodes.

    Args:
        data (HeteroData): The input heterogeneous graph data object.
        **kwargs: Additional keyword arguments for configuring the loader.

    Returns:
        NeighborLoader: A neighbor loader for the specified edge-based data.

    Example:
        >>> loader = init_edge_loader(data, num_neighbors=5, neg_sampl=0.2, bs=32, shuffle=True)
    """
    
    eli = data['user', 'recommends', 'app'].edge_label_index
    el = data['user', 'recommends', 'app'].edge_label
    
    loader = LinkNeighborLoader(
        data=data,
        num_neighbors=kwargs['num_neighbors'],
        neg_sampling_ratio=kwargs['neg_sampl'],
        edge_label_index=(('user', 'recommends', 'app'), eli),
        edge_label=el,
        batch_size=kwargs['bs'],
        shuffle=kwargs['shuffle'],
    )
    return loader

In [11]:
def get_sparse_adj_matr(data):
    # Extract sparse adjacency matrix with message passing edges
    mp_edges = data['user', 'recommends', 'app']['edge_index']
    mp_matrix = to_scipy_sparse_matrix(mp_edges, num_nodes=n_users).tocsr()

    # Extract sparse adjacency matrix with validation edges
    true_mask = data['user', 'recommends', 'app']['edge_label'].nonzero().flatten()
    val_edges = data['user', 'recommends', 'app']['edge_label_index'][:, true_mask]
    val_matrix = to_scipy_sparse_matrix(val_edges, num_nodes=n_users).tocsr()
    
    return mp_matrix, val_matrix

In [12]:
train_df = load_data_from_csv("data/graph_train.csv")
test_df = load_data_from_csv("data/graph_test.csv")

In [13]:
df = pd.concat([train_df, test_df])
n_users, n_items = df.user_id.nunique(), df.app_id.nunique()
data = load_graph(df, n_users, n_items)
data = transform_graph(data)

  edge_index = torch.tensor([df['user_id'].values, df['app_id'].values])


In [14]:
random_split = T.RandomLinkSplit(
    num_val=0.3,
    num_test=0.0,
    add_negative_train_samples=False,
    neg_sampling_ratio=2.0,
    disjoint_train_ratio=0.3,
    edge_types=('user', 'recommends', 'app'),
    rev_edge_types=('app', 'rev_recommends', 'user')
)
train_data, val_data, _ = random_split(data)

In [15]:
train_loader = init_edge_loader(train_data, num_neighbors=[20, 10], neg_sampl=2.0, bs=1024, shuffle=True, drop_last=True)
val_loader = init_edge_loader(val_data, num_neighbors=[20, 10], neg_sampl=2.0, bs=256, shuffle=False, drop_last=True)

In [16]:
mp_matrix, val_matrix = get_sparse_adj_matr(val_data)

In [17]:
# Dataloader:
#  - user: x->attributes of sampled nodes, n_id->mapping of sampled nodes to ids from whole graph
#  - app: x->attributes of sampled nodes, n_id->mapping of sampled nodes to ids from whole graph
#  - (user recommends app): 
#      edge_index -> sampled edges with batch ids with neighbors
#      edge_label -> labels of edges which will be evaluated, size of batch size
#      e_id -> mapping of sampled edges to ids from whole graph, refers to ?????
#      input_id -> mapping of sampled edges to ids from whole graph, refers to edge_label_index
#      edge_label_index -> edge index, ids of nodes in sampled graph which will be evaluated


# To validate nodes first get sampled nodes ids from edge_label_index, then map them to whole graph
# using n_ids of user and app and then check if such edge exists in dataframe

In [18]:
def train_fn(train_data: HeteroData, test_data: HeteroData):
    pass

In [19]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = nn.SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True)
        self.conv2 = nn.SAGEConv((hidden_channels, hidden_channels), out_channels, normalize=False)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
    
class Classifier(torch.nn.Module):
    def forward(self, x_user, x_app, edge_label_index):
        x_user = x_user[edge_label_index[0]]
        x_app = x_app[edge_label_index[1]]
        return (x_user * x_app).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, metadata):
        super().__init__()
        
        self.user_emb = torch.nn.Embedding(n_users, hidden_channels)
        self.app_emb = torch.nn.Embedding(n_items, hidden_channels)
        self.app_lin = torch.nn.Linear(425, hidden_channels)
        
        self.gnn = GNN(hidden_channels=hidden_channels, out_channels=out_channels)
        self.gnn = nn.to_hetero(self.gnn, metadata=metadata, aggr='sum')
        
        self.clf = Classifier()
        
    def forward(self, batch):  
        x_dict = {
          "user": self.user_emb(batch['user'].n_id),
          "app": self.app_emb(batch['app'].n_id) + self.app_lin(batch['app'].x),
        } 
        
        x_dict = self.gnn(x_dict, batch.edge_index_dict)
        pred = self.clf(
            x_dict["user"],
            x_dict["app"],
            batch['user', 'recommends', 'app'].edge_label_index,
        )
        return pred
    
    def evaluate(self, batch):
        x_dict = {
          "user": self.user_emb(batch['user'].n_id),
          "app": self.app_emb(batch['app'].n_id) + self.app_lin(batch['app'].x),
        } 

        x_dict = self.gnn(x_dict, batch.edge_index_dict)

        return x_dict

def xavier_init(m):
    if isinstance(m, torch.nn.Linear) or isinstance(m, torch_geometric.nn.dense.linear.Linear):
        torch.nn.init.xavier_normal_(m.weight, gain=1.41)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    
model = Model(hidden_channels=32, out_channels=32, metadata=train_data.metadata())
model.apply(xavier_init)
model = model.to(device)

In [20]:
criterion = torch.nn.BCEWithLogitsLoss()
#optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-2)
#optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-1, momentum=0.9)
optimizer = torch.optim.RMSprop(params=model.parameters(), lr=0.001, momentum=0.9)
#writer = SummaryWriter()

In [21]:
def train(n_epochs, print_loss=500):
    model.train()
    
    for epoch in range(n_epochs):
        running_loss = 0.
        for i_batch, batch in enumerate(tqdm(train_loader)):
            batch = batch.to(device)
            
            y_pred = model(batch)
            y_true = batch['user', 'recommends', 'app'].edge_label
            loss = criterion(y_pred, y_true.float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if not ((i_batch+1) % print_loss):
                last_loss = running_loss / print_loss
                #writer.add_scalar("Loss/train", last_loss, epoch*len(train_loader) + i_batch + 1)
                score = roc_auc_score(y_true.detach().cpu().numpy(), y_pred.detach().sigmoid().cpu().numpy())
                print(f"batch <{i_batch}> - loss: {last_loss} - roc_auc score: {score}")
                running_loss = 0.
                
                #test_loss, test_roc_auc = test()
                #print(f"\tTest loss: {test_loss} \t Test ROC AUC: {test_roc_auc}")
                
                print(f"Metrics: {evaluate_nn(model, mp_matrix, val_matrix, k=10)}")
            
        print(f"Epoch: {epoch}, Loss: {running_loss / len(train_loader):.4f}")

In [22]:
def precision_k(reco_relevance, k=10):
    # k = reco_relevance.shape[1]
    # return (reco_relevance.sum(axis=1) / k).mean()
    return reco_relevance.mean()

def mean_average_prec(reco_relevance):
    K = reco_relevance.shape[1]
    
    mean_ap = 0.0
    for k in range(1, K+1):
        mean_ap += prec_k(reco_relevance[:, :k]) # DODAC MNOŻNIK 1/0 GDY ITEM JEST RELEWANTNY!!!
    return mean_ap / K

def recall_k(reco_relevance, relevance, k=10):
    sum_relevant = relevance.sum(axis=1)
    return (reco_relevance.sum(axis=1) / sum_relevant).mean()

@torch.no_grad()
def generate_embeddings(model, data):
    model.eval()
    data = data.to(device)
    x_dict = model.evaluate(data)
    return x_dict

@torch.no_grad()
def recommend_k(user_emb, item_emb, past_interactions=None, k=10, user_batch_size=1000):
    def remove_past_interactions(prob, user_batch):
        id_x = np.repeat(np.arange(user_batch.shape[0]), np.diff(past_interactions[user_batch].indptr))
        id_y = past_interactions[user_batch].indices
        prob[id_x, id_y] = -torch.inf
        return prob
    
    recommended_batches = []
    user_batches = torch.arange(user_emb.shape[0]).split(user_batch_size)
    for user_batch in user_batches:
        prob = (user_emb[user_batch] @ item_emb.T).sigmoid()
        prob = remove_past_interactions(prob, user_batch)
        recommended_batches.append(prob.topk(k, 1)[1])
    
    recommendations = torch.cat(recommended_batches, 0)
    return recommendations

def recommendation_relevance(recommendations, ground_truth):
    """
    Computes the relevance matrix of recommended items based on ground truth data.

    This function takes a matrix of recommended items and a ground truth sparse matrix, and calculates
    binary relevance of recommended items for each user. The relevance is determined by
    comparing the recommended items with the actual items in the ground truth.

    Args:
        recommendations (numpy.ndarray): A 2D matrix of shape (n_users, k) where k is the number of 
            recommended items per user. Each row contains indices representing the recommended 
            items for a user.
        ground_truth (scipy.csr_matrix): A sparse matrix of shape (n_users, n_items). The matrix 
            contains binary values indicating whether an item is relevant (1) or not (0) for each user.

    Returns:
        numpy.matrix: A 2D matrix of shape (n_users, k) containing the relevance scores of the
        recommended items for each user.
        
    Raises:
        ValueError: If the dimensions of 'recommendations' and 'ground_truth' do not match or
            are incompatible for matrix operations.
    """
    n_users, n_items = ground_truth.shape
    k = recommendations.shape[1]
    
    if recommendations.shape[0] != n_users:
        raise ValueError("Number of users in 'recommendations' should match 'ground_truth'.")
    
    user_idx = np.repeat(np.arange(n_users), k)
    item_idx = recommendations.flatten()
    relevance = ground_truth[user_idx, item_idx].reshape((n_users, k))  # get values under arrays of indices 
                                                                        # (user_idx and item_idx) from ground truth
    relevance_mask = np.asarray((ground_truth.sum(axis=1) != 0)).ravel()
    
    return relevance, relevance_mask

def evaluate_nn(model, mp_matrix, val_matrix, k):
    x_emb = generate_embeddings(model, val_data)
    recommendations = recommend_k(x_emb['user'], x_emb['app'], past_interactions=mp_matrix, 
                                  k=10, user_batch_size=10000).cpu().numpy()
    reco_relevance, relevance_mask = recommendation_relevance(recommendations, val_matrix)
    
    prec_k = precision_k(reco_relevance[relevance_mask], k)
    return {f"precision@{k}": prec_k}

In [23]:
@torch.no_grad()
def test():
    model.eval()
    running_loss = 0.
    preds, ground_truths = [], []

    for i_batch, batch in enumerate(val_loader):
        batch = batch.to(device)
        y_pred = model(batch)
        y_true = batch['user', 'recommends', 'app'].edge_label
        loss = criterion(y_pred, y_true.float())
        
        preds.append(y_pred)
        ground_truths.append(y_true)

        running_loss += loss.item()
        
    pred = torch.cat(preds, dim=0).cpu().numpy()
    ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
    
    test_loss = running_loss / len(test_loader)
    test_score = roc_auc_score(ground_truth, pred)

    return test_loss, test_score

In [24]:
train(n_epochs=5, print_loss=100)

  5%|████▍                                                                                  | 99/1924 [01:05<18:29,  1.65it/s]

batch <99> - loss: 0.4053977537155151 - roc_auc score: 0.907811164855957


  5%|████▎                                                                               | 100/1924 [01:44<6:07:54, 12.10s/it]

Metrics: {'precision@10': 0.016886577}


 10%|████████▉                                                                             | 199/1924 [02:52<19:40,  1.46it/s]

batch <199> - loss: 0.3609964242577553 - roc_auc score: 0.9138555526733398


 10%|████████▋                                                                           | 200/1924 [03:34<6:12:31, 12.97s/it]

Metrics: {'precision@10': 0.018438745}


 16%|█████████████▎                                                                        | 299/1924 [04:43<18:24,  1.47it/s]

batch <299> - loss: 0.35242858618497847 - roc_auc score: 0.9105024337768555


 16%|█████████████                                                                       | 300/1924 [05:23<5:36:19, 12.43s/it]

Metrics: {'precision@10': 0.019079002}


 21%|█████████████████▊                                                                    | 399/1924 [06:28<16:19,  1.56it/s]

batch <399> - loss: 0.3458378195762634 - roc_auc score: 0.9144563674926758


 21%|█████████████████▍                                                                  | 400/1924 [07:06<4:58:50, 11.77s/it]

Metrics: {'precision@10': 0.019382125}


 26%|██████████████████████▎                                                               | 499/1924 [08:11<15:47,  1.50it/s]

batch <499> - loss: 0.345595056116581 - roc_auc score: 0.9209170341491699


 26%|█████████████████████▊                                                              | 500/1924 [08:50<4:43:04, 11.93s/it]

Metrics: {'precision@10': 0.019986866}


 31%|██████████████████████████▊                                                           | 599/1924 [09:56<14:26,  1.53it/s]

batch <599> - loss: 0.343881219625473 - roc_auc score: 0.9126307964324951


 31%|██████████████████████████▏                                                         | 600/1924 [10:34<4:22:01, 11.87s/it]

Metrics: {'precision@10': 0.020101314}


 36%|███████████████████████████████▏                                                      | 699/1924 [11:43<15:00,  1.36it/s]

batch <699> - loss: 0.34166058629751206 - roc_auc score: 0.9093236923217773


 36%|██████████████████████████████▌                                                     | 700/1924 [12:21<3:59:01, 11.72s/it]

Metrics: {'precision@10': 0.020660615}


 42%|███████████████████████████████████▋                                                  | 799/1924 [13:43<15:28,  1.21it/s]

batch <799> - loss: 0.3389466980099678 - roc_auc score: 0.9246869087219238


 42%|██████████████████████████████████▉                                                 | 800/1924 [14:14<3:02:54,  9.76s/it]

Metrics: {'precision@10': 0.020751778}


 47%|████████████████████████████████████████▏                                             | 899/1924 [15:37<14:20,  1.19it/s]

batch <899> - loss: 0.33873935252428056 - roc_auc score: 0.9161767959594727


 47%|███████████████████████████████████████▎                                            | 900/1924 [16:08<2:46:42,  9.77s/it]

Metrics: {'precision@10': 0.02085442}


 52%|████████████████████████████████████████████▋                                         | 999/1924 [17:33<13:36,  1.13it/s]

batch <999> - loss: 0.33743521839380264 - roc_auc score: 0.9171428680419922


 52%|███████████████████████████████████████████▏                                       | 1000/1924 [18:03<2:29:03,  9.68s/it]

Metrics: {'precision@10': 0.021156507}


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:30<12:02,  1.14it/s]

batch <1099> - loss: 0.33563955187797545 - roc_auc score: 0.920313835144043


 57%|███████████████████████████████████████████████▍                                   | 1100/1924 [20:01<2:13:39,  9.73s/it]

Metrics: {'precision@10': 0.021271143}


 62%|████████████████████████████████████████████████████▉                                | 1199/1924 [21:28<10:50,  1.11it/s]

batch <1199> - loss: 0.3348745614290237 - roc_auc score: 0.9276609420776367


 62%|███████████████████████████████████████████████████▊                               | 1200/1924 [21:58<1:53:28,  9.40s/it]

Metrics: {'precision@10': 0.02116742}


 68%|█████████████████████████████████████████████████████████▍                           | 1299/1924 [23:25<08:44,  1.19it/s]

batch <1299> - loss: 0.33492648005485537 - roc_auc score: 0.9166660308837891


 68%|████████████████████████████████████████████████████████                           | 1300/1924 [23:51<1:25:35,  8.23s/it]

Metrics: {'precision@10': 0.021495335}


 73%|█████████████████████████████████████████████████████████████▊                       | 1399/1924 [25:15<07:21,  1.19it/s]

batch <1399> - loss: 0.33591071784496307 - roc_auc score: 0.9225974082946777


 73%|████████████████████████████████████████████████████████████▍                      | 1400/1924 [25:40<1:11:25,  8.18s/it]

Metrics: {'precision@10': 0.021725923}


 78%|██████████████████████████████████████████████████████████████████▏                  | 1499/1924 [27:05<06:11,  1.14it/s]

batch <1499> - loss: 0.33532076865434646 - roc_auc score: 0.9238901138305664


 78%|████████████████████████████████████████████████████████████████▋                  | 1500/1924 [27:33<1:04:58,  9.19s/it]

Metrics: {'precision@10': 0.021575866}


 83%|██████████████████████████████████████████████████████████████████████▋              | 1599/1924 [28:59<04:35,  1.18it/s]

batch <1599> - loss: 0.33261684387922286 - roc_auc score: 0.9171504974365234


 83%|██████████████████████████████████████████████████████████████████████▋              | 1600/1924 [29:24<44:37,  8.26s/it]

Metrics: {'precision@10': 0.021920102}


 88%|███████████████████████████████████████████████████████████████████████████          | 1699/1924 [30:48<03:08,  1.19it/s]

batch <1699> - loss: 0.33181318432092666 - roc_auc score: 0.9228205680847168


 88%|███████████████████████████████████████████████████████████████████████████          | 1700/1924 [31:14<30:45,  8.24s/it]

Metrics: {'precision@10': 0.022020202}


 94%|███████████████████████████████████████████████████████████████████████████████▍     | 1799/1924 [32:37<01:43,  1.20it/s]

batch <1799> - loss: 0.3338363364338875 - roc_auc score: 0.9318656921386719


 94%|███████████████████████████████████████████████████████████████████████████████▌     | 1800/1924 [33:02<16:56,  8.20s/it]

Metrics: {'precision@10': 0.022008725}


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1899/1924 [34:26<00:20,  1.19it/s]

batch <1899> - loss: 0.3326935201883316 - roc_auc score: 0.9257903099060059


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1900/1924 [34:51<03:16,  8.18s/it]

Metrics: {'precision@10': 0.022259917}


100%|█████████████████████████████████████████████████████████████████████████████████████| 1924/1924 [35:11<00:00,  1.10s/it]


Epoch: 0, Loss: 0.0041


  5%|████▍                                                                                  | 99/1924 [01:23<25:37,  1.19it/s]

batch <99> - loss: 0.3286576551198959 - roc_auc score: 0.9202175140380859


  5%|████▎                                                                               | 100/1924 [01:48<4:08:47,  8.18s/it]

Metrics: {'precision@10': 0.022317259}


 10%|████████▉                                                                             | 199/1924 [03:11<24:05,  1.19it/s]

batch <199> - loss: 0.3280209505558014 - roc_auc score: 0.9214122295379639


 10%|████████▋                                                                           | 200/1924 [03:37<3:55:22,  8.19s/it]

Metrics: {'precision@10': 0.02237667}


 16%|█████████████▎                                                                        | 299/1924 [05:00<22:38,  1.20it/s]

batch <299> - loss: 0.32712480902671814 - roc_auc score: 0.9224553108215332


 16%|█████████████                                                                       | 300/1924 [05:25<3:41:21,  8.18s/it]

Metrics: {'precision@10': 0.022254743}


 21%|█████████████████▊                                                                    | 399/1924 [06:49<21:15,  1.20it/s]

batch <399> - loss: 0.32873392313718797 - roc_auc score: 0.9269752502441406


 21%|█████████████████▍                                                                  | 400/1924 [07:14<3:27:46,  8.18s/it]

Metrics: {'precision@10': 0.022002092}


 26%|██████████████████████▎                                                               | 499/1924 [08:37<20:07,  1.18it/s]

batch <499> - loss: 0.3272774884104729 - roc_auc score: 0.9214272499084473


 26%|█████████████████████▊                                                              | 500/1924 [09:03<3:14:32,  8.20s/it]

Metrics: {'precision@10': 0.022270642}


 31%|██████████████████████████▊                                                           | 599/1924 [10:26<18:44,  1.18it/s]

batch <599> - loss: 0.3250985437631607 - roc_auc score: 0.9296121597290039


 31%|██████████████████████████▏                                                         | 600/1924 [10:51<3:00:37,  8.19s/it]

Metrics: {'precision@10': 0.022442149}


 36%|███████████████████████████████▏                                                      | 699/1924 [12:14<17:10,  1.19it/s]

batch <699> - loss: 0.3264514690637588 - roc_auc score: 0.9313673973083496


 36%|██████████████████████████████▌                                                     | 700/1924 [12:40<2:46:45,  8.17s/it]

Metrics: {'precision@10': 0.02213846}


 42%|███████████████████████████████████▋                                                  | 799/1924 [14:03<15:43,  1.19it/s]

batch <799> - loss: 0.325390442609787 - roc_auc score: 0.9230737686157227


 42%|██████████████████████████████████▉                                                 | 800/1924 [14:28<2:33:14,  8.18s/it]

Metrics: {'precision@10': 0.022027822}


 47%|████████████████████████████████████████▏                                             | 899/1924 [15:52<14:25,  1.18it/s]

batch <899> - loss: 0.3273549848794937 - roc_auc score: 0.9266066551208496


 47%|███████████████████████████████████████▎                                            | 900/1924 [16:17<2:19:33,  8.18s/it]

Metrics: {'precision@10': 0.02238951}


 52%|████████████████████████████████████████████▋                                         | 999/1924 [17:40<13:31,  1.14it/s]

batch <999> - loss: 0.3259343746304512 - roc_auc score: 0.9218869209289551


 52%|███████████████████████████████████████████▏                                       | 1000/1924 [18:06<2:06:38,  8.22s/it]

Metrics: {'precision@10': 0.022336168}


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:29<11:37,  1.18it/s]

batch <1099> - loss: 0.3241207236051559 - roc_auc score: 0.9324145317077637


 57%|███████████████████████████████████████████████▍                                   | 1100/1924 [19:55<1:52:46,  8.21s/it]

Metrics: {'precision@10': 0.022401929}


 62%|████████████████████████████████████████████████████▉                                | 1199/1924 [21:18<10:07,  1.19it/s]

batch <1199> - loss: 0.3263606008887291 - roc_auc score: 0.9295954704284668


 62%|███████████████████████████████████████████████████▊                               | 1200/1924 [21:43<1:38:44,  8.18s/it]

Metrics: {'precision@10': 0.022313448}


 68%|█████████████████████████████████████████████████████████▍                           | 1299/1924 [23:07<08:42,  1.20it/s]

batch <1299> - loss: 0.32499513030052185 - roc_auc score: 0.9223456382751465


 68%|████████████████████████████████████████████████████████                           | 1300/1924 [23:32<1:25:10,  8.19s/it]

Metrics: {'precision@10': 0.022547705}


 73%|█████████████████████████████████████████████████████████████▊                       | 1399/1924 [24:55<07:24,  1.18it/s]

batch <1399> - loss: 0.3247631439566612 - roc_auc score: 0.928473949432373


 73%|████████████████████████████████████████████████████████████▍                      | 1400/1924 [25:21<1:11:34,  8.20s/it]

Metrics: {'precision@10': 0.022622922}


 78%|██████████████████████████████████████████████████████████████████▏                  | 1499/1924 [26:44<05:57,  1.19it/s]

batch <1499> - loss: 0.324041993021965 - roc_auc score: 0.9295649528503418


 78%|██████████████████████████████████████████████████████████████████▎                  | 1500/1924 [27:09<58:00,  8.21s/it]

Metrics: {'precision@10': 0.022637598}


 83%|██████████████████████████████████████████████████████████████████████▋              | 1599/1924 [28:32<04:33,  1.19it/s]

batch <1599> - loss: 0.32462930262088774 - roc_auc score: 0.9328832626342773


 83%|██████████████████████████████████████████████████████████████████████▋              | 1600/1924 [28:58<44:32,  8.25s/it]

Metrics: {'precision@10': 0.022452073}


 88%|███████████████████████████████████████████████████████████████████████████          | 1699/1924 [30:22<03:09,  1.19it/s]

batch <1699> - loss: 0.3235175430774689 - roc_auc score: 0.9284591674804688


 88%|███████████████████████████████████████████████████████████████████████████          | 1700/1924 [30:47<30:51,  8.26s/it]

Metrics: {'precision@10': 0.022608528}


 94%|███████████████████████████████████████████████████████████████████████████████▍     | 1799/1924 [32:10<01:45,  1.18it/s]

batch <1799> - loss: 0.32516882836818695 - roc_auc score: 0.9260659217834473


 94%|███████████████████████████████████████████████████████████████████████████████▌     | 1800/1924 [32:36<16:58,  8.21s/it]

Metrics: {'precision@10': 0.022714084}


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1899/1924 [33:59<00:21,  1.19it/s]

batch <1899> - loss: 0.3226244565844536 - roc_auc score: 0.9291925430297852


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1900/1924 [34:25<03:16,  8.19s/it]

Metrics: {'precision@10': 0.022515576}


100%|█████████████████████████████████████████████████████████████████████████████████████| 1924/1924 [34:45<00:00,  1.08s/it]


Epoch: 1, Loss: 0.0040


  5%|████▍                                                                                  | 99/1924 [01:23<25:38,  1.19it/s]

batch <99> - loss: 0.31041175991296766 - roc_auc score: 0.9366397857666016


  5%|████▎                                                                               | 100/1924 [01:48<4:09:14,  8.20s/it]

Metrics: {'precision@10': 0.022234187}


 10%|████████▉                                                                             | 199/1924 [03:11<24:04,  1.19it/s]

batch <199> - loss: 0.3075334995985031 - roc_auc score: 0.9304866790771484


 10%|████████▋                                                                           | 200/1924 [03:37<3:54:39,  8.17s/it]

Metrics: {'precision@10': 0.022221956}


 16%|█████████████▎                                                                        | 299/1924 [05:00<23:01,  1.18it/s]

batch <299> - loss: 0.30692731380462646 - roc_auc score: 0.9361896514892578


 16%|█████████████                                                                       | 300/1924 [05:25<3:41:26,  8.18s/it]

Metrics: {'precision@10': 0.022142693}


 21%|█████████████████▊                                                                    | 399/1924 [06:49<21:20,  1.19it/s]

batch <399> - loss: 0.30679073989391326 - roc_auc score: 0.938570499420166


 21%|█████████████████▍                                                                  | 400/1924 [07:14<3:27:49,  8.18s/it]

Metrics: {'precision@10': 0.022052942}


 26%|██████████████████████▎                                                               | 499/1924 [08:37<19:53,  1.19it/s]

batch <499> - loss: 0.30636575520038606 - roc_auc score: 0.9361019134521484


 26%|█████████████████████▊                                                              | 500/1924 [09:03<3:14:28,  8.19s/it]

Metrics: {'precision@10': 0.021748407}


 31%|██████████████████████████▊                                                           | 599/1924 [10:26<18:39,  1.18it/s]

batch <599> - loss: 0.30605661392211914 - roc_auc score: 0.9325647354125977


 31%|██████████████████████████▏                                                         | 600/1924 [10:51<3:00:36,  8.18s/it]

Metrics: {'precision@10': 0.021737024}


 36%|███████████████████████████████▏                                                      | 699/1924 [12:15<17:20,  1.18it/s]

batch <699> - loss: 0.3054946780204773 - roc_auc score: 0.9304752349853516


 36%|██████████████████████████████▌                                                     | 700/1924 [12:40<2:47:04,  8.19s/it]

Metrics: {'precision@10': 0.021757392}


 42%|███████████████████████████████████▋                                                  | 799/1924 [14:03<15:42,  1.19it/s]

batch <799> - loss: 0.30625747978687284 - roc_auc score: 0.9361236095428467


 42%|██████████████████████████████████▉                                                 | 800/1924 [14:29<2:33:32,  8.20s/it]

Metrics: {'precision@10': 0.021736506}


 47%|████████████████████████████████████████▏                                             | 899/1924 [15:52<14:21,  1.19it/s]

batch <899> - loss: 0.3068440079689026 - roc_auc score: 0.9332118034362793


 47%|███████████████████████████████████████▎                                            | 900/1924 [16:17<2:19:49,  8.19s/it]

Metrics: {'precision@10': 0.021758098}


 52%|████████████████████████████████████████████▋                                         | 999/1924 [17:41<12:53,  1.20it/s]

batch <999> - loss: 0.30572642594575883 - roc_auc score: 0.9331150054931641


 52%|███████████████████████████████████████████▏                                       | 1000/1924 [18:06<2:06:25,  8.21s/it]

Metrics: {'precision@10': 0.021797752}


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:30<11:37,  1.18it/s]

batch <1099> - loss: 0.30513442724943163 - roc_auc score: 0.9436101913452148


 57%|███████████████████████████████████████████████▍                                   | 1100/1924 [19:55<1:52:23,  8.18s/it]

Metrics: {'precision@10': 0.021682363}


 62%|████████████████████████████████████████████████████▉                                | 1199/1924 [21:18<10:09,  1.19it/s]

batch <1199> - loss: 0.3062672904133797 - roc_auc score: 0.9349403381347656


 62%|███████████████████████████████████████████████████▊                               | 1200/1924 [21:44<1:38:42,  8.18s/it]

Metrics: {'precision@10': 0.021424634}


 68%|█████████████████████████████████████████████████████████▍                           | 1299/1924 [23:07<08:44,  1.19it/s]

batch <1299> - loss: 0.3074741348624229 - roc_auc score: 0.9356951713562012


 68%|████████████████████████████████████████████████████████                           | 1300/1924 [23:32<1:25:12,  8.19s/it]

Metrics: {'precision@10': 0.02165447}


 73%|█████████████████████████████████████████████████████████████▊                       | 1399/1924 [24:55<07:23,  1.18it/s]

batch <1399> - loss: 0.305008085668087 - roc_auc score: 0.9361724853515625


 73%|████████████████████████████████████████████████████████████▍                      | 1400/1924 [25:21<1:11:30,  8.19s/it]

Metrics: {'precision@10': 0.021712188}


 78%|██████████████████████████████████████████████████████████████████▏                  | 1499/1924 [26:44<05:55,  1.20it/s]

batch <1499> - loss: 0.30601498574018476 - roc_auc score: 0.9257516860961914


 78%|██████████████████████████████████████████████████████████████████▎                  | 1500/1924 [27:09<57:50,  8.19s/it]

Metrics: {'precision@10': 0.021816708}


 83%|██████████████████████████████████████████████████████████████████████▋              | 1599/1924 [28:33<04:34,  1.19it/s]

batch <1599> - loss: 0.3053678160905838 - roc_auc score: 0.9318370819091797


 83%|██████████████████████████████████████████████████████████████████████▋              | 1600/1924 [28:58<44:25,  8.23s/it]

Metrics: {'precision@10': 0.021619424}


 88%|███████████████████████████████████████████████████████████████████████████          | 1699/1924 [30:21<03:09,  1.19it/s]

batch <1699> - loss: 0.30694333881139757 - roc_auc score: 0.9368181228637695


 88%|███████████████████████████████████████████████████████████████████████████          | 1700/1924 [30:47<30:35,  8.19s/it]

Metrics: {'precision@10': 0.021694407}


 94%|███████████████████████████████████████████████████████████████████████████████▍     | 1799/1924 [32:10<01:45,  1.19it/s]

batch <1799> - loss: 0.30645301401615144 - roc_auc score: 0.9364809989929199


 94%|███████████████████████████████████████████████████████████████████████████████▌     | 1800/1924 [32:35<16:55,  8.19s/it]

Metrics: {'precision@10': 0.021682693}


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1899/1924 [33:59<00:20,  1.20it/s]

batch <1899> - loss: 0.3056786611676216 - roc_auc score: 0.9347314834594727


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1900/1924 [34:24<03:16,  8.20s/it]

Metrics: {'precision@10': 0.021768164}


100%|█████████████████████████████████████████████████████████████████████████████████████| 1924/1924 [34:44<00:00,  1.08s/it]


Epoch: 2, Loss: 0.0039


  5%|████▍                                                                                  | 99/1924 [01:23<25:39,  1.19it/s]

batch <99> - loss: 0.2757242733240128 - roc_auc score: 0.9447751045227051


  5%|████▎                                                                               | 100/1924 [01:48<4:08:58,  8.19s/it]

Metrics: {'precision@10': 0.021054385}


 10%|████████▉                                                                             | 199/1924 [03:11<24:06,  1.19it/s]

batch <199> - loss: 0.2742547954618931 - roc_auc score: 0.9421992301940918


 10%|████████▋                                                                           | 200/1924 [03:37<3:55:09,  8.18s/it]

Metrics: {'precision@10': 0.020210586}


 16%|█████████████▎                                                                        | 299/1924 [05:00<22:55,  1.18it/s]

batch <299> - loss: 0.2730060437321663 - roc_auc score: 0.9449090957641602


 16%|█████████████                                                                       | 300/1924 [05:25<3:41:33,  8.19s/it]

Metrics: {'precision@10': 0.020206729}


 21%|█████████████████▊                                                                    | 399/1924 [06:49<21:18,  1.19it/s]

batch <399> - loss: 0.2734584724903107 - roc_auc score: 0.948028564453125


 21%|█████████████████▍                                                                  | 400/1924 [07:14<3:28:00,  8.19s/it]

Metrics: {'precision@10': 0.020201461}


 26%|██████████████████████▎                                                               | 499/1924 [08:37<19:57,  1.19it/s]

batch <499> - loss: 0.27447231650352477 - roc_auc score: 0.9447212219238281


 26%|█████████████████████▊                                                              | 500/1924 [09:03<3:14:40,  8.20s/it]

Metrics: {'precision@10': 0.01977373}


 31%|██████████████████████████▊                                                           | 599/1924 [10:26<18:31,  1.19it/s]

batch <599> - loss: 0.2749359369277954 - roc_auc score: 0.9417791366577148


 31%|██████████████████████████▏                                                         | 600/1924 [10:51<3:00:16,  8.17s/it]

Metrics: {'precision@10': 0.02007963}


 36%|███████████████████████████████▏                                                      | 699/1924 [12:15<17:22,  1.18it/s]

batch <699> - loss: 0.2743777219951153 - roc_auc score: 0.953704833984375


 36%|██████████████████████████████▌                                                     | 700/1924 [12:40<2:46:52,  8.18s/it]

Metrics: {'precision@10': 0.0198727}


 42%|███████████████████████████████████▋                                                  | 799/1924 [14:03<15:50,  1.18it/s]

batch <799> - loss: 0.2737113097310066 - roc_auc score: 0.9498081207275391


 42%|██████████████████████████████████▉                                                 | 800/1924 [14:29<2:33:14,  8.18s/it]

Metrics: {'precision@10': 0.019967675}


 47%|████████████████████████████████████████▏                                             | 899/1924 [15:52<14:18,  1.19it/s]

batch <899> - loss: 0.276831730902195 - roc_auc score: 0.9504861831665039


 47%|███████████████████████████████████████▎                                            | 900/1924 [16:17<2:19:34,  8.18s/it]

Metrics: {'precision@10': 0.019958878}


 52%|████████████████████████████████████████████▋                                         | 999/1924 [17:41<12:59,  1.19it/s]

batch <999> - loss: 0.27489261656999586 - roc_auc score: 0.9494762420654297


 52%|███████████████████████████████████████████▏                                       | 1000/1924 [18:06<2:06:15,  8.20s/it]

Metrics: {'precision@10': 0.019893821}


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:29<12:00,  1.14it/s]

batch <1099> - loss: 0.2744735115766525 - roc_auc score: 0.950343132019043


 57%|███████████████████████████████████████████████▍                                   | 1100/1924 [19:55<1:55:06,  8.38s/it]

Metrics: {'precision@10': 0.019861694}


 62%|████████████████████████████████████████████████████▉                                | 1199/1924 [21:18<10:08,  1.19it/s]

batch <1199> - loss: 0.27575318902730944 - roc_auc score: 0.9501352310180664


 62%|███████████████████████████████████████████████████▊                               | 1200/1924 [21:44<1:38:50,  8.19s/it]

Metrics: {'precision@10': 0.019877123}


 68%|█████████████████████████████████████████████████████████▍                           | 1299/1924 [23:07<08:47,  1.18it/s]

batch <1299> - loss: 0.27525952965021133 - roc_auc score: 0.9499516487121582


 68%|████████████████████████████████████████████████████████                           | 1300/1924 [23:32<1:25:07,  8.19s/it]

Metrics: {'precision@10': 0.019867338}


 73%|█████████████████████████████████████████████████████████████▊                       | 1399/1924 [24:55<07:22,  1.19it/s]

batch <1399> - loss: 0.2771095106005669 - roc_auc score: 0.9516983032226562


 73%|████████████████████████████████████████████████████████████▍                      | 1400/1924 [25:21<1:13:03,  8.37s/it]

Metrics: {'precision@10': 0.019670289}


 78%|██████████████████████████████████████████████████████████████████▏                  | 1499/1924 [26:44<05:57,  1.19it/s]

batch <1499> - loss: 0.2775083470344544 - roc_auc score: 0.9487404823303223


 78%|██████████████████████████████████████████████████████████████████▎                  | 1500/1924 [27:10<57:56,  8.20s/it]

Metrics: {'precision@10': 0.019686565}


 83%|██████████████████████████████████████████████████████████████████████▋              | 1599/1924 [28:33<04:34,  1.19it/s]

batch <1599> - loss: 0.2771294766664505 - roc_auc score: 0.9471845626831055


 83%|██████████████████████████████████████████████████████████████████████▋              | 1600/1924 [28:58<44:10,  8.18s/it]

Metrics: {'precision@10': 0.019784503}


 88%|███████████████████████████████████████████████████████████████████████████          | 1699/1924 [30:22<03:09,  1.19it/s]

batch <1699> - loss: 0.27724218636751174 - roc_auc score: 0.9469184875488281


 88%|███████████████████████████████████████████████████████████████████████████          | 1700/1924 [30:47<30:29,  8.17s/it]

Metrics: {'precision@10': 0.019693527}


 94%|███████████████████████████████████████████████████████████████████████████████▍     | 1799/1924 [32:10<01:45,  1.19it/s]

batch <1799> - loss: 0.27736544132232666 - roc_auc score: 0.9450736045837402


 94%|███████████████████████████████████████████████████████████████████████████████▌     | 1800/1924 [32:35<16:55,  8.19s/it]

Metrics: {'precision@10': 0.020046795}


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1899/1924 [33:59<00:21,  1.19it/s]

batch <1899> - loss: 0.2774539875984192 - roc_auc score: 0.9469079971313477


 99%|███████████████████████████████████████████████████████████████████████████████████▉ | 1900/1924 [34:24<03:16,  8.17s/it]

Metrics: {'precision@10': 0.019782856}


100%|█████████████████████████████████████████████████████████████████████████████████████| 1924/1924 [34:44<00:00,  1.08s/it]


Epoch: 3, Loss: 0.0034


  5%|████▍                                                                                  | 99/1924 [01:23<25:34,  1.19it/s]

batch <99> - loss: 0.23860626950860023 - roc_auc score: 0.9631361961364746


  5%|████▎                                                                               | 100/1924 [01:48<4:08:42,  8.18s/it]

Metrics: {'precision@10': 0.018279891}


 10%|████████▉                                                                             | 199/1924 [03:11<24:09,  1.19it/s]

batch <199> - loss: 0.23737020969390868 - roc_auc score: 0.9656662940979004


 10%|████████▋                                                                           | 200/1924 [03:37<3:55:24,  8.19s/it]

Metrics: {'precision@10': 0.017663814}


 16%|█████████████▎                                                                        | 299/1924 [05:00<22:46,  1.19it/s]

batch <299> - loss: 0.2378930266201496 - roc_auc score: 0.9631562232971191


 16%|█████████████                                                                       | 300/1924 [05:25<3:41:29,  8.18s/it]

Metrics: {'precision@10': 0.01805095}


 21%|█████████████████▊                                                                    | 399/1924 [06:48<21:27,  1.18it/s]

batch <399> - loss: 0.23922433614730834 - roc_auc score: 0.959355354309082


 21%|█████████████████▍                                                                  | 400/1924 [07:14<3:28:11,  8.20s/it]

Metrics: {'precision@10': 0.018091591}


 26%|██████████████████████▎                                                               | 499/1924 [08:37<19:52,  1.19it/s]

batch <499> - loss: 0.2389494889974594 - roc_auc score: 0.9662919044494629


 26%|█████████████████████▊                                                              | 500/1924 [09:02<3:14:17,  8.19s/it]

Metrics: {'precision@10': 0.018026395}


 31%|██████████████████████████▊                                                           | 599/1924 [10:26<18:26,  1.20it/s]

batch <599> - loss: 0.24185117214918136 - roc_auc score: 0.9594416618347168


 31%|██████████████████████████▏                                                         | 600/1924 [10:51<3:00:18,  8.17s/it]

Metrics: {'precision@10': 0.017955882}


 36%|███████████████████████████████▏                                                      | 699/1924 [12:14<17:08,  1.19it/s]

batch <699> - loss: 0.24073730915784836 - roc_auc score: 0.9562058448791504


 36%|██████████████████████████████▌                                                     | 700/1924 [12:39<2:46:53,  8.18s/it]

Metrics: {'precision@10': 0.017986082}


 42%|███████████████████████████████████▋                                                  | 799/1924 [14:02<15:58,  1.17it/s]

batch <799> - loss: 0.24010149165987968 - roc_auc score: 0.9582414627075195


 42%|██████████████████████████████████▉                                                 | 800/1924 [14:28<2:33:49,  8.21s/it]

Metrics: {'precision@10': 0.018027946}


 47%|████████████████████████████████████████▏                                             | 899/1924 [15:51<14:25,  1.18it/s]

batch <899> - loss: 0.24280113965272904 - roc_auc score: 0.9608922004699707


 47%|███████████████████████████████████████▎                                            | 900/1924 [16:16<2:19:38,  8.18s/it]

Metrics: {'precision@10': 0.017976392}


 52%|████████████████████████████████████████████▋                                         | 999/1924 [17:40<13:06,  1.18it/s]

batch <999> - loss: 0.24361703917384148 - roc_auc score: 0.9600071907043457


 52%|███████████████████████████████████████████▏                                       | 1000/1924 [18:05<2:05:55,  8.18s/it]

Metrics: {'precision@10': 0.017870741}


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:29<11:47,  1.17it/s]

batch <1099> - loss: 0.24217766240239144 - roc_auc score: 0.9645195007324219


 57%|████████████████████████████████████████████████▌                                    | 1099/1924 [19:49<14:52,  1.08s/it]


KeyboardInterrupt: 

In [None]:
p2 = retrieval.RetrievalPrecision(top_k=5)
r2 = retrieval.RetrievalRecall(top_k=2)
ndcg = retrieval.RetrievalNormalizedDCG(top_k=7)

In [None]:
preds = torch.tensor([0.7, 0.8, 0.1, 0.2, 0.4, 0.6, 0.5, 0.9, 0.3, 0.15])
targets = torch.tensor([True, True, False, False, False, False, False, False, False, False])
indices = torch.tensor([0,0,0,0,0,0,0,0,0,0])

mask = torch.ones(preds.shape, dtype=torch.bool)
mask[[4, 5, 7]] = False
print(preds[mask])
print(targets[mask])

In [None]:
preds = torch.tensor([1.0, 0.85, 0.8, 0.7, 0.65])
targets = torch.tensor([False, False, False, False, False])
indices = torch.tensor([0,0,0,0,0])

In [None]:
p2(preds, targets, indexes=indices)

In [None]:
r2(preds, targets, indexes=indices)

In [None]:
ndcg(preds, targets, indexes=indices)

In [None]:
def dcg(rel):
    g = 0.
    for i in range(1,6):
        g+= (2**rel[i-1] - 1)/np.log2(i+1)
    return g

In [None]:
rel = [1,0,1,0,1]
rel_idcg = [1,1,1,0,0]
print(dcg(rel))
print(dcg(rel_idcg))
print(dcg(rel)/dcg(rel_idcg))

In [None]:
torch.tensor([0.9, 0.7, 0.6])

In [None]:
(preds[i][:3] / torch.log2(torch.arange(3)+2)).sum() / (torch.tensor([0.9, 0.7, 0.6]) / torch.log2(torch.arange(3)+2)).sum()

In [None]:
torch.log2(torch.arange(3)+2)

In [None]:
cm = confusion_matrix(y_true.detach().cpu().numpy(), y_pred.detach().cpu().numpy().round())
cm_display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=[False, True])
cm_display.plot()
plt.show()

In [25]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    
def load_model(path):
    model = Model(hidden_channels=32, out_channels=32, metadata=train_data.metadata())
    model.load_state_dict(torch.load(path))
    model = model.to(device)
    return model

In [26]:
save_model(model, "models/gnn_03.pth")

In [None]:
model = load_model("models/gnn_02.pth")

In [None]:
print(nn.summary(model, next(iter(train_loader)).to(device)))

In [None]:
model