In [1]:
import sys
sys.path.append('../')  

from data.graph_sampler import GraphSampler
from data.load_data import read_graphfile
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
from utils.operations import *
from utils.graph_processing import *
from tqdm import tqdm
from models.gan import Generator, Discriminator
from models.encoder import Encoder
import torch.nn.functional as F
import os
from sklearn.metrics import roc_auc_score

In [2]:

DATADIR = '../datasets/'
DS = 'NCI1'
CHECKPOINT_DIR = '../checkpoints'
FEAT = 'deg'    # deg or default
BATCH_SIZE = 300

In [3]:
graphs = read_graphfile(DATADIR, DS, max_nodes=0)  

graphs_train, graphs_test = train_test_split(graphs, test_size=0.2, random_state=42)
graphs_train_labels = [graph.graph['label'] for graph in graphs_train]

# filtering out abnormal graphs for training
graphs_train = [graph for graph, label in zip(graphs_train, graphs_train_labels) if label == 0]

max_num_nodes = max([G.number_of_nodes() for G in graphs])
dataset_sampler_train = GraphSampler(graphs_train, features=FEAT, normalize=False, max_num_nodes=max_num_nodes)
dataset_sampler_test = GraphSampler(graphs_test, features=FEAT, normalize=False, max_num_nodes=max_num_nodes)

data_loader_train = DataLoader(dataset_sampler_train, shuffle=True, batch_size=BATCH_SIZE)
data_loader_test = DataLoader(dataset_sampler_test, shuffle=True, batch_size=BATCH_SIZE)

No node attributes


In [4]:
MAX_NODES = max_num_nodes
WGAN_EPOCHS = 200000
EPOCHS_DECAY = 100000
LR_UPDATE_STEP = 1000
PATIENCE = 200   # number of epochs to wait for improvement before stopping
RESUME_TRAINING = False

ENCODER_SAVE_ITERS = [50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 30000, 40000, 50000]
WGAN_SAVE_ITERS = [50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 30000, 40000, 50000, 100000, 150000]

G_LR = 2e-3
D_LR = 2e-3
BETA1 = 0.5
BETA2 = 0.99

G_CONV_DIM = [128,256,512]
D_CONV_DIM = [dataset_sampler_train.feat_dim, [128, 64]]     # updated later
D_AGGR_DIM = 128
D_LINEAR_DIM = [128, 64]
Z_DIM = 8

DROPOUT = 0.
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
GUMBELL_TYPE = 'hard-gumbell'
LAMBDA_GP = 10  # penalty coefficient
N_CRITIC = 5    # number of D updates per each G update

In [5]:
G = Generator(G_CONV_DIM, Z_DIM, MAX_NODES, dataset_sampler_train.feat_dim, DROPOUT).to(DEVICE)
D = Discriminator(D_CONV_DIM, D_AGGR_DIM, D_LINEAR_DIM, DROPOUT).to(DEVICE)
g_optimizer = torch.optim.Adam(G.parameters(), G_LR, [BETA1, BETA2])
d_optimizer = torch.optim.Adam(D.parameters(), D_LR, [BETA1, BETA2])

if RESUME_TRAINING:
    START_EPOCH, _ = load_checkpoint(G, g_optimizer, os.path.join(CHECKPOINT_DIR, 'G_checkpoint_latest.pth'))
    _, _ = load_checkpoint(D, d_optimizer, os.path.join(CHECKPOINT_DIR, 'D_checkpoint_latest.pth'))
else:
    START_EPOCH = 0

WGAN Train Loop

In [6]:
G.train()
D.train()

best_gloss = np.inf
epochs_no_improve = 0

with tqdm(total=WGAN_EPOCHS, desc="WGAN-GP Training Progress") as pbar:

    avg_dloss = 0
    avg_gloss = 0

    for epoch in range(START_EPOCH, WGAN_EPOCHS):
        
        total_dloss = 0
        total_gloss = 0
        batch_count = 0

        # min-max game
        for batch in data_loader_train:
            
            adj = batch['adj'].float().clone().to(DEVICE)
            x = batch['feat'].float().clone().to(DEVICE)

            current_batch_size = adj.shape[0]
            z = sample_z(current_batch_size, Z_DIM).to(DEVICE)

            #===============#
            # discriminator #
            #===============#

            # real graphs loss
            real_logits, _ = D(x, adj)
            dloss_real = - torch.mean(real_logits)

            # generate graphs
            adj_logits, x_hat = G(z)
            adj_hat = process_adj(adj_logits, GUMBELL_TYPE)

            # fake graphs loss
            fake_logits, _ = D(x_hat, adj_hat)
            dloss_fake = torch.mean(fake_logits)

            # gradient penalty from WGAN-GP
            # small adaptation -> macro penalty computed as sum of generated nodes and edges micro penalties
            eps = torch.rand(adj.shape[0], 1, 1).to(DEVICE)
            x_int0 = (eps * x + (1. - eps) * x_hat).requires_grad_(True)    # nodes
            x_int1 = (eps * adj + (1. - eps) * adj_hat).requires_grad_(True)    # edges
            grad0, grad1 = D(x_int0, x_int1)
            dloss_gp = gradient_penalty(grad0, x_int0, DEVICE) + gradient_penalty(grad1, x_int1, DEVICE)

            dloss = dloss_fake + dloss_real + LAMBDA_GP * dloss_gp

            # backward and optimize
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            dloss.backward()
            d_optimizer.step()

            #===========#
            # generator #
            #===========#

            if epoch % N_CRITIC == 0:

                adj_logits, x_hat = G(z)
                adj_hat = process_adj(adj_logits, GUMBELL_TYPE)

                fake_logits, _ = D(x_hat, adj_hat)
                gloss = - torch.mean(fake_logits)

                # backward and optimize
                g_optimizer.zero_grad()
                d_optimizer.zero_grad()

                gloss.backward()
                g_optimizer.step()

            total_dloss += dloss.item()
            total_gloss += gloss.item()
            batch_count += 1

        avg_dloss = total_dloss / batch_count
        avg_gloss = total_gloss / batch_count

        pbar.set_postfix({'Generator Loss': avg_gloss, 'Discriminator Loss': avg_dloss})
        pbar.update(1)  # move progress bar forward

        # early stopping
        if avg_gloss < best_gloss:
            best_gloss = avg_gloss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= PATIENCE:
            print(f"Early stopping triggered after {epoch+1} epochs due to no improvement in generator loss.")
            break

        # decay learning rate
        if (epoch+1) % LR_UPDATE_STEP == 0 and (epoch+1) > (WGAN_EPOCHS - EPOCHS_DECAY):
            G_LR -= G_LR / float(EPOCHS_DECAY)
            D_LR -= D_LR / float(EPOCHS_DECAY)
            update_lr(g_optimizer, d_optimizer, G_LR, D_LR)

        # save checkpoint
        if (epoch + 1) in WGAN_SAVE_ITERS: 
            save_checkpoint(G, g_optimizer, epoch + 1, avg_gloss, os.path.join(CHECKPOINT_DIR, 'G_checkpoint_{}.pth'.format(epoch + 1)))
            save_checkpoint(D, d_optimizer, epoch + 1, avg_dloss, os.path.join(CHECKPOINT_DIR, 'D_checkpoint_{}.pth'.format(epoch + 1)))

    save_checkpoint(G, g_optimizer, WGAN_EPOCHS, avg_gloss, os.path.join(CHECKPOINT_DIR, 'G_checkpoint_final.pth'))
    save_checkpoint(D, d_optimizer, WGAN_EPOCHS, avg_dloss, os.path.join(CHECKPOINT_DIR, 'D_checkpoint_final.pth'))


WGAN-GP Training Progress:   0%|          | 14/200000 [00:17<68:42:55,  1.24s/it, Generator Loss=0.537, Discriminator Loss=89.3]   


KeyboardInterrupt: 

In [7]:
E_LINEAR_DIM = [128, 64]
E_LR = 2e-3
ENCODER_EPOCHS = 3000
KAPPA = 1.0

In [8]:
E = Encoder(dataset_sampler_train.feat_dim, E_LINEAR_DIM, Z_DIM, DROPOUT).to(DEVICE)
e_optimizer =torch.optim.Adam(E.parameters(), E_LR, [BETA1, BETA2])

Encoder Train Loop

In [11]:
G.eval()
D.eval()
E.train()

with tqdm(total=ENCODER_EPOCHS, desc="Encoder Training Progress") as pbar:

    avg_eloss = 0

    for epoch in range(ENCODER_EPOCHS):
        
        total_eloss = 0
        batch_count = 0

        # discriminator guided gzg_f approach (graph_real, z_hat, graph_fake)
        for batch in data_loader_train:
            
            adj = batch['adj'].float().clone().to(DEVICE)
            x = batch['feat'].float().clone().to(DEVICE)

            z_hat = E(x, adj)
            
            # weights for G and D remain fixed
            adj_logits, x_tilde = G(z_hat)
            adj_tilde = process_adj(adj_logits, GUMBELL_TYPE)

            _, real_emb = D(x, adj)
            _, fake_emb = D(x_tilde, adj_tilde)

            adj_loss = F.mse_loss(adj, adj_tilde)
            x_loss = F.mse_loss(x, x_tilde)
            guided_dloss = F.mse_loss(real_emb, fake_emb)

            eloss = adj_loss + x_loss + KAPPA*guided_dloss

            total_eloss += eloss.item()
            batch_count += 1

            # backward and optimize
            e_optimizer.zero_grad()
            eloss.backward()
            e_optimizer.step()

        avg_eloss = total_eloss / batch_count

        pbar.set_postfix({'Encoder Loss': avg_eloss})
        pbar.update(1)  # move progress bar forward

    save_checkpoint(E, e_optimizer, ENCODER_EPOCHS, avg_eloss, os.path.join(CHECKPOINT_DIR, 'E_checkpoint_final.pth'))

Encoder Training Progress:   1%|          | 16/3000 [00:05<15:32,  3.20it/s, Encoder Loss=15.8]


KeyboardInterrupt: 

In [None]:
G.eval()
E.eval()
D.eval()

all_anomaly_scores = []
all_labels = []

for batch in data_loader_test:
    
    adj = batch['adj'].float().clone().to(DEVICE)
    x = batch['feat'].float().clone().to(DEVICE)
    label = batch['label'].float().clone.to(DEVICE)

    with torch.no_grad():
        z_hat = E(x, adj)

        adj_logits, x_tilde = G(z_hat)
        adj_tilde = process_adj(adj_logits, GUMBELL_TYPE)
        
        _, real_emb = D(x, adj)
        _, fake_emb = D(x_tilde, adj_tilde)
        
        z_tilde = E(x_tilde, adj_tilde)

    # compute graph distance
    adj_loss_per_graph = F.mse_loss(adj, adj_tilde, reduction='none').mean(dim=(1, 2))  
    x_loss_per_graph = F.mse_loss(x, x_tilde, reduction='none').mean(dim=(1, 2))  
    guided_dloss_per_graph = F.mse_loss(real_emb, fake_emb, reduction='none').mean(dim=1)  
    graph_distance_per_graph = adj_loss_per_graph + x_loss_per_graph + KAPPA * guided_dloss_per_graph
    
    # compute z distance
    z_distance_per_graph = F.mse_loss(z_hat, z_tilde, reduction='none').mean(dim=1) 
    
    anomaly_score_per_graph = graph_distance_per_graph + z_distance_per_graph

    # store the scores and ground truth labels 
    all_anomaly_scores.extend(anomaly_score_per_graph.cpu().numpy())
    all_labels.extend(label.cpu().numpy())


THRESHOLD = np.percentile(all_anomaly_scores, 95)

predictions = [1 if score > THRESHOLD else 0 for score in all_anomaly_scores]  # 1: anomalous, 0: normal

auc_score = roc_auc_score(all_labels, predictions) 
print(f"AUC Score: {auc_score}")



IndentationError: unexpected indent (2122480592.py, line 18)