In [2]:
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable


from operator import itemgetter
import math

## Layers

### Inter Relation Aggregator using GNN



In [3]:
class InterAgg_leaky(nn.Module):

	def __init__(self, features, feature_dim,
				 embed_dim, adj_lists, intraggs,
				 inter='GNN', step_size=0.02, cuda=True):
		"""
		Initialize the inter-relation aggregator
		:param features: the input node features or embeddings for all nodes
		:param feature_dim: the input dimension
		:param embed_dim: the output dimension
		:param adj_lists: a list of adjacency lists for each single-relation graph
		:param intraggs: the intra-relation aggregators used by each single-relation graph
		:param inter: the aggregator type: 'Att', 'Weight', 'Mean', 'GNN'
		:param step_size: the RL action step size
		:param cuda: whether to use GPU
		"""
		super(InterAgg_leaky, self).__init__()

		self.features = features
		self.dropout = 0.6
		self.adj_lists = adj_lists
		self.intra_agg1 = intraggs[0]
		self.intra_agg2 = intraggs[1]
		self.intra_agg3 = intraggs[2]
		self.embed_dim = embed_dim
		self.feat_dim = feature_dim
		self.inter = inter
		self.step_size = step_size
		self.cuda = cuda
		self.intra_agg1.cuda = cuda
		self.intra_agg2.cuda = cuda
		self.intra_agg3.cuda = cuda

		# RL condition flag
		self.RL = True

		# number of batches for current epoch, assigned during training
		self.batch_num = 0

		# initial filtering thresholds
		self.thresholds = [0.5, 0.5, 0.5]

		# the activation function used by attention mechanism
		self.leakyrelu = nn.LeakyReLU(0.2)

		# parameter used to transform node embeddings before inter-relation aggregation
		self.weight = nn.Parameter(torch.FloatTensor(self.feat_dim, self.embed_dim))
		init.xavier_uniform_(self.weight)

		# weight parameter for each relation used by CARE-Weight
		self.alpha = nn.Parameter(torch.FloatTensor(self.embed_dim, 3))
		init.xavier_uniform_(self.alpha)

		# parameters used by attention layer
		self.a = nn.Parameter(torch.FloatTensor(2 * self.embed_dim, 1))
		init.xavier_uniform_(self.a)

		# label predictor for similarity measure
		self.label_clf = nn.Linear(self.feat_dim, 2)

		# initialize the parameter logs
		self.weights_log = []
		self.thresholds_log = [self.thresholds]
		self.relation_score_log = []

	def forward(self, nodes, labels, train_flag=True):
		"""
		:param nodes: a list of batch node ids
		:param labels: a list of batch node labels, only used by the RLModule
		:param train_flag: indicates whether in training or testing mode
		:return combined: the embeddings of a batch of input node features
		:return center_scores: the label-aware scores of batch nodes
		"""

		# extract 1-hop neighbor ids from adj lists of each single-relation graph
		to_neighs = []
		for adj_list in self.adj_lists:
			to_neighs.append([set(adj_list[int(node)]) for node in nodes])

		# find unique nodes and their neighbors used in current batch
		unique_nodes = set.union(set.union(*to_neighs[0]), set.union(*to_neighs[1]),
								 set.union(*to_neighs[2], set(nodes)))

		# calculate label-aware scores
		if self.cuda:
			batch_features = self.features(torch.cuda.LongTensor(list(unique_nodes)))
		else:
			batch_features = self.features(torch.LongTensor(list(unique_nodes)))
		batch_scores = self.label_clf(batch_features)
		id_mapping = {node_id: index for node_id, index in zip(unique_nodes, range(len(unique_nodes)))}

		# the label-aware scores for current batch of nodes
		center_scores = batch_scores[itemgetter(*nodes)(id_mapping), :]

		# get neighbor node id list for each batch node and relation
		r1_list = [list(to_neigh) for to_neigh in to_neighs[0]]
		r2_list = [list(to_neigh) for to_neigh in to_neighs[1]]
		r3_list = [list(to_neigh) for to_neigh in to_neighs[2]]

		# assign label-aware scores to neighbor nodes for each batch node and relation
		r1_scores = [batch_scores[itemgetter(*to_neigh)(id_mapping), :].view(-1, 2) for to_neigh in r1_list]
		r2_scores = [batch_scores[itemgetter(*to_neigh)(id_mapping), :].view(-1, 2) for to_neigh in r2_list]
		r3_scores = [batch_scores[itemgetter(*to_neigh)(id_mapping), :].view(-1, 2) for to_neigh in r3_list]

		# count the number of neighbors kept for aggregation for each batch node and relation
		r1_sample_num_list = [math.ceil(len(neighs) * self.thresholds[0]) for neighs in r1_list]
		r2_sample_num_list = [math.ceil(len(neighs) * self.thresholds[1]) for neighs in r2_list]
		r3_sample_num_list = [math.ceil(len(neighs) * self.thresholds[2]) for neighs in r3_list]

		# intra-aggregation steps for each relation
		# Eq. (8) in the paper
		r1_feats, r1_scores = self.intra_agg1.forward(nodes, r1_list, center_scores, r1_scores, r1_sample_num_list)
		r2_feats, r2_scores = self.intra_agg2.forward(nodes, r2_list, center_scores, r2_scores, r2_sample_num_list)
		r3_feats, r3_scores = self.intra_agg3.forward(nodes, r3_list, center_scores, r3_scores, r3_sample_num_list)

		# concat the intra-aggregated embeddings from each relation
		neigh_feats = torch.cat((r1_feats, r2_feats, r3_feats), dim=0)

		# get features or embeddings for batch nodes
		if self.cuda and isinstance(nodes, list):
			index = torch.LongTensor(nodes).cuda()
		else:
			index = torch.LongTensor(nodes)
		self_feats = self.features(index)

		# number of nodes in a batch
		n = len(nodes)

		# inter-relation aggregation steps
		# Eq. (9) in the paper
		# if self.inter == 'Att':
		# 	# 1) CARE-Att Inter-relation Aggregator
		# 	combined, attention = att_inter_agg(len(self.adj_lists), self.leakyrelu, self_feats, neigh_feats, self.embed_dim,
		# 										self.weight, self.a, n, self.dropout, self.training, self.cuda)
		# elif self.inter == 'Weight':
		# 	# 2) CARE-Weight Inter-relation Aggregator
		# 	combined = weight_inter_agg(len(self.adj_lists), self_feats, neigh_feats, self.embed_dim, self.weight, self.alpha, n, self.cuda)
		# 	gem_weights = F.softmax(torch.sum(self.alpha, dim=0), dim=0).tolist()
		# 	if train_flag:
		# 		print(f'Weights: {gem_weights}')
		# elif self.inter == 'Mean':
		# 	# 3) CARE-Mean Inter-relation Aggregator
		# 	combined = mean_inter_agg(len(self.adj_lists), self_feats, neigh_feats, self.embed_dim, self.weight, n, self.cuda)
		if self.inter == 'GNN':
			# 4) CARE-GNN Inter-relation Aggregator
			combined = threshold_inter_agg(len(self.adj_lists), self_feats, neigh_feats, self.embed_dim, self.weight, self.thresholds, n, self.cuda)

		# the reinforcement learning module
		if self.RL and train_flag:
			relation_scores, rewards, thresholds, stop_flag = RLModule([r1_scores, r2_scores, r3_scores],
																	   self.relation_score_log, labels, self.thresholds,
																	   self.batch_num, self.step_size)
			self.thresholds = thresholds
			self.RL = stop_flag
			self.relation_score_log.append(relation_scores)
			self.thresholds_log.append(self.thresholds)

		return combined, center_scores

### Instra Relation Aggregator with LeakyReLU (Proposed)



In [4]:
class IntraAgg_leaky(nn.Module):

	def __init__(self, features, feat_dim, cuda=False):
		"""
		Initialize the intra-relation aggregator
		:param features: the input node features or embeddings for all nodes
		:param feat_dim: the input dimension
		:param cuda: whether to use GPU
		"""
		super(IntraAgg_leaky, self).__init__()

		self.features = features
		self.cuda = cuda
		self.feat_dim = feat_dim

	def forward(self, nodes, to_neighs_list, batch_scores, neigh_scores, sample_list):
		"""
		Code partially from https://github.com/williamleif/graphsage-simple/
		:param nodes: list of nodes in a batch
		:param to_neighs_list: neighbor node id list for each batch node in one relation
		:param batch_scores: the label-aware scores of batch nodes
		:param neigh_scores: the label-aware scores 1-hop neighbors each batch node in one relation
		:param sample_list: the number of neighbors kept for each batch node in one relation
		:return to_feats: the aggregated embeddings of batch nodes neighbors in one relation
		:return samp_scores: the average neighbor distances for each relation after filtering
		"""

		# filer neighbors under given relation
		samp_neighs, samp_scores = filter_neighs_ada_threshold(batch_scores, neigh_scores, to_neighs_list, sample_list)

		# find the unique nodes among batch nodes and the filtered neighbors
		unique_nodes_list = list(set.union(*samp_neighs))
		unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)}

		# intra-relation aggregation only with sampled neighbors
		mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
		column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
		row_indices = [i for i in range(len(samp_neighs)) for _ in range(len(samp_neighs[i]))]
		mask[row_indices, column_indices] = 1
		if self.cuda:
			mask = mask.cuda()
		num_neigh = mask.sum(1, keepdim=True)
		mask = mask.div(num_neigh)
		if self.cuda:
			embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
		else:
			embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
		to_feats = mask.mm(embed_matrix)
		# to_feats = F.relu(to_feats)
		to_feats = nn.LeakyReLU(0.2)(to_feats) # Proposed Leakly Relu Activtion eqn 8
		return to_feats, samp_scores


### Reinforcement Learning Module 

In [5]:
def RLModule(scores, scores_log, labels, thresholds, batch_num, step_size):
	"""
	The reinforcement learning module.
	It updates the neighbor filtering threshold for each relation based
	on the average neighbor distances between two consecutive epochs.
	:param scores: the neighbor nodes label-aware scores for each relation
	:param scores_log: a list stores the relation average distances for each batch
	:param labels: the batch node labels used to select positive nodes
	:param thresholds: the current neighbor filtering thresholds for each relation
	:param batch_num: numbers batches in an epoch
	:param step_size: the RL action step size
	:return relation_scores: the relation average distances for current batch
	:return rewards: the reward for given thresholds in current epoch
	:return new_thresholds: the new filtering thresholds updated according to the rewards
	:return stop_flag: the RL terminal condition flag
	"""

	relation_scores = []
	stop_flag = True

	# only compute the average neighbor distances for positive nodes
	pos_index = (labels == 1).nonzero().tolist()
	pos_index = [i[0] for i in pos_index]

	# compute average neighbor distances for each relation
	for score in scores:
		pos_scores = itemgetter(*pos_index)(score)
		neigh_count = sum([1 if isinstance(i, float) else len(i) for i in pos_scores])
		pos_sum = [i if isinstance(i, float) else sum(i) for i in pos_scores]
		relation_scores.append(sum(pos_sum) / neigh_count)

	if len(scores_log) % batch_num != 0 or len(scores_log) < 2 * batch_num:
		# do not call RL module within the epoch or within the first two epochs
		rewards = [0, 0, 0]
		new_thresholds = thresholds
	else:
		# update thresholds according to average scores in last epoch
		# Eq.(5) in the paper
		previous_epoch_scores = [sum(s) / batch_num for s in zip(*scores_log[-2 * batch_num:-batch_num])]
		current_epoch_scores = [sum(s) / batch_num for s in zip(*scores_log[-batch_num:])]

		# compute reward for each relation and update the thresholds according to reward
		# Eq. (6) in the paper
		rewards = [1 if previous_epoch_scores[i] - s >= 0 else -1 for i, s in enumerate(current_epoch_scores)]
		new_thresholds = [thresholds[i] + step_size if r == 1 else thresholds[i] - step_size for i, r in enumerate(rewards)]

		# avoid overflow
		new_thresholds = [0.999 if i > 1 else i for i in new_thresholds]
		new_thresholds = [0.001 if i < 0 else i for i in new_thresholds]

		print(f'epoch scores: {current_epoch_scores}')
		print(f'rewards: {rewards}')
		print(f'thresholds: {new_thresholds}')

	return relation_scores, rewards, new_thresholds, stop_flag

### Filter neighbors from label predictor

In [6]:
def filter_neighs_ada_threshold(center_scores, neigh_scores, neighs_list, sample_list):
	"""
	Filter neighbors according label predictor result with adaptive thresholds
	:param center_scores: the label-aware scores of batch nodes
	:param neigh_scores: the label-aware scores 1-hop neighbors each batch node in one relation
	:param neighs_list: neighbor node id list for each batch node in one relation
	:param sample_list: the number of neighbors kept for each batch node in one relation
	:return samp_neighs: the neighbor indices and neighbor simi scores
	:return samp_scores: the average neighbor distances for each relation after filtering
	"""

	samp_neighs = []
	samp_scores = []
	for idx, center_score in enumerate(center_scores):
		center_score = center_scores[idx][0]
		neigh_score = neigh_scores[idx][:, 0].view(-1, 1)
		center_score = center_score.repeat(neigh_score.size()[0], 1)
		neighs_indices = neighs_list[idx]
		num_sample = sample_list[idx]

		# compute the L1-distance of batch nodes and their neighbors
		# Eq. (2) in paper
		score_diff = torch.abs(center_score - neigh_score).squeeze()
		sorted_scores, sorted_indices = torch.sort(score_diff, dim=0, descending=False)
		selected_indices = sorted_indices.tolist()

		# top-p sampling according to distance ranking and thresholds
		# Section 3.3.1 in paper
		if len(neigh_scores[idx]) > num_sample + 1:
			selected_neighs = [neighs_indices[n] for n in selected_indices[:num_sample]]
			selected_scores = sorted_scores.tolist()[:num_sample]
		else:
			selected_neighs = neighs_indices
			selected_scores = score_diff.tolist()
			if isinstance(selected_scores, float):
				selected_scores = [selected_scores]

		samp_neighs.append(set(selected_neighs))
		samp_scores.append(selected_scores)

	return samp_neighs, samp_scores

### CARE-GNN inter-relation aggregator using LeakyReLU (Proposed)

In [7]:
def threshold_inter_agg(num_relations, self_feats, neigh_feats, embed_dim, weight, threshold, n, cuda):
	"""
	CARE-GNN inter-relation aggregator
	Eq. (9) in the paper use Leaky Relu instread of Relu
	:param num_relations: number of relations in the graph
	:param self_feats: batch nodes features or embeddings
	:param neigh_feats: intra-relation aggregated neighbor embeddings for each relation
	:param embed_dim: the dimension of output embedding
	:param weight: parameter used to transform node embeddings before inter-relation aggregation
	:param threshold: the neighbor filtering thresholds used as aggregating weights
	:param n: number of nodes in a batch
	:param cuda: whether use GPU
	:return: inter-relation aggregated node embeddings
	"""

	# transform batch node embedding and neighbor embedding in each relation with weight parameter
	center_h = torch.mm(self_feats, weight)
	neigh_h = torch.mm(neigh_feats, weight)

	# initialize the final neighbor embedding
	if cuda:
		aggregated = torch.zeros(size=(n, embed_dim)).cuda()
	else:
		aggregated = torch.zeros(size=(n, embed_dim))

	# add weighted neighbor embeddings in each relation together
	for r in range(num_relations):
		aggregated += neigh_h[r * n:(r + 1) * n, :] * threshold[r]

	# sum aggregated neighbor embedding and batch node embedding
	# feed them to activation function
	# combined = F.relu(center_h + aggregated)
	combined = nn.LeakyReLU(0.2)(center_h + aggregated) # Proposed Leakly Relu Activtion eqn 9

	return combined

# Model 

In [8]:
class MultiLayerCARE_Leaky(nn.Module):
    """
    Multi-Layer CARE-GNN- with additional transform layers and residual connections between embeddings
    """
    def __init__(self, num_classes, features, feat_dim, embed_dim, adj_lists,
                 num_layers=2, lambda_1=2.0, step_size=0.02, cuda=False):
        super(MultiLayerCARE_Leaky, self).__init__()

        self.num_layers = num_layers
        self.lambda_1 = lambda_1
        self.xent = nn.CrossEntropyLoss()

        # Core CARE-GNN components
        intra1 = IntraAgg_leaky(features, feat_dim, cuda=cuda)
        intra2 = IntraAgg_leaky(features, feat_dim, cuda=cuda)
        intra3 = IntraAgg_leaky(features, feat_dim, cuda=cuda)

        self.inter1 = InterAgg_leaky(features, feat_dim, embed_dim, adj_lists,
                              [intra1, intra2, intra3],
                              inter='GNN', step_size=step_size, cuda=cuda)

        # Disable RL to prevent modulo by zero error
        self.inter1.RL = False

        ##### NOVELTY PROPOSED #####
        # Novelty: Add additional transformation layers
        
        self.transforms = nn.ModuleList()
        for i in range(num_layers - 1):
            self.transforms.append(nn.Linear(embed_dim, embed_dim))
        #############################
        # Final classifier
        self.weight = nn.Parameter(torch.FloatTensor(embed_dim, num_classes))
        init.xavier_uniform_(self.weight)

    def forward(self, nodes, labels, train_flag=True):
        # First layer (CARE-GNN without RL issues)
        embeddings, label_scores = self.inter1(nodes, labels, train_flag)

        # NOVELTY: Apply additional transformations
        #applies a sequence of linear layers with ReLU, followed by residual connections (like in ResNet):
        prev_embeddings = embeddings
        for i, transform in enumerate(self.transforms):
            embeddings = F.relu(transform(embeddings))
            # ALSOAdd residual connection for deeper layers
            if i > 0:
                embeddings = embeddings + prev_embeddings
            prev_embeddings = embeddings
        ##################################################
        # Final prediction
        scores = torch.mm(embeddings, self.weight)
        return scores, label_scores

    def to_prob(self, nodes, labels, train_flag=True):
        gnn_scores, label_scores = self.forward(nodes, labels, train_flag)
        gnn_prob = nn.functional.softmax(gnn_scores, dim=1)
        label_prob = nn.functional.softmax(label_scores, dim=1)
        return gnn_prob, label_prob

    def loss(self, nodes, labels, train_flag=True):
        gnn_scores, label_scores = self.forward(nodes, labels, train_flag)

        # GNN loss
        gnn_loss = self.xent(gnn_scores, labels.squeeze())

        # Similarity loss (same as baseline)
        label_loss = self.xent(label_scores, labels.squeeze())

        # Combined loss
        final_loss = gnn_loss + self.lambda_1 * label_loss
        return final_loss


# Experiment

### Data Preparation

In [9]:
import numpy as np
import torch
import torch.nn as nn
from scipy.io import loadmat
import pickle
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import networkx as nx
import random
from matplotlib.patches import Patch

from torch.nn import init
import torch.nn.functional as F
from torch.autograd import Variable
from operator import itemgetter
import math

# Import your existing modules
from model import OneLayerCARE
from layers import IntraAgg, InterAgg
from utils import normalize, pos_neg_split, undersample, test_care


# Set seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)


In [12]:
np.__version__

'1.24.4'

In [None]:
print("=== Amazon Fraud Detection Dataset ===")

# Load Amazon.mat file directly
data_file = loadmat('data/Amazon.mat')
labels = data_file['label'].flatten()
feat_data = data_file['features'].todense().A


In [None]:


# Load the preprocessed adjacency lists
with open('data/amz_upu_adjlists.pickle', 'rb') as f:
    relation1 = pickle.load(f)  # User-Product-User
with open('data/amz_usu_adjlists.pickle', 'rb') as f:
    relation2 = pickle.load(f)  # User-Service-User
with open('data/amz_uvu_adjlists.pickle', 'rb') as f:
    relation3 = pickle.load(f)  # User-View-User

# Create homogeneous adjacency list (combine all relations)
homo = data_file['homo']

print(f"Dataset Statistics:")
print(f"Total nodes: {len(labels)}")
print(f"Fraudulent users: {np.sum(labels == 1)} ({np.mean(labels==1)*100:.2f}%)")
print(f"Feature dimensions: {feat_data.shape}")

print(f"\nRelation Statistics:")
print(f"User-Product-User nodes: {len(relation1)}")
print(f"User-Service-User nodes: {len(relation2)}")
print(f"User-View-User nodes: {len(relation3)}")

# ADD THIS BLOCK HERE (moved from Section 2):
print(f"\n=== Feature Preprocessing ===")
# Prepare features and adjacency lists
feat_data_normalized = normalize(feat_data)
features = nn.Embedding(feat_data.shape[0], feat_data.shape[1])
features.weight = nn.Parameter(torch.FloatTensor(feat_data_normalized), requires_grad=False)

adj_lists = [relation1, relation2, relation3]
print("✅ Features normalized and prepared")
print("✅ Adjacency lists organized")

# # Visualize class distribution
# plt.figure(figsize=(8, 4))
# plt.bar(['Legitimate', 'Fraudulent'], [np.sum(labels == 0), np.sum(labels == 1)])
# plt.title('Class Distribution in Amazon Dataset')
# plt.ylabel('Number of Users')
# plt.show()

## Create Model

### CARE-GNN Base Model ( paper)

In [None]:


# Model hyperparameters
embed_dim = 64
step_size = 0.02
lambda_1 = 2.0
batch_size = 256

print(f"Model Configuration:")
print(f"- Embedding dimension: {embed_dim}")
print(f"- RL step size: {step_size}")
print(f"- Similarity loss weight: {lambda_1}")
print(f"- Batch size: {batch_size}")

# Build CARE-GNN model (features and adj_lists already prepared in Section 1)
intra1 = IntraAgg(features, feat_data.shape[1], cuda=False)
intra2 = IntraAgg(features, feat_data.shape[1], cuda=False)
intra3 = IntraAgg(features, feat_data.shape[1], cuda=False)

inter = InterAgg(features, feat_data.shape[1], embed_dim, adj_lists,
                 [intra1, intra2, intra3], inter='GNN', step_size=step_size, cuda=False)

model = OneLayerCARE(num_classes=2, inter1=inter, lambda_1=lambda_1)

print(f"✅ CARE-GNN base model built successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

### Proposed Multi-Layer CARE-GNN with LeakyReLU and residual connections between embedding layers

In [None]:

model2_leaky = MultiLayerCARE_Leaky(
    num_classes=2,
    features=features,
    feat_dim=feat_data.shape[1],
    embed_dim=embed_dim,
    adj_lists=adj_lists,
    num_layers=2,
    lambda_1=lambda_1,
    step_size=step_size,
    cuda=False
)

In [None]:
print(type(model2_leaky))

In [None]:
model3_leaky = MultiLayerCARE_Leaky(
    num_classes=2,
    features=features,
    feat_dim=feat_data.shape[1],
    embed_dim=embed_dim,
    adj_lists=adj_lists,
    num_layers=3, # 3 layers
    lambda_1=lambda_1,
    step_size=step_size,
    cuda=False
)

In [None]:
print(type(model3_leaky))

# Ultils Functions for training and evaluation

# Train Model

In [None]:


print("=== Training Configuration ===")

# Create train/test split for Amazon dataset
# Amazon: first 3305 nodes are unlabeled
labeled_indices = list(range(3305, len(labels)))
labeled_labels = labels[3305:]

idx_train, idx_test, y_train, y_test = train_test_split(
    labeled_indices, labeled_labels,
    stratify=labeled_labels,
    test_size=0.6,
    random_state=42
)

print(f"Training samples: {len(idx_train)}")
print(f"Test samples: {len(idx_test)}")

# Split positive and negative samples for undersampling
train_pos, train_neg = pos_neg_split(idx_train, y_train)
print(f"Training - Positive: {len(train_pos)}, Negative: {len(train_neg)}")

In [None]:
def train_care_model(model, model_name, train_pos, train_neg, labels, 
                      learning_rate=0.01, weight_decay=1e-3, momentum=0.9, num_epochs=15, 
                      batch_size=256, scale=1, verbose=True, plot_loss=True):
    """
    General training function for CARE-GNN models (including multi-layer variants)
    
    Args:
        model: The CARE model to train (model2_leaky, model3_leaky, etc.)
        model_name: String name for the model (for logging and plotting)
        train_pos: List of positive training sample indices
        train_neg: List of negative training sample indices
        labels: Array of all node labels
        learning_rate: Learning rate for SGD optimizer
        weight_decay: Weight decay for SGD optimizer
        momentum: Momentum for SGD optimizer
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        scale: Scale factor for undersampling (1 means balanced)
        verbose: Whether to print training progress
        plot_loss: Whether to plot training loss curve
    
    Returns:
        train_losses: List of training losses per epoch
        trained_model: The trained model
    """
    
    if verbose:
        print(f"=== Training {model_name} ===")
    
    # Setup SGD optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, 
                               weight_decay=weight_decay, momentum=momentum)
    
    # Track losses
    train_losses = []
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        
        # Undersample negative samples for balanced training
        sampled_idx_train = undersample(train_pos, train_neg, scale=scale)
        random.shuffle(sampled_idx_train)
        
        # Calculate number of batches
        num_batches = (len(sampled_idx_train) + batch_size - 1) // batch_size
        
        # Set batch number for RL module (if model has inter1 with RL)
        if hasattr(model, 'inter1') and hasattr(model.inter1, 'batch_num'):
            model.inter1.batch_num = num_batches
        
        epoch_loss = 0.0
        
        # Batch training
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, len(sampled_idx_train))
            batch_nodes = sampled_idx_train[start_idx:end_idx]
            batch_labels = torch.LongTensor(labels[batch_nodes])
            
            optimizer.zero_grad()
            
            # Forward pass
            loss = model.loss(batch_nodes, batch_labels, train_flag=True)
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        # Calculate average loss for epoch
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)
        
        # Print progress
        if verbose and epoch % 3 == 0:
            print(f"Epoch {epoch:2d}: Loss = {avg_loss:.4f}")
    
    if verbose:
        print(f"✅ {model_name} training completed!")
    
    # Plot training loss
    if plot_loss:
        plt.figure(figsize=(8, 4))
        plt.plot(train_losses, label=model_name)
        plt.title(f'{model_name} Training Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.show()
    
    return train_losses, model

### TRAINING BASE CARE-GNN 

In [None]:
# Train 2-layer model with SGD
losses_2layer, trained_model2 = train_care_model(
    model=model2_leaky,
    model_name="Multi-Layer CARE (2 layers + LeakyReLU)",
    train_pos=train_pos,
    train_neg=train_neg,
    labels=labels,
    learning_rate=0.01,
    weight_decay=1e-3,
    momentum=0.9,        # SGD momentum
    num_epochs=15,
    batch_size=256
)

# Train 3-layer model with SGD
losses_3layer, trained_model3 = train_care_model(
    model=model3_leaky,
    model_name="Multi-Layer CARE (3 layers + LeakyReLU)",
    train_pos=train_pos,
    train_neg=train_neg,
    labels=labels,
    learning_rate=0.01,
    weight_decay=1e-3,
    momentum=0.9,        # SGD momentum
    num_epochs=15,
    batch_size=256
)

In [None]:
plt.figure(figsize=(12, 5))

# Training loss comparison
plt.subplot(1, 2, 1)
plt.plot(losses_2layer, label='2-Layer CARE + LeakyReLU', linewidth=2, color='blue')
plt.plot(losses_3layer, label='3-Layer CARE + LeakyReLU', linewidth=2, color='red')
plt.title('Training Loss Comparison (SGD)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

# Final loss comparison bar chart
plt.subplot(1, 2, 2)
models = ['2-Layer', '3-Layer']
final_losses = [losses_2layer[-1], losses_3layer[-1]]
colors = ['blue', 'red']
plt.bar(models, final_losses, color=colors, alpha=0.7)
plt.title('Final Training Loss')
plt.ylabel('Loss')
plt.grid(axis='y', alpha=0.3)

# Add value labels on bars
for i, v in enumerate(final_losses):
    plt.text(i, v + 0.001, f'{v:.4f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
# Evaluate 2-layer model
print("\n=== 2-Layer CARE Model Evaluation ===")
gnn_auc_2l, label_auc_2l, gnn_recall_2l, label_recall_2l = test_care(idx_test, y_test, trained_model2, batch_size)

print(f"2-Layer Model Results:")
print(f"  AUC:    {gnn_auc_2l:.4f}")
print(f"  Recall: {gnn_recall_2l:.4f}")

# Evaluate 3-layer model  
print("\n=== 3-Layer CARE Model Evaluation ===")
gnn_auc_3l, label_auc_3l, gnn_recall_3l, label_recall_3l = test_care(idx_test, y_test, trained_model3, batch_size)

print(f"3-Layer Model Results:")
print(f"  AUC:    {gnn_auc_3l:.4f}")
print(f"  Recall: {gnn_recall_3l:.4f}")

# Manual comparison plot
plt.figure(figsize=(12, 5))

# AUC comparison
plt.subplot(1, 2, 1)
models = ['2-Layer', '3-Layer']
aucs = [gnn_auc_2l, gnn_auc_3l]
plt.bar(models, aucs, color=['skyblue', 'lightcoral'])
plt.title('AUC Comparison')
plt.ylabel('AUC Score')
plt.ylim([0, 1])

# Recall comparison
plt.subplot(1, 2, 2)
recalls = [gnn_recall_2l, gnn_recall_3l]
plt.bar(models, recalls, color=['orange', 'lightgreen'])
plt.title('Recall Comparison')
plt.ylabel('Recall Score')
plt.ylim([0, 1])

plt.tight_layout()
plt.show()