In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from tqdm.auto import trange

In [4]:
from gnnboundary import *

# Reddit

In [16]:
reddit = RedditDataset(seed=69)
reddit_train, reddit_val = reddit.train_test_split(k=10)
reddit_model = GCNClassifier(node_features=len(reddit.NODE_CLS),
                            num_classes=len(reddit.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=5)

In [None]:
patience = 10
best_val_acc = float('-inf')
counter = 0
lr = 0.004

for epoch in range(100):
    train_loss = reddit_train.model_fit(reddit_model, lr=lr)
    train_metrics = reddit_train.model_evaluate(reddit_model)
    val_metrics = reddit_val.model_evaluate(reddit_model)
    
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
         )

    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        counter = 0
        torch.save({
            'epoch': epoch,
            'model': reddit_model.state_dict(),
            'val_acc': best_val_acc
        }, 'ckpts/ours/reddit.pt')
    else:
        counter += 1

    if counter >= patience:
        lr = lr / 2
        counter = 0

In [None]:
a = torch.load('ckpts/ours/reddit.pt')
a['epoch'], a['val_acc']

In [None]:
reddit2 = MultiRedditDataset(seed=69)
reddit2_train, reddit2_val = reddit2.train_test_split(k=10)
reddit2_model = GCNClassifier(node_features=len(reddit2.NODE_CLS),
                            num_classes=len(reddit2.GRAPH_CLS),
                            hidden_channels=64,
                            num_layers=5)

In [None]:
patience = 10
best_val_acc = float('-inf')
counter = 0
lr = 0.004

for epoch in range(100):
    train_loss = reddit_train.model_fit(reddit_model, lr=lr)
    train_metrics = reddit_train.model_evaluate(reddit_model)
    val_metrics = reddit_val.model_evaluate(reddit_model)
    
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Lr: {lr:.4f}, "
         )

    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        counter = 0
        torch.save({
            'epoch': epoch,
            'model': reddit_model.state_dict(),
            'val_acc': best_val_acc
        }, 'ckpts/ours/reddit.pt')
    else:
        counter += 1

    if counter >= patience:
        lr = lr / 2
        counter = 0

# Motif

In [None]:
motif = MotifDataset(seed=12345)
# motif_train, motif_val = motif.train_test_split(k=10)
# motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
#                             num_classes=len(motif.GRAPH_CLS),
#                             hidden_channels=6,
#                             num_layers=3)

In [None]:
for epoch in trange(128):
    train_loss = motif_train.model_fit(motif_model, lr=0.001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

In [None]:
motif_model.load_state_dict(torch.load('ckpts/motif.pt'))

# ENZYMES

In [4]:
enzymes = ENZYMESDataset(seed=12345)
enzymes_train, enzymes_val = enzymes.train_test_split(k=10)
enzymes_model = GCNClassifier(node_features=len(enzymes.NODE_CLS),
                              num_classes=len(enzymes.GRAPH_CLS),
                              hidden_channels=32,
                              num_layers=3)

In [None]:
enzymes_model.load_state_dict(torch.load('ckpts/enzymes.pt'))

In [None]:
for epoch in range(4096):
    train_loss = enzymes_train.model_fit(enzymes_model, lr=0.0001)
    train_metrics = enzymes_train.model_evaluate(enzymes_model)
    val_metrics = enzymes_val.model_evaluate(enzymes_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [8]:
torch.save(enzymes_model.state_dict(), f"ckpts/enzymes.pt")

# COLLAB

In [None]:
collab = CollabDataset(seed=12345)
collab_train, collab_val = collab.train_test_split(k=10)
collab_model = GCNClassifier(node_features=len(collab.NODE_CLS),
                             num_classes=len(collab.GRAPH_CLS),
                             hidden_channels=64,
                             num_layers=5)

In [None]:
for epoch in trange(1024):
    train_loss = collab_train.model_fit(collab_model, lr=0.001)
    train_metrics = collab_train.model_evaluate(collab_model)
    val_metrics = collab_val.model_evaluate(collab_model)
    print(f"Epoch: {epoch:03d}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Train Acc: {train_metrics['acc']:.4f}, "
          f"Test Acc: {val_metrics['acc']:.4f}, "
          f"Train F1: {train_metrics['f1']}, "
          f"Test F1: {val_metrics['f1']}")

In [None]:
# torch.save(collab_model.state_dict(), f"ckpts/collab.pt")

In [None]:
collab_model.load_state_dict(torch.load('ckpts/collab.pt'))