# G-Mixup can improve the performance of graph neural networks on various datasets

In [1]:
import random
import torch
import os.path as osp
from torch_geometric.datasets import TUDataset
from gmixup import prepare_dataset_onehot_y
from utils import stat_graph
import numpy as np
from graphon_estimator import largest_gap
from utils import split_class_graphs, align_graphs
from torch_geometric.loader import DataLoader
from gmixup import prepare_dataset_x
from utils import two_graphons_mixup
from models import GIN, GCN, DiffPoolNet, TopKNet, MinCutPoolNet
from gmixup import mixup_cross_entropy_loss

In [2]:
data_path = './'
dataset_names = ['REDDIT-BINARY', 'REDDIT-MULTI-5K', 'REDDIT-MULTI-12K']
models = ['GCN', 'GIN', 'MinCutPool', 'DiffPool', 'TopKPool']
epochs = 300
batch_size = 128
lr = 0.01
num_hidden = 64
seeds = [1314, 311098, 271296, 180562, 280466, 50832, 280433, 21022, 0, 546464]
no_test_runs = 10
lam_range = [0.1, 0.2]
aug_ratio = 0.2
aug_num = 10
augmentations = ['G-Mixup', 'Vanilla']

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running device: {device}')

Running device: cuda


In [4]:
def train(model, train_loader):
    model.train()
    loss_all = 0
    graph_all = 0
    correct = 0
    total = 0
    for data in train_loader:
        # print( "data.y", data.y )
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x, data.edge_index, data.batch)
        y = data.y.view(-1, num_classes)
        #print(y.size())
        #print(output.size())
        loss = mixup_cross_entropy_loss(output, y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        graph_all += data.num_graphs
        optimizer.step()
        y = y.max(dim=1)[1]
        pred = output.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs

    loss = loss_all / graph_all
    acc = correct / total
    return model, loss, acc

In [5]:
def test(model, loader):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    for data in loader:
        data = data.to(device)
        output = model(data.x, data.edge_index, data.batch)
        pred = output.max(dim=1)[1]
        y = data.y.view(-1, num_classes)
        loss += mixup_cross_entropy_loss(output, y).item() * data.num_graphs
        y = y.max(dim=1)[1]
        correct += pred.eq(y).sum().item()
        total += data.num_graphs
    acc = correct / total
    loss = loss / total
    return acc, loss

In [None]:
from torch.optim.lr_scheduler import StepLR

for dataset_name in dataset_names:
    path = osp.join(data_path, dataset_name)
    dataset = TUDataset(path, name=dataset_name)
    dataset = list(dataset)
    for graph in dataset:
        graph.y = graph.y.view(-1)

    dataset = prepare_dataset_onehot_y(dataset)
    train_nums = int(len(dataset) * 0.7)
    train_val_nums = int(len(dataset) * 0.8)

    avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(dataset[:train_nums])
    graphon_size = int(median_num_nodes)
    print(f"Avg num nodes of training graphs: {avg_num_nodes}")
    print(f"Avg num edges of training graphs: {avg_num_edges}")
    print(f"Avg density of training graphs: {avg_density}")
    print(f"Median num edges of training graphs: {median_num_edges}")
    print(f"Median density of training graphs: {median_density}")
    for model_name in models:
        for seed in seeds:
            torch.manual_seed(seed)
            random.seed(seed)
            random.shuffle(dataset)
            for aug in augmentations:
                random.shuffle(dataset)
                if aug == 'G-Mixup':
                    class_graphs = split_class_graphs(dataset[:train_nums])
                    graphons = []
                    for label, graphs in class_graphs:
                        align_graphs_list, normalized_node_degrees, max_num, min_num = align_graphs(
                            graphs, padding=True, N=graphon_size)
                        graphon = largest_gap(align_graphs_list, k=graphon_size)
                        graphons.append((label, graphon))

                    num_sample = int(train_nums * aug_ratio / aug_num)
                    lam_list = np.random.uniform(low=lam_range[0], high=lam_range[1], size=(aug_num,))

                    random.seed(seed)
                    new_graph = []
                    for lam in lam_list:
                        two_graphons = random.sample(graphons, 2)
                        new_graph += two_graphons_mixup(two_graphons, la=lam, num_sample=num_sample)

                    new_dataset = new_graph + dataset
                    new_train_nums = train_nums + len(new_graph)
                    new_train_val_nums = train_val_nums + len(new_graph)
                else:
                    new_dataset = dataset
                    new_train_nums = train_nums
                    new_train_val_nums = train_val_nums

                dataset = prepare_dataset_x(new_dataset)

                num_features = new_dataset[0].x.shape[1]
                num_classes = new_dataset[0].y.shape[0]

                # avg_num_nodes, avg_num_edges, avg_density, median_num_nodes, median_num_edges, median_density = stat_graph(new_dataset[:new_train_nums])
                # print(f"Avg num nodes of new training graphs: {avg_num_nodes}")
                # print(f"Avg num edges of new training graphs: {avg_num_edges}")
                # print(f"Avg density of new training graphs: {avg_density}")
                # print(f"Median num edges of new training graphs: {median_num_edges}")
                # print(f"Median density of new training graphs: {median_density}")
                train_dataset = new_dataset[:new_train_nums]
                random.shuffle(train_dataset)
                val_dataset = new_dataset[new_train_nums:new_train_val_nums]
                test_dataset = new_dataset[new_train_val_nums:]

                train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
                val_loader = DataLoader(val_dataset, batch_size=batch_size)
                test_loader = DataLoader(test_dataset, batch_size=batch_size)

                if model_name == "GIN":
                    model = GIN(num_features=num_features, num_classes=num_classes, num_hidden=num_hidden).to(device)
                elif model_name == "GCN":
                    model = GCN(in_channels=num_features, hidden_channels=num_hidden, out_channels=num_classes, num_layers=4).to(device)
                elif model_name == "TopKPool":
                    model = TopKNet(in_channels=num_features, hidden_channels=num_hidden, out_channels=num_classes).to(device)
                elif model_name == "DiffPool":
                    model = DiffPoolNet(in_channels=num_features, hidden_channels=num_hidden, out_channels=num_classes, max_nodes = median_num_nodes).to(device)
                elif model_name == "MinCutPool":
                    model = MinCutPoolNet(in_channels=num_features, hidden_channels=num_hidden, out_channels=num_classes, max_nodes = median_num_nodes).to(device)
                else:
                    model = None

                optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
                scheduler = StepLR(optimizer, step_size=100, gamma=0.5)

                max_val_acc = 0
                model_test_acc = 0
                model_test_loss = 0
                model_val_loss = 0
                best_epoch = 0
                train_losses = []
                val_losses = []
                test_losses = []
                for epoch in range(1, epochs):
                    model, train_loss, train_acc = train(model, train_loader)
                    val_acc, val_loss = test(model, val_loader)
                    test_acc, test_loss = test(model, test_loader)
                    scheduler.step()
                    train_losses.append(train_loss)
                    val_losses.append(val_loss)
                    test_losses.append(test_loss)
                    if val_acc > max_val_acc:
                        max_val_acc = val_acc
                        model_test_loss = test_loss
                        model_test_acc = test_acc
                        model_val_loss = val_loss
                        best_epoch = epoch
                    #if epoch%20==0:
                    #    print(
                    #        'Epoch: {:03d}, Train Loss: {:.6f}, Val Loss: {:.6f}, Test Loss: {:.6f}, Train acc: {: .6f}, Val Acc: {: .6f}, Test Acc: {: .6f}'.format(
                    #            epoch, train_loss, val_loss, test_loss, train_acc, val_acc, test_acc))

                with open('train_log.txt', 'a') as f:
                    f.write(f'Dataset: {dataset_name}, Model: {model_name}, Seed: {seed}, Aug: {aug}, Best epoch: {best_epoch}, Test acc: {model_test_acc}, Test loss: {model_test_loss}, Val acc: {max_val_acc}, Val loss: {model_val_loss}\n')
                if model_name == 'GCN':
                    with open('../results/losses.txt', 'a') as f:
                        f.write(f'{dataset_name}, {seed}, train, {train_losses}\n{dataset_name}, {seed}, val, {val_losses}\n{dataset_name}, {seed}, test, {test_losses}\n')
                print(f'Dataset: {dataset_name}, Model: {model_name}, Seed: {seed}, Aug: {aug}, Best epoch: {best_epoch}, Test acc: {model_test_acc}, Test loss: {model_test_loss}, Val acc: {max_val_acc}, Val loss: {model_val_loss}')

Avg num nodes of training graphs: 573.1435714285715
Avg num edges of training graphs: 662.7285714285714
Avg density of training graphs: 0.0020174781648775995
Median num edges of training graphs: 896.0
Median density of training graphs: 0.005845807052483408
Dataset: REDDIT-BINARY, Model: GCN, Seed: 1314, Aug: G-Mixup, Best epoch: 291, Test acc: 0.855, Test loss: 0.33758402824401856, Val acc: 0.89, Val loss: 0.31167617440223694
Dataset: REDDIT-BINARY, Model: GCN, Seed: 1314, Aug: Vanilla, Best epoch: 277, Test acc: 0.8661764705882353, Test loss: 0.357357855053509, Val acc: 0.915, Val loss: 0.3326137983798981
Dataset: REDDIT-BINARY, Model: GCN, Seed: 311098, Aug: G-Mixup, Best epoch: 122, Test acc: 0.8985294117647059, Test loss: 0.34148074633934916, Val acc: 0.955, Val loss: 0.2615780061483383
Dataset: REDDIT-BINARY, Model: GCN, Seed: 311098, Aug: Vanilla, Best epoch: 275, Test acc: 0.834375, Test loss: 0.4429880142211914, Val acc: 0.865, Val loss: 0.38562137603759766
Dataset: REDDIT-BINA