In [1]:
from GNNProject.utils import *
from GNNProject.dataset import *
from GNNProject.model import *
from GNNProject.classifier import *

In [2]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
else:  
  dev = "cpu"  
device = torch.device(dev)  

## Generate a synthetic dataset

In [3]:
n_classes=3
n_features=100
n_char_features=10
n_obs_train=500
n_obs_test=2000
epochs=20

In [4]:
dataset = Dataset(tag='EXP1')

dataset.create_syn(n_classes = n_classes, 
                   n_obs_train = n_obs_train, 
                   n_obs_test= n_obs_test, 
                   n_features=n_features,
                   n_char_features = n_char_features,
                   noise = [.2, .2], 
                   model ='BA',
                   syn_method="sign")

dataset.create_graph(alphas=0.5)


train_dataloader = dataset._dataloader('train',use_true_graph=True,batch_size=16)
test_dataloader  = dataset._dataloader('test',use_true_graph=True,batch_size=16)

## Fit and evaluate a GNN

In [5]:
clf = Classifier(n_features=n_features,
        n_classes=n_classes,
        n_hidden_GNN=[8],
        n_hidden_FC=[],
        dropout_FC=0.2,
        dropout_GNN=0.2,
        K=2,
        classifier='GraphSAGE', 
        lr=.001,
        momentum=.9,
        log_dir="runs/GNN_TrueGraph",
        device = device) 

clf.fit(train_dataloader, epochs = epochs, test_dataloader=test_dataloader,verbose=True)
_ = clf.eval(test_dataloader, verbose=True)

[1] loss: 1.087
[3] loss: 0.616
[5] loss: 0.488
[7] loss: 0.461
[9] loss: 0.439
[11] loss: 0.449
[13] loss: 0.404
[15] loss: 0.381
[17] loss: 0.379
[19] loss: 0.314
Accuracy: 0.814
Confusion Matrix:
 [[1737  122  141]
 [ 256 1545  199]
 [ 243  157 1600]]
Precision: 0.816
Recall: 0.814
f1_score: 0.813


## Fit and evaluate a MLP

In [6]:
clf = Classifier(n_features=n_features,
        n_classes=n_classes,
        n_hidden_GNN=[],
        n_hidden_FC=[40],
        dropout_FC=0.2,
        classifier='MLP', 
        lr=.001, 
        momentum=.9,
        log_dir="runs/MLP")
clf.fit(train_dataloader, epochs = epochs, test_dataloader=test_dataloader,verbose=True)
_ = clf.eval(test_dataloader, verbose=True)

[1] loss: 1.111
[3] loss: 1.011
[5] loss: 0.872
[7] loss: 0.729
[9] loss: 0.619
[11] loss: 0.533
[13] loss: 0.484
[15] loss: 0.396
[17] loss: 0.344
[19] loss: 0.328
Accuracy: 0.456
Confusion Matrix:
 [[959 532 509]
 [601 872 527]
 [591 504 905]]
Precision: 0.456
Recall: 0.456
f1_score: 0.456
