## Dataset

In [None]:
from Libs.Datasets.load_dataset import load_UPFD
from torch_geometric import seed_everything

seed_everything(3407)

K = 3
name = 'politifact'
feature = 'spacy'
batch_size = 32
train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = load_UPFD(
    name, feature, batch_size=batch_size, K=K)

In [None]:
sample_id = 8
sample_g = train_dataset[sample_id]
sample_g, sample_g.partitions

## Model

In [None]:
import torch
import torch.nn.functional as F
from torch.nn import Linear, Dropout, MultiheadAttention
from torch_geometric.nn import global_mean_pool, GATv2Conv, SAGEConv, GCNConv, global_max_pool
import torch.nn as nn
from einops.layers.torch import Rearrange


class GNNLayer(nn.Module):
    def __init__(self, conv, act=F.relu, dropout=0.0):
        super().__init__()
        self.conv = conv
        self.act = act
        self.dropout = Dropout(dropout)

    def forward(self, x, edge_index):
        x = self.conv(x, edge_index)
        if self.act:
            x = self.act(x)
        x = self.dropout(x)
        return x


class GraphEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, pe_dim, num_heads=8, dropout=0.0):
        super(GraphEncoder, self).__init__()
        self.conv = GNNLayer(
            SAGEConv(in_channels + pe_dim, hidden_channels), dropout=dropout)
        self.pe_norm = nn.BatchNorm1d(pe_dim)
        self.pe_lin = Linear(pe_dim, pe_dim)

        self.layers = nn.ModuleList([
            self.conv,
        ])

    def forward(self, x, edge_index, pe):
        pe_norm = self.pe_norm(pe)
        # Combine node features with positional encoding
        x = torch.cat((x, self.pe_lin(pe_norm)), dim=1)
        for layer in self.layers:
            x = layer(x, edge_index)
        return x


class GraphEncoderwithoutPE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_heads=8, dropout=0.0):
        super(GraphEncoderwithoutPE, self).__init__()
        self.conv = GNNLayer(
            SAGEConv(in_channels, hidden_channels), dropout=dropout)

        self.layers = nn.ModuleList([
            self.conv,
        ])

    def forward(self, x, edge_index):
        for layer in self.layers:
            x = layer(x, edge_index)
        return x


class Identity(nn.Module):
    def __init__(self, *args, **kwargs):
        super(Identity, self).__init__()

    def forward(self, input):
        return input

    def reset_parameters(self):
        pass


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            # nn.Linear(dim, hidden_dim),
            # nn.Dropout(dropout),
            nn.GELU(),
            # nn.Dropout(dropout),
            # nn.Linear(hidden_dim, dim),
            # nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


class MixerBlock(nn.Module):

    def __init__(self, num_features, num_partitions, token_dim, channel_dim, dropout=0.1):
        super().__init__()
        self.token_mix = nn.Sequential(
            nn.LayerNorm(num_features),
            Rearrange('p d -> d p'),
            FeedForward(num_partitions, token_dim, dropout),
            Rearrange('d p -> p d'),
        )
        self.channel_mix = nn.Sequential(
            # nn.LayerNorm(num_features),
            FeedForward(num_features, channel_dim, dropout),
        )

    def forward(self, x):
        x = x + self.token_mix(x)
        x = x + self.channel_mix(x)
        return x


class MLPMixer(nn.Module):
    def __init__(self,
                 hidden_channels,
                 num_partitions,
                 num_layers=3,
                 with_final_norm=True,
                 dropout=0.1):
        super().__init__()
        self.num_partitions = num_partitions
        self.with_final_norm = with_final_norm
        self.mixer_blocks = nn.ModuleList(
            [MixerBlock(hidden_channels, self.num_partitions, hidden_channels*4, hidden_channels//2, dropout=dropout) for _ in range(num_layers)])
        if self.with_final_norm:
            self.layer_norm = nn.LayerNorm(hidden_channels)

    def forward(self, x):
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        if self.with_final_norm:
            x = self.layer_norm(x)
        return x


class AttentionMixerBlock(nn.Module):
    def __init__(self, num_features, num_partitions, token_dim, channel_dim, num_heads=8, dropout=0.1):
        super().__init__()
        self.num_partitions = num_partitions
        self.attention = MultiheadAttention(
            embed_dim=num_features, num_heads=num_heads, dropout=dropout, batch_first=True)
        self.token_mix = nn.Sequential(
            Rearrange('p d -> d p'),
            FeedForward(num_partitions, token_dim, dropout),
            Rearrange('d p -> p d'),
        )
        self.channel_mix = FeedForward(num_features, channel_dim, dropout)
        self.norm1 = nn.LayerNorm(num_features)
        self.norm2 = nn.LayerNorm(num_features)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Attention
        x = x + self.dropout(self.attention(self.norm1(x),
                             self.norm1(x), self.norm1(x))[0])
        # Token mixing
        x = x + self.token_mix(x)
        # Channel mixing
        x = x + self.channel_mix(self.norm2(x))
        return x


class MHAMixer(nn.Module):
    def __init__(self, hidden_channels, num_partitions, num_heads=8, num_layers=3, with_final_norm=True, dropout=0.1):
        super().__init__()
        self.num_partitions = num_partitions
        self.with_final_norm = with_final_norm
        self.mixer_blocks = nn.ModuleList([
            AttentionMixerBlock(hidden_channels, num_partitions,
                                hidden_channels*4, hidden_channels//2, num_heads, dropout)
            for _ in range(num_layers)])
        if with_final_norm:
            self.layer_norm = nn.LayerzNorm(hidden_channels)

    def forward(self, x):
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        if self.with_final_norm:
            x = self.layer_norm(x)
        return x


class Classifier(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=2, with_final_activation=True, with_norm=False):
        super().__init__()
        hidden_channels = in_channels
        layers = []
        norms = []
        for i in range(num_layers):
            out_channels = hidden_channels if i < num_layers-1 else out_channels
            layers.append(nn.Linear(
                in_channels if i == 0 else hidden_channels, out_channels, bias=not with_norm))
            if with_norm:
                norms.append(nn.BatchNorm1d(out_channels))
            else:
                norms.append(Identity())
            in_channels = out_channels
        self.layers = nn.ModuleList(layers)
        self.norms = nn.ModuleList(norms)
        self.activation = F.relu if with_final_activation else Identity()

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x)
            x = norm(x)
            x = self.activation(x)
        return x


class GraphMixerModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_partitions, pe_dim):
        super(GraphMixerModel, self).__init__()
        self.encoder = GraphEncoder(in_channels, hidden_channels, pe_dim)
        self.mixer = MLPMixer(hidden_channels, num_partitions,
                              num_layers=1, with_final_norm=False)
        # self.mixer = MHAMixer(hidden_channels, num_partitions, num_layers=1)
        self.pool = global_max_pool
        # self.cls = Classifier(hidden_channels, out_channels)
        self.cls = Linear(hidden_channels, out_channels)

    def forward(self, batch):
        # print(f"Input: {batch}")
        # Process each partition
        partition_embeddings = []
        for partition in batch.partitions:
            partition_x = self.encoder(
                partition.x, partition.edge_index, partition.random_walk_pe)
            # print(f"GNN: {partition_x.shape}")
            # Global average pooling for each partitions
            partition_embedding = self.pool(partition_x, partition.batch)
            # print(f"READOUT: {partition_embedding.shape}")
            partition_embeddings.append(partition_embedding)

        # Stack embeddings and apply mixer
        embeddings = torch.stack(partition_embeddings, dim=0).squeeze(1)
        # # print(f"Stacked: {embeddings.shape}")
        mixed_embeddings = self.mixer(embeddings)
        # print(f"Mixer: {mixed_embeddings.shape}")
        # # Generate the final graph-level representation by mean pooling
        global_embedding = torch.mean(mixed_embeddings, dim=0)
        # print(f"Pooling: {global_embedding.shape}")
        # Final linear layer
        out = self.cls(global_embedding)
        # print(f"Out: {out.shape}")
        return F.log_softmax(out, dim=-1)


class NaiveGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(NaiveGNN, self).__init__()
        self.conv = SAGEConv(in_channels, hidden_channels)
        self.pool = global_max_pool
        self.cls = Linear(hidden_channels, out_channels)

    def forward(self, data):
        # print(f"Input: {batch}")
        x = self.conv(data.x, data.edge_index)
        # Generate the final graph-level representation by mean pooling
        x = self.pool(x, data.batch)
        # Final linear layer
        out = self.cls(x)
        # print(f"Out: {out.shape}")
        return F.log_softmax(out, dim=1)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_partitions = K
in_channels = train_dataset[0].num_node_features
hidden_channels = 64
out_channels = 2
pe_dim = train_dataset[0].partitions[0].random_walk_pe.shape[1]
model = GraphMixerModel(in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels,
                        num_partitions=num_partitions, pe_dim=pe_dim).to(device)
# model = NaiveGNN(
#     in_channels=in_channels, out_channels=out_channels, hidden_channels=hidden_channels).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=100, T_mult=2, eta_min=1e-6)
criterion = F.nll_loss

In [None]:
print('Use device:', device)
print(model)

## Training

In [None]:
from sklearn.metrics import accuracy_score, f1_score
import torch


def train(dataset, model, optimizer, scheduler, criterion, device):
    model.train()  # 将模型设置为训练模式
    total_loss = []

    for data in dataset:
        data = data.to(device)
        optimizer.zero_grad()  # 清除旧的梯度
        outputs = model(data).view(1, 2)  # 正向传播
        loss = criterion(outputs, data.y)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新权重

        total_loss.append(loss.item())

    # 调整学习率
    scheduler.step()

    # 计算并返回平均损失
    return sum(total_loss) / len(dataset)


@torch.no_grad()  # 确保在此函数中不计算梯度
def test(dataset, model, device):
    model.eval()  # 将模型设置为评估模式
    y_pred = []
    y_true = []

    for data in dataset:
        data = data.to(device)
        outputs = model(data).argmax(dim=-1)
        y_pred.append(outputs.cpu().numpy())  # 将预测结果转换为列表并添加到y_pred
        y_true.extend(data.y.cpu().numpy())  # 将真实标签转换为列表并添加到y_true

    # 计算F1分数
    f1 = f1_score(y_true, y_pred)

    return f1


def run_exp(train_dataset, val_dataset, test_dataset, model, optimizer, scheduler, criterion, device, num_epochs=25, patience=10):

    best_val_f1 = 0.0  # 最佳验证F1分数
    patience_counter = 0  # 早停计数器

    for epoch in range(1, num_epochs + 1):
        # 训练阶段
        loss = train(train_dataset, model, optimizer,
                     scheduler, criterion, device)

        # 评估阶段
        train_f1 = test(train_dataset, model, device)
        val_f1 = test(val_dataset, model, device)
        test_f1 = test(test_dataset, model, device)

        # 模型检查点保存逻辑
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(),
                       f'./Output/Models/ckpt-{val_f1:.4f}.pt')
            print(f"Best model saved with F1: {val_f1:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1

        # 早停逻辑
        if patience_counter > patience:
            print("Early stopping triggered")
            break

        # 打印性能指标
        print(f'Epoch: {epoch:03d} | Loss: {loss:.4f} | Train F1: {train_f1:.4f} | Val F1: {val_f1:.4f} | Test F1: {test_f1:.4f}')

In [None]:
run_exp(train_dataset, val_dataset, test_dataset, model, optimizer,
        scheduler, criterion, device, num_epochs=10, patience=5)