In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from tqdm.auto import trange

In [3]:
from gnnboundary import *

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Motif

In [19]:
motif = MotifDataset(seed=12345).to(device)
motif_train, motif_val = motif.train_test_split(k=10)
motif_model = GCNClassifier(node_features=len(motif.NODE_CLS),
                            num_classes=len(motif.GRAPH_CLS),
                            hidden_channels=6,
                            num_layers=3)
motif_model.to(device)

GCNClassifier(
  (conv): GCN(5, 6, num_layers=3)
  (drop): Dropout(p=0, inplace=False)
  (lin): Linear(12, 6, bias=True)
  (out): Linear(6, 4, bias=True)
)

In [12]:
patience = 20
best_val_acc = float('-inf')
early_stop_counter = 0

for epoch in trange(1):
    train_loss = motif_train.model_fit(motif_model, lr=0.001)
    train_metrics = motif_train.model_evaluate(motif_model)
    val_metrics = motif_val.model_evaluate(motif_model)
    # print(f"Epoch: {epoch:03d}, "
    #       f"Train Loss: {train_loss:.4f}, "
    #       f"Train Acc: {train_metrics['acc']:.4f}, "
    #       f"Test Acc: {val_metrics['acc']:.4f}, "
    #       f"Train F1: {train_metrics['f1']}, "
    #       f"Test F1: {val_metrics['f1']}")
    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        early_stop_counter = 0
        torch.save({
            'epoch': epoch,
            'model': motif_model.state_dict(),
            'val_acc': best_val_acc
        }, 'ckpts/ours/motif.pt')
    else:
        early_stop_counter += 1

    if early_stop_counter >= patience:
        break

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
# torch.save(motif_model.state_dict(), 'ckpts/motif.pt')

In [16]:
# motif_model.load_state_dict(torch.load('ckpts/ours/motif.pt')['model'])

  motif_model.load_state_dict(torch.load('ckpts/ours/motif.pt')['model'])


<All keys matched successfully>

# ENZYMES

In [4]:
enzymes = ENZYMESDataset(seed=12345).to(device)
enzymes_train, enzymes_val = enzymes.train_test_split(k=10)
enzymes_model = GCNClassifier(node_features=len(enzymes.NODE_CLS),
                              num_classes=len(enzymes.GRAPH_CLS),
                              hidden_channels=32,
                              num_layers=3)

In [5]:
#enzymes_model.load_state_dict(torch.load('ckpts/enzymes.pt'))

<All keys matched successfully>

In [7]:
patience = 30
best_val_acc = float('-inf')
early_stop_counter = 0

for epoch in range(4096):
    train_loss = enzymes_train.model_fit(enzymes_model, lr=0.0001)
    train_metrics = enzymes_train.model_evaluate(enzymes_model)
    val_metrics = enzymes_val.model_evaluate(enzymes_model)
    # print(f"Epoch: {epoch:03d}, "
    #       f"Train Loss: {train_loss:.4f}, "
    #       f"Train Acc: {train_metrics['acc']:.4f}, "
    #       f"Test Acc: {val_metrics['acc']:.4f}, "
    #       f"Train F1: {train_metrics['f1']}, "
    #       f"Test F1: {val_metrics['f1']}")

    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        early_stop_counter = 0
        torch.save({
            'epoch': epoch,
            'model': motif_model.state_dict(),
            'val_acc': best_val_acc
        }, 'ckpts/ours/enzymes.pt')
    else:
        early_stop_counter += 1

    if early_stop_counter >= patience:
        break

Epoch: 000, Train Loss: 0.8206, Train Acc: 0.7315, Test Acc: 0.5667, Train F1: {'EC1': 0.7159090638160706, 'EC2': 0.6107784509658813, 'EC3': 0.7914438247680664, 'EC4': 0.7745664715766907, 'EC5': 0.7900000214576721, 'EC6': 0.6892655491828918}, Test F1: {'EC1': 0.4761904776096344, 'EC2': 0.4444444477558136, 'EC3': 0.8333333134651184, 'EC4': 0.5600000023841858, 'EC5': 0.4444444477558136, 'EC6': 0.5714285969734192}
Epoch: 001, Train Loss: 0.8286, Train Acc: 0.7259, Test Acc: 0.5667, Train F1: {'EC1': 0.6708074808120728, 'EC2': 0.6741573214530945, 'EC3': 0.7755101919174194, 'EC4': 0.7748690843582153, 'EC5': 0.7597765326499939, 'EC6': 0.6857143044471741}, Test F1: {'EC1': 0.4444444477558136, 'EC2': 0.5714285969734192, 'EC3': 0.800000011920929, 'EC4': 0.5, 'EC5': 0.5333333611488342, 'EC6': 0.4615384638309479}
Epoch: 002, Train Loss: 0.8205, Train Acc: 0.7204, Test Acc: 0.5000, Train F1: {'EC1': 0.6946107745170593, 'EC2': 0.6289308071136475, 'EC3': 0.7932960987091064, 'EC4': 0.7885714173316956

In [8]:
#torch.save(enzymes_model.state_dict(), f"ckpts/enzymes.pt")

# COLLAB

In [None]:
collab = CollabDataset(seed=12345).to(device)
collab_train, collab_val = collab.train_test_split(k=10)
collab_model = GCNClassifier(node_features=len(collab.NODE_CLS),
                             num_classes=len(collab.GRAPH_CLS),
                             hidden_channels=64,
                             num_layers=5)

In [None]:
patience = 30
best_val_acc = float('-inf')
early_stop_counter = 0

for epoch in trange(1024):
    train_loss = collab_train.model_fit(collab_model, lr=0.001)
    train_metrics = collab_train.model_evaluate(collab_model)
    val_metrics = collab_val.model_evaluate(collab_model)
    # print(f"Epoch: {epoch:03d}, "
    #       f"Train Loss: {train_loss:.4f}, "
    #       f"Train Acc: {train_metrics['acc']:.4f}, "
    #       f"Test Acc: {val_metrics['acc']:.4f}, "
    #       f"Train F1: {train_metrics['f1']}, "
    #       f"Test F1: {val_metrics['f1']}")

    if val_metrics['acc'] > best_val_acc:
        best_val_acc = val_metrics['acc']
        early_stop_counter = 0
        torch.save({
            'epoch': epoch,
            'model': motif_model.state_dict(),
            'val_acc': best_val_acc
        }, 'ckpts/ours/collab.pt')
    else:
        early_stop_counter += 1

    if early_stop_counter >= patience:
        break

In [None]:
# torch.save(collab_model.state_dict(), f"ckpts/collab.pt")

In [None]:
# collab_model.load_state_dict(torch.load('ckpts/collab.pt'))