# Model train
---

## Import

In [33]:
import torch
from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader
from torch.functional import F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.loader import NeighborSampler, NeighborLoader
from torch_geometric.data import Data
import torch.optim as optim
import seaborn as sns
from models import *
import tqdm
from umap import UMAP

np.random.seed(0)

## Read, understand and prepare the data

In [22]:
# node_data = pd.read_parquet('data/amazon_product_data_word2vec.parquet')
data = torch.load('data/amazon_product_data_concat.pt')
data.num_classes = data.y.unique().shape[0]

### Main info

In [23]:
print("data", data)
print("num nodes", data.num_nodes)
print("Num edges", data.num_edges)
print("num node features", data.num_node_features)
print("is undirected", data.is_undirected())
print("is directed", data.is_directed())
print("num edge features", data.num_edge_features)
print('num classes', data.num_classes)

data Data(x=[729819, 1200], edge_index=[2, 680548], y=[729819], num_classes=10)
num nodes 729819
Num edges 680548
num node features 1200
is undirected False
is directed True
num edge features 0
num classes 10


In [24]:
# value_counts = node_data['main_category'].value_counts()

# # plot a bar chart of the main categories
# plt.figure(figsize=(15, 6))
# plt.bar(value_counts.index, value_counts.values)
# plt.xticks(rotation=90)
# plt.title('Main Category Distribution')
# plt.show()

### Visualization

In [32]:
def visualize(h, color, epoch=None, loss=None):
    z = UMAP(n_components=2).fit_transform(h.detach().cpu().numpy())

    plt.figure(figsize=(5, 5))
    plt.xticks([])
    plt.yticks([])
    
    title_text = 'Epoch: {}'.format(epoch) if epoch is not None else ''
    loss_text = 'Loss: {:.4f}'.format(loss) if loss is not None else ''
    plt.title('Embedding visualization' + ' ' + title_text + ' ' + loss_text)
    plt.scatter(z[:, 0], z[:, 1], s=70, c=color, cmap="Set2")
    plt.show()

In [26]:
# from torch_geometric.utils import to_networkx

# G = to_networkx(data, to_undirected=True)
# visualize_graph(G, color=data.y)

### Preparation for model training

In [27]:
def create_masks(data, train_ratio, val_ratio):
  num_nodes = data.num_nodes
  indices = list(range(num_nodes))
  np.random.shuffle(indices)
  
  # Create masks
  train_mask = torch.zeros(num_nodes, dtype=torch.bool)
  val_mask = torch.zeros(num_nodes, dtype=torch.bool)
  test_mask = torch.zeros(num_nodes, dtype=torch.bool)
  
  # Set proportions for train, val, and test
  train_end = int(train_ratio * num_nodes)
  val_end = train_end + int(val_ratio * num_nodes)
  train_mask[indices[:train_end]] = True
  val_mask[indices[train_end:val_end]] = True
  test_mask[indices[val_end:]] = True
  
  # Assign the custom masks to the dataset
  data.train_mask = train_mask
  data.val_mask = val_mask
  data.test_mask = test_mask


def train_test_split_graph(data: Data, train_ratio: float, val_ratio: float, batch_size: int):
    """
    Split the graph data into train, validation, and test sets
    :param data: The graph data
    :param train_ratio: The ratio of the training set
    :param val_ratio: The ratio of the validation set
    :param batch_size: The batch size
    
    :return: The train, validation, and test data loaders
    """

    # Call the function to create masks
    create_masks(data, train_ratio, val_ratio)

    # train, validation, and test node indices based on the masks
    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)
    val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)
    test_idx = data.test_mask.nonzero(as_tuple=False).view(-1)
    
    def create_data_loader(data, indices):
        return DataLoader(data[indices], batch_size=batch_size, shuffle=True)
      
    def create_neighbor_loader(data, indices, batch_size=batch_size):
        return NeighborLoader(data, num_neighbors=[30] * 2, batch_size=batch_size, input_nodes=indices)

    # create the data loaders
    train_loader = create_neighbor_loader(data, train_idx)
    val_loader = create_neighbor_loader(data, val_idx)
    test_loader = create_neighbor_loader(data, test_idx)

    return train_loader, val_loader, test_loader

## Model training

### Models

### Training functions

In [30]:
from sklearn.metrics import f1_score, balanced_accuracy_score
import numpy as np
import torch

# Training epoch
def train_epoch(model, optimizer, loss_fn, train_loader: NeighborLoader, device, print_every, epoch=0):
    model.train()
    b_losses = []
    b_accuracies = []
    b_f1_scores = []
    b_balanced_accuracies = []

    for batch in train_loader:
        optimizer.zero_grad()
        batch = batch.to(device)

        out, h = model(batch.x, batch.edge_index)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()

        b_losses.append(loss.item())
        preds = out.argmax(dim=1).cpu().numpy()
        labels = batch.y.cpu().numpy()
        b_acc = (preds == labels).sum() / len(labels)
        b_accuracies.append(b_acc)
        
        # Compute F1-score and balanced accuracy
        b_f1 = f1_score(labels, preds, average='weighted')
        b_bal_acc = balanced_accuracy_score(labels, preds)
        b_f1_scores.append(b_f1)
        b_balanced_accuracies.append(b_bal_acc)
    
    visualize(h, batch.y, epoch, loss=np.mean(b_losses))
     
    
    return (
        np.mean(b_losses),
        np.mean(b_accuracies),
        np.mean(b_f1_scores),
        np.mean(b_balanced_accuracies),
    )

# Validation
@torch.no_grad()
def validate(model, loss_fn, val_loader, device):
    model.eval()
    total_correct = 0
    total_loss = 0
    total_samples = 0
    all_preds = []
    all_labels = []

    for batch in val_loader:
        batch = batch.to(device)
        out, _ = model(batch.x, batch.edge_index)
        loss = loss_fn(out, batch.y)
        total_loss += loss.item()
        preds = out.argmax(dim=1).cpu().numpy()
        labels = batch.y.cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels)
        total_correct += (preds == labels).sum()
        total_samples += len(labels)

    val_acc = total_correct / total_samples
    val_f1 = f1_score(all_labels, all_preds, average='weighted')
    val_bal_acc = balanced_accuracy_score(all_labels, all_preds)

    return total_loss / len(val_loader), val_acc, val_f1, val_bal_acc

# Training loop
def training_loop(model, optimizer, loss_fn, train_loader, val_loader, num_epochs, print_every):
    print("Starting training")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    train_losses, train_accs, train_f1s, train_bal_accs = [], [], [], []
    val_losses, val_accs, val_f1s, val_bal_accs = [], [], [], []

    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc, train_f1, train_bal_acc = train_epoch(
            model, optimizer, loss_fn, train_loader, device, print_every, epoch
        )
        val_loss, val_acc, val_f1, val_bal_acc = validate(model, loss_fn, val_loader, device)
        print(
            f"Epoch {epoch}/{num_epochs}: "
            f"Train loss: {train_loss:.3f}, Train acc.: {train_acc:.3f}, Train F1: {train_f1:.3f}, Train Bal. Acc.: {train_bal_acc:.3f}, "
            f"Val. loss: {val_loss:.3f}, Val. acc.: {val_acc:.3f}, Val. F1: {val_f1:.3f}, Val. Bal. Acc.: {val_bal_acc:.3f}"
        )
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        train_f1s.append(train_f1)
        train_bal_accs.append(train_bal_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        val_f1s.append(val_f1)
        val_bal_accs.append(val_bal_acc)

    return model, train_losses, train_accs, train_f1s, train_bal_accs, val_losses, val_accs, val_f1s, val_bal_accs

### Actual training

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
train_ratio = 0.8
val_ratio = 0.1

learning_rate = 0.01
weight_decay = 5e-4
batch_size = 64


train_loader, val_loader, test_loader = train_test_split_graph(data, train_ratio, val_ratio, batch_size)

# Initialize the model and optimizer
# model = GAT(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes, num_heads=8).to(device)
model = GCN(in_channels=data.num_node_features, hidden_channels=64, out_channels=data.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the model
import warnings
warnings.filterwarnings("ignore")
model, train_losses, train_accs, val_losses, val_accs = training_loop(
  model, optimizer, loss_fn, train_loader, val_loader, num_epochs=10, print_every=1
)

Starting training


TypeError: visualize() got an unexpected keyword argument 'loss'

## Evaluation