In [1]:
from torch_geometric.data import HeteroData
from torch_geometric.transforms.random_node_split import RandomNodeSplit
import torch_geometric.transforms as T
from torch_geometric.nn import SAGEConv, to_hetero
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score
import torch
import matplotlib.pyplot as plt

In [2]:
data = torch.load('data/hetero_data.pt')
data = T.ToUndirected()(data)
data = T.NormalizeFeatures()(data)
split = RandomNodeSplit(num_test=0.33, num_val=0.2)

data = split(data)

In [3]:
data

HeteroData(
  incident={
    x=[141707, 51],
    y=[141707],
    train_mask=[141707],
    val_mask=[141707],
    test_mask=[141707],
  },
  support_org={ x=[141707, 17] },
  customer={ x=[141707, 16] },
  vendor={ x=[141707, 16] },
  (incident, assigned, support_org)={ edge_index=[2, 141707] },
  (incident, assigned, vendor)={ edge_index=[2, 141707] },
  (incident, reported, customer)={ edge_index=[2, 141707] },
  (support_org, rev_assigned, incident)={ edge_index=[2, 141707] },
  (vendor, rev_assigned, incident)={ edge_index=[2, 141707] },
  (customer, rev_reported, incident)={ edge_index=[2, 141707] }
)

In [4]:
class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels, aggr='sum')
        self.conv2 = SAGEConv((-1, -1), out_channels, aggr='sum')

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index).sigmoid()

        return x


gcn = GCN(hidden_channels=32, out_channels=2)
gcn = to_hetero(gcn, data.metadata(), aggr='sum')
optimizer = torch.optim.Adam(gcn.parameters(), lr=0.0001, weight_decay=5e-4)

In [5]:
def train(model, data, optimizer, loss, n_epochs=200):
    train_losses = []
    val_accs = []
    for epoch in tqdm(range(1, n_epochs + 1)):
        model.train()
        optimizer.zero_grad()
        out = model(data.x_dict, data.edge_index_dict)
        mask = data['incident'].train_mask
        loss_fn = F.cross_entropy(out['incident'][mask], data['incident'].y[mask].type(torch.LongTensor))
        loss_fn.backward(retain_graph=True)
        optimizer.step()

        val_acc = eval_node_classifier(model, data, data['incident'].val_mask)
        train_acc = accuracy_score(out['incident'][mask].argmax(dim=-1), data['incident'].y[mask].detach().numpy())
        if epoch % 20 == 0:
            print(f'\tEpoch: {epoch}, train_loss: {float(loss_fn):.5f}, train_acc: {train_acc:.5f}, val_acc: {val_acc:.5f}')

        train_losses.append(loss_fn)
        val_accs.append(val_acc)

    return train_losses, val_accs


def eval_node_classifier(model, data, mask):
    with torch.no_grad():
        model.eval()
        out = model(data.x_dict, data.edge_index_dict)['incident']
        pred = out.argmax(dim=-1)

    #     roc_score = roc_auc_score(pred, mask)
        acc_score = accuracy_score(pred, mask)

        return acc_score

In [6]:
train_losses, val_losses = train(gcn, data, optimizer, None, 100)

 20%|████████████████▏                                                                | 20/100 [00:49<03:25,  2.57s/it]

	Epoch: 20, train_loss: 0.69405, train_acc: 0.49483, val_acc: 0.79971


 40%|████████████████████████████████▍                                                | 40/100 [01:42<02:42,  2.71s/it]

	Epoch: 40, train_loss: 0.69365, train_acc: 0.49480, val_acc: 0.79971


 60%|████████████████████████████████████████████████▌                                | 60/100 [02:36<01:59,  2.99s/it]

	Epoch: 60, train_loss: 0.69336, train_acc: 0.49481, val_acc: 0.79973


 80%|████████████████████████████████████████████████████████████████▊                | 80/100 [03:38<01:11,  3.56s/it]

	Epoch: 80, train_loss: 0.69318, train_acc: 0.49484, val_acc: 0.79969


100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [04:31<00:00,  2.72s/it]

	Epoch: 100, train_loss: 0.69305, train_acc: 0.49789, val_acc: 0.78326





In [7]:
acc_score = eval_node_classifier(gcn, data, data['incident'].test_mask)
acc_score

0.6615410671314755