In [8]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt
import torch_geometric.transforms as T
from torch_geometric.datasets import DBLP
import torch.nn.functional as F
from utils.graph_polluters import remove_features


import numpy as np
from tqdm import trange
from modules.heteroGNN import HeteroGNN
from utils.set_seed import set_seed
from copy import deepcopy

dataset = DBLP('./data/dblp', transform=T.Constant(node_types='conference'))
data = dataset[0]

In [20]:
data

HeteroData(
  author={
    x=[4057, 334],
    y=[4057],
    train_mask=[4057],
    val_mask=[4057],
    test_mask=[4057],
  },
  paper={ x=[14328, 4231] },
  term={ x=[7723, 50] },
  conference={
    num_nodes=20,
    x=[20, 1],
  },
  (author, to, paper)={ edge_index=[2, 19645] },
  (paper, to, author)={ edge_index=[2, 19645] },
  (paper, to, term)={ edge_index=[2, 85810] },
  (paper, to, conference)={ edge_index=[2, 14328] },
  (term, to, paper)={ edge_index=[2, 85810] },
  (conference, to, paper)={ edge_index=[2, 14328] }
)

## Initialize Autoencoder

In [9]:
class AE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super().__init__()
        self.encoder = torch.nn.ModuleList()
        self.decoder = torch.nn.ModuleList()

        # Encoder
        for i in range(len(hidden_dims)):
            if i == 0:
                self.encoder.append(torch.nn.Linear(input_dim, hidden_dims[i]))
            else:
                self.encoder.append(torch.nn.Linear(hidden_dims[i-1], hidden_dims[i]))
            self.encoder.append(torch.nn.ReLU())
        
        # Decoder
        for i in reversed(range(len(hidden_dims))):
            if i == 0:
                self.encoder.append(torch.nn.Linear(hidden_dims[i], input_dim))
                self.encoder.append(torch.nn.Sigmoid())
            else:
                self.encoder.append(torch.nn.Linear(hidden_dims[i], hidden_dims[i-1]))
                self.encoder.append(torch.nn.ReLU())

        self.encoder = torch.nn.Sequential(*self.encoder)
        self.decoder = torch.nn.Sequential(*self.decoder)
            
 
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

## Pretrain Autoencoder

In [10]:
def train_ae(node_type, ae_hidden_dims, epochs=30):
    sparse_threshold = 10
    not_sparse = torch.sum(data[node_type].x, 1).to(torch.int) > sparse_threshold

    base_data = data[node_type].x[not_sparse]
    half_data = torch.where(torch.rand_like(base_data) < 0.5, torch.zeros_like(base_data), base_data)

    loader = DataLoader(
        TensorDataset(base_data, half_data), 
        batch_size=64, shuffle=True, pin_memory=True)

    ae = AE(base_data.shape[-1], ae_hidden_dims)
    loss_function = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(ae.parameters(),
                                lr = 1e-1,
                                weight_decay = 1e-8)

    outputs = []

    average_losses, average_accuracies = [], []
    for epoch in trange(epochs):
        epoch_losses, epoch_accuracies = [], []
        for base, half in loader:
            reconstructed = ae(base)
            maxed = torch.max(half, reconstructed)
            loss = loss_function(maxed, base)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            accuracy = (torch.round(reconstructed) == base).float().mean()
            epoch_accuracies.append(accuracy) 
            # Storing the losses in a list for plotting
            epoch_losses.append(loss.item())

        average_losses.append(np.mean(epoch_losses))
        average_accuracies.append(np.mean(epoch_accuracies))
        outputs.append((epochs, base, reconstructed))
    return ae

In [11]:

author_ae_dims = [128, 64, 36, 18]
paper_ae_dims = [512,128,64,32]
author_ae = train_ae('author', ae_hidden_dims = author_ae_dims)
paper_ae = train_ae('paper', ae_hidden_dims = paper_ae_dims)

100%|██████████| 30/30 [00:01<00:00, 16.67it/s]
100%|██████████| 30/30 [00:05<00:00,  5.95it/s]


# AE GNN Code

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(data, model, optimizer):
    model.train()
    optimizer.zero_grad()
    out, filtered = model(data.x_dict, data.edge_index_dict)
    mask = data['author'].train_mask
    loss = F.cross_entropy(out[mask], data['author'].y[mask])
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(data, model):
    model.eval()
    pred, filtered = model(data.x_dict, data.edge_index_dict)
    pred = pred.argmax(dim=-1)

    accs = []
    for split in ['train_mask', 'val_mask', 'test_mask']:
        mask = data['author'][split]
        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()
        accs.append(float(acc))
    return accs

In [13]:
class AEGNN(torch.nn.Module):
    def __init__(self, ae, gnn):
        super().__init__()
        self.autoencoders = ae
        self.gnn = gnn
 
    def forward(self, x_dict, edge_index_dict):
        for node_type, autoencoder in self.autoencoders.items():
            filtered = autoencoder(x_dict[node_type])
            x_dict[node_type] = filtered
        return self.gnn(x_dict, edge_index_dict), filtered

## Encode-Decode-GNN Pipeline

In [15]:
set_seed()
dataset_copy = dataset.copy()
data_copy = dataset_copy[0]
data_copy = remove_features(data_copy, 0.5)

gnn = HeteroGNN(data_copy.metadata(), hidden_channels=10, out_channels=4, num_layers=2, target_node_type='author')
gnn = gnn.to(device)

autoencoders = dict()
autoencoders['author'] = AE(data_copy['author'].x.shape[-1], author_ae_dims).to(device)
autoencoders['author'].load_state_dict(author_ae.state_dict())
autoencoders['paper'] = AE(data_copy['paper'].x.shape[-1], paper_ae_dims).to(device)
autoencoders['paper'].load_state_dict(paper_ae.state_dict())

model = AEGNN(autoencoders, gnn)
data_copy, model = data_copy.to(device), model.to(device)



with torch.no_grad():  # Initialize lazy modules.
    out = model(data_copy.x_dict, data_copy.edge_index_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)

train_accs, val_accs, test_accs = [],[],[]
for epoch in range(1, 100):
    loss = train(data=data_copy, model=model, optimizer=optimizer)
    train_acc, val_acc, test_acc = test(data = data_copy, model=model)
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    test_accs.append(test_acc)
best_epoch = max(enumerate(val_accs),key=lambda x: x[1])[0]
train_acc, val_acc, test_acc = train_accs[best_epoch], val_accs[best_epoch], test_accs[best_epoch]
print(f'End 2 End, Train: {train_acc:.4f}, '
        f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')

End 2 End, Train: 0.6900, Val: 0.5725, Test: 0.6334
