# Model train
---

## Import

In [14]:
import torch
from torch_geometric.datasets import Reddit, Amazon
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
from torch.functional import F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import NeighborSampler
from torch_geometric.data import Data
import torch.optim as optim

## Read and prepare the data

In [15]:
node_data = pd.read_parquet('data/amazon_product_data_word2vec.parquet')

In [16]:
data = torch.load('data/amazon_product_data.pt')
loader = DataLoader(data, batch_size=32, shuffle=True)

  data = torch.load('data/amazon_product_data.pt')


data Data(x=[863130, 100], edge_index=[2, 815222], y=[863130])
num nodes 863130
Num edges 815222
num features 100
is undirected False
is directed True


In [22]:
print("data", data)
print("num nodes", data.num_nodes)
print("Num edges", data.num_edges)
print("num node features", data.num_node_features)
print("is undirected", data.is_undirected())
print("is directed", data.is_directed())

data Data(x=[863130, 100], edge_index=[2, 815222], y=[863130])
num nodes 863130
Num edges 815222
num node features 100
is undirected False
is directed True


In [24]:
def create_data_split_masks(data, train_ratio=0.8, val_ratio=0.1):
  # Total number of nodes
  num_nodes = data.num_nodes

  # Randomly permute the node indices
  perm = torch.randperm(num_nodes)

  # Calculate split sizes
  train_size = int(train_ratio * num_nodes)
  val_size = int(val_ratio * num_nodes)
  test_size = num_nodes - train_size - val_size

  # Create masks for train, validation, and test
  train_mask = torch.zeros(num_nodes, dtype=torch.bool)
  val_mask = torch.zeros(num_nodes, dtype=torch.bool)
  test_mask = torch.zeros(num_nodes, dtype=torch.bool)

  # Assign masks
  train_mask[perm[:train_size]] = True
  val_mask[perm[train_size:train_size + val_size]] = True
  test_mask[perm[train_size + val_size:]] = True

  # Assign masks to the data object
  data.train_mask = train_mask
  data.val_mask = val_mask
  data.test_mask = test_mask
  
  return data

data = create_data_split_masks(data)
    

## Model training

In [17]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels)
        self.conv2 = GATConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [13]:
def train_epoch(model, optimizer, loader, device):
    """Train the model for one epoch"""
    model.train()
    
    b_losses = np.empty(len(loader))
    b_accuracies = np.empty(len(loader))

    for batch in loader:
      batch.to(device)
      optimizer.zero_grad()
      out = model(batch.x, batch.edge_index)
      loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])
      loss.backward()
      optimizer.step()
      
      b_losses.append(loss.item())
      
      b_acc = (out[batch.train_mask].argmax(dim=1) == batch.y[batch.train_mask]).sum().item() / batch.train_mask.sum().item()
      b_accuracies.append(b_acc / batch.train_mask.sum().item())
      
    loss = np.mean(b_losses)
    acc = np.mean(b_accuracies)
    return loss, acc


@torch.no_grad()
def evaluate(model, data_loader, device, data):
    """Evaluate the model on the validation or test set."""
    model.eval()
    total_correct = 0

    for batch_size, n_id, adjs in data_loader:
        adjs = [adj.to(device) for adj in adjs]

        # Compute predictions for all nodes in the mini-batch
        x_input = data.x[n_id].to(device)
        out = model(x_input, adjs[0].edge_index)

        # Get predictions only for the first `batch_size` target nodes
        pred = out[:batch_size].max(dim=1)[1]

        # Compare with the actual labels of the target nodes
        total_correct += (pred ==
                          data.y[n_id[:batch_size]].to(device)).sum().item()

    # Calculate accuracy based on the number of target nodes in the entire set
    return total_correct / len(data_loader.dataset)

KeyError: 4