In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch_geometric
import os
import torch
from tqdm import tqdm
import numpy as np

In [None]:
from graph_datasets import linear_mapping

In [None]:
from torch_geometric.data import HeteroData

for m_name in tqdm(os.listdir('/mnt/d/graph_dataset/raw')[10:]):
    g = torch.load(f'/mnt/d/graph_dataset/raw/{m_name}')
    model_dict = g.x_dict
    aggr = model_dict['aggregator']
    del model_dict['aggregator']
    features = linear_mapping(model_dict, 1024)
    
    data = HeteroData(aggregator={'x': aggr},
                      y=g.y)
    data['clients'].x = features
    
    for k, w in g.edge_index_dict.items():
        data[k].edge_index = w
    
    torch.save(data, f'/mnt/d/graph_dataset/raw/{m_name}')

In [3]:
from graph_datasets import HeteroGraphDataset

In [4]:
dataset = HeteroGraphDataset('/mnt/d/graph_dataset/')

Processing...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [00:10<00:00, 189.57it/s]
Done!


In [5]:
from torch.utils.data import DataLoader
from graph_utils import my_hetero_collate, separate

In [None]:
len(dataset)

In [6]:
train_loader = DataLoader(dataset[:1600], shuffle=True, batch_size=16, collate_fn=my_hetero_collate)
val_loader = DataLoader(dataset[1600:1800], shuffle=False, batch_size=32, collate_fn=my_hetero_collate)
test_loader = DataLoader(dataset[1800:2000], shuffle=False, batch_size=32, collate_fn=my_hetero_collate)

In [7]:
from hetero_gnn import HeteroConv, HeteroGNN, HeteroGNNHomofeatures

In [15]:
# model = HeteroGNN(feature_in_channels=128,
#                  aggr_in_channels=1,
#                  hidden_channels=128,
#                  out_channels=1,
#                  num_layers=3,
#                  feature_encode='mean').cuda()

model = HeteroGNNHomofeatures(feature_in_channels=1024,
                 aggr_in_channels=1,
                 hidden_channels=128,
                 out_channels=1,
                 num_layers=3,).cuda()

In [18]:
criterion = torch.nn.BCEWithLogitsLoss()

In [19]:
@torch.no_grad()
def validation(dataloader, model):
    corrects = 0
    counts = 0
    for i, data in enumerate(dataloader):
        # get the inputs; data is a list of [inputs, labels]
        data = data.cuda()
        # forward + backward + optimize
        outputs = model(data.x_dict, data.edge_index_dict)
        preds = (outputs > 0.).detach().to(torch.float)
        corrects += (preds == data.y).sum()
        counts += data.y.shape[0]
    
    return corrects / counts

In [24]:
epochs = 15

In [25]:
test_accs = []

In [26]:
runs = 5

In [27]:
for _ in range(runs):
    model.reset_parameters()
    optimizer = torch.optim.Adam(model.parameters(), lr=1.e-3)
    pbar = tqdm(range(epochs))
    for epoch in pbar:  # loop over the dataset multiple times
        losses = 0.
        counts = 0
        corrects = 0
        model.train()
        for i, data in enumerate(train_loader):
            # get the inputs; data is a list of [inputs, labels]
            data = data.cuda()
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(data.x_dict, data.edge_index_dict)
            loss = criterion(outputs, data.y)
            loss.backward()
            optimizer.step()

            losses += loss.item() * data.y.shape[0]
            counts += data.y.shape[0]
            preds = (outputs > 0.).detach().to(torch.float)
            corrects += (preds == data.y).sum()

        losses /= counts
        train_acc = corrects / counts
        
        model.eval()
        val_acc = validation(val_loader, model)

        pbar.set_postfix({'loss': losses, 'train_acc': train_acc, 'val_acc': val_acc})
    
    model.eval()
    test_acc = validation(test_loader, model)
    print(f'test acc: {test_acc}')
    test_accs.append(test_acc.cpu().item())

100%|█████████████████████████████████████████████████████| 15/15 [00:17<00:00,  1.16s/it, loss=1.8e-5, train_acc=tensor(1., device='cuda:0'), val_acc=tensor(0.9989, device='cuda:0')]


test acc: 0.9995999932289124


100%|████████████████████████████████████████████████████| 15/15 [00:17<00:00,  1.16s/it, loss=1.62e-5, train_acc=tensor(1., device='cuda:0'), val_acc=tensor(0.9986, device='cuda:0')]


test acc: 0.9990999698638916


100%|████████████████████████████████████████████████████| 15/15 [00:16<00:00,  1.11s/it, loss=2.53e-5, train_acc=tensor(1., device='cuda:0'), val_acc=tensor(0.9989, device='cuda:0')]


test acc: 0.9984999895095825


100%|████████████████████████████████████████████████████| 15/15 [00:17<00:00,  1.17s/it, loss=1.69e-5, train_acc=tensor(1., device='cuda:0'), val_acc=tensor(0.9994, device='cuda:0')]


test acc: 0.9994999766349792


100%|████████████████████████████████████████████████████| 15/15 [00:17<00:00,  1.15s/it, loss=3.13e-5, train_acc=tensor(1., device='cuda:0'), val_acc=tensor(0.9991, device='cuda:0')]

test acc: 0.9993999600410461





In [28]:
print(f'{np.mean(test_accs)} ± {np.std(test_accs)}')

0.9992199778556824 ± 0.0003969860763201743
