## Experiment Setting

In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import pickle
import argparse

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str,
                    help='Dataset')
parser.add_argument('--epoch', type=int, default=40,
                    help='Training Epochs')
parser.add_argument('--node_dim', type=int, default=64,
                    help='Node dimension')
parser.add_argument('--num_channels', type=int, default=2,
                    help='number of channels')
parser.add_argument('--lr', type=float, default=0.005,
                    help='learning rate')
parser.add_argument('--weight_decay', type=float, default=0.001,
                    help='l2 reg')
parser.add_argument('--num_layers', type=int, default=3,
                    help='number of layer')
parser.add_argument('--norm', type=str, default='true',
                    help='normalization')
parser.add_argument('--adaptive_lr', type=str, default='false',
                    help='adaptive learning rate')

args = parser.parse_args(['--dataset','ACM','--num_layers','2','--adaptive_lr','true'])
print(args)
epochs = args.epoch
node_dim = args.node_dim
num_channels = args.num_channels
lr = args.lr
weight_decay = args.weight_decay
num_layers = args.num_layers
norm = args.norm
adaptive_lr = args.adaptive_lr
device = torch.device('cuda')

Namespace(adaptive_lr='true', dataset='ACM', epoch=40, lr=0.005, node_dim=64, norm='true', num_channels=2, num_layers=2, weight_decay=0.001)


## Load Dataset

In [3]:
## open and load node features, egdes and labels
with open('data/'+args.dataset+'/node_features.pkl','rb') as f:
    node_features = pickle.load(f)
with open('data/'+args.dataset+'/edges.pkl','rb') as f:
    edges = pickle.load(f)
with open('data/'+args.dataset+'/labels.pkl','rb') as f:
    labels = pickle.load(f)

  edges = pickle.load(f)
  edges = pickle.load(f)


In [4]:
## get the number of nodes
num_nodes = edges[0].shape[0]

In [5]:
## concatenate each edge type's adjacency matrices, as well as the identity matrix
for i,edge in enumerate(edges):
    if i == 0:
        A = torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1)
    else:
        A = torch.cat([A, torch.from_numpy(edge.todense()).type(torch.FloatTensor).unsqueeze(-1)], dim=-1)
        
A = torch.cat([A, torch.eye(num_nodes).type(torch.FloatTensor).unsqueeze(-1)], dim=-1).to(device)

In [6]:
node_features = torch.from_numpy(node_features).type(torch.FloatTensor).to(device)
## creating trainig dataset
train_node = torch.from_numpy(np.array(labels[0])[:,0]).type(torch.LongTensor).to(device)
train_target = torch.from_numpy(np.array(labels[0])[:,1]).type(torch.LongTensor).to(device)

## creating validation dataset
valid_node = torch.from_numpy(np.array(labels[1])[:,0]).type(torch.LongTensor).to(device)
valid_target = torch.from_numpy(np.array(labels[1])[:,1]).type(torch.LongTensor).to(device)

## creating test dataset
test_node = torch.from_numpy(np.array(labels[2])[:,0]).type(torch.LongTensor).to(device)
test_target = torch.from_numpy(np.array(labels[2])[:,1]).type(torch.LongTensor).to(device)

In [7]:
## get number of classes
num_classes = torch.max(train_target).item()+1

## Model

In [8]:
class GTConv(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(GTConv, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        
        self.weight = nn.Parameter(torch.Tensor(output_channels, input_channels, 1, 1))
        self.bias = None
        
        self.reset_parameters()
        
    def reset_parameters(self):
        n = self.input_channels
        nn.init.constant_(self.weight, 0.1)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)
            
    def forward(self, A):
        attention_score = F.softmax(self.weight, dim=1)
        A = torch.sum(A*attention_score, dim=1)
        return A

In [9]:
class GTLayer(nn.Module):
    def __init__(self, input_channels, output_channels, first=True):
        super(GTLayer, self).__init__()
        self.input_channels = input_channels
        self.output_channels = output_channels
        self.first = first
        
        if self.first == True:
            self.conv1 = GTConv(input_channels, output_channels)
            self.conv2 = GTConv(input_channels, output_channels)
        else:
            self.conv1 = GTConv(input_channels, output_channels)
    
    def forward(self, A, H_normalized=None):
        if self.first == True:
            Q1 = self.conv1(A)
            Q2 = self.conv2(A)
            H = torch.bmm(Q1,Q2)
            W = [(F.softmax(self.conv1.weight, dim=1)).detach(),(F.softmax(self.conv2.weight, dim=1)).detach()]
        else:
            Q1 = self.conv1(A)
            H = torch.bmm(H_normalized,Q1)
            W = [(F.softmax(self.conv1.weight, dim=1)).detach()]
        return H,W

In [10]:
class GTN(nn.Module):
    def __init__(self, num_edge, num_channels, w_in, w_out, num_class, num_layers, norm):
        super(GTN, self).__init__()
        self.num_edge = num_edge
        self.num_channels = num_channels
        self.w_in = w_in
        self.w_out = w_out
        self.num_class = num_class
        self.num_layers = num_layers
        self.is_norm = norm
        
        layers = []
        for i in range(num_layers):
            if i==0:
                layers.append(GTLayer(num_edge, num_channels, first=True))
            else:
                layers.append(GTLayer(num_edge, num_channels, first=False))
        
        self.layers = nn.ModuleList(layers)
        self.weight = nn.Parameter(torch.Tensor(w_in, w_out))
        self.bias = nn.Parameter(torch.Tensor(w_out))
        self.loss = nn.CrossEntropyLoss()
        self.linear1 = nn.Linear(self.w_out*self.num_channels, self.w_out)
        self.linear2 = nn.Linear(self.w_out, self.num_class)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)
        
    def gcn_conv(self, H, X):
        H_normalized = self.norm(H, add=True)
        X = torch.mm(X, self.weight)
        return torch.mm(H_normalized.transpose(0,1), X)

    def normalization(self, H):
        for i in range(self.num_channels):
            if i==0:
                H_normalized = self.norm(H[i,:,:]).unsqueeze(0)
            else:
                H_normalized = torch.cat((H_normalized, self.norm(H[i,:,:]).unsqueeze(0)), dim=0)
        return H_normalized

    def norm(self, H, add=False):
        H = H.transpose(0,1)
        if add == False:
            H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)).to(device)
        else:
            H = H*((torch.eye(H.shape[0])==0).type(torch.FloatTensor)).to(device) + torch.eye(H.shape[0]).type(torch.FloatTensor).to(device)
        deg = torch.sum(H, dim=1)
        deg_inv = deg.pow(-1)
        deg_inv[deg_inv == float('inf')] = 0
        deg_inv = deg_inv*torch.eye(H.shape[0]).type(torch.FloatTensor).to(device)
        H = torch.mm(deg_inv,H)
        H = H.t()
        return H
    
    def forward(self, A, X, target_x, target):
        A = A.unsqueeze(0).permute(0,3,1,2) 
        Ws = []
        for i in range(self.num_layers):
            if i==0:
                H, W = self.layers[i](A)
            else:
                H = self.normalization(H)
                H, W = self.layers[i](A, H)
            Ws.append(W)
            
        for i in range(self.num_channels):
            if i==0:
                X_ = F.relu(self.gcn_conv(H[i], X))
            else:
                X_tmp = F.relu(self.gcn_conv(H[i], X))
                X_ = torch.cat((X_, X_tmp), dim=1)
                
        X_ = F.relu(self.linear1(X_))
        y = self.linear2(X_[target_x])
        loss = self.loss(y, target)
        return loss, y, Ws

## Metrics

In [11]:
def accuracy(pred, target):
    r"""Computes the accuracy of correct predictions.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
    :rtype: int
    """
    return (pred == target).sum().item() / target.numel()



def true_positive(pred, target, num_classes):
    r"""Computes the number of true positive predictions.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`LongTensor`
    """
    out = []
    for i in range(num_classes):
        out.append(((pred == i) & (target == i)).sum())

    return torch.tensor(out)



def true_negative(pred, target, num_classes):
    r"""Computes the number of true negative predictions.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`LongTensor`
    """
    out = []
    for i in range(num_classes):
        out.append(((pred != i) & (target != i)).sum())

    return torch.tensor(out)



def false_positive(pred, target, num_classes):
    r"""Computes the number of false positive predictions.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`LongTensor`
    """
    out = []
    for i in range(num_classes):
        out.append(((pred == i) & (target != i)).sum())

    return torch.tensor(out)



def false_negative(pred, target, num_classes):
    r"""Computes the number of false negative predictions.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`LongTensor`
    """
    out = []
    for i in range(num_classes):
        out.append(((pred != i) & (target == i)).sum())

    return torch.tensor(out)



def precision(pred, target, num_classes):
    r"""Computes the precision:
    :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FP}}`.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`Tensor`
    """
    tp = true_positive(pred, target, num_classes).to(torch.float)
    fp = false_positive(pred, target, num_classes).to(torch.float)

    out = tp / (tp + fp)
    out[torch.isnan(out)] = 0

    return out



def recall(pred, target, num_classes):
    r"""Computes the recall:
    :math:`\frac{\mathrm{TP}}{\mathrm{TP}+\mathrm{FN}}`.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`Tensor`
    """
    tp = true_positive(pred, target, num_classes).to(torch.float)
    fn = false_negative(pred, target, num_classes).to(torch.float)

    out = tp / (tp + fn)
    out[torch.isnan(out)] = 0

    return out



def f1_score(pred, target, num_classes):
    r"""Computes the :math:`F_1` score:
    :math:`2 \cdot \frac{\mathrm{precision} \cdot \mathrm{recall}}
    {\mathrm{precision}+\mathrm{recall}}`.
    Args:
        pred (Tensor): The predictions.
        target (Tensor): The targets.
        num_classes (int): The number of classes.
    :rtype: :class:`Tensor`
    """
    prec = precision(pred, target, num_classes)
    rec = recall(pred, target, num_classes)

    score = 2 * (prec * rec) / (prec + rec)
    score[torch.isnan(score)] = 0

    return score

## Training

In [12]:
model = GTN(num_edge=A.shape[-1],
            num_channels=num_channels,
            w_in=node_features.shape[1],
            w_out=node_dim,
            num_class=num_classes,
            num_layers=num_layers,
            norm=norm).to(device)

if adaptive_lr == 'false':
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
    optimizer = torch.optim.Adam([{'params':model.weight},
                                  {'params':model.linear1.parameters()},
                                  {'params':model.linear2.parameters()},
                                  {"params":model.layers.parameters(), "lr":0.5}
                                  ], lr=args.lr, weight_decay=args.weight_decay)
    
criterion = nn.CrossEntropyLoss()

best_train_loss = float("inf")
best_val_loss = float("inf")
best_test_loss = float("inf")        
best_train_f1 = 0
best_val_f1 = 0
best_test_f1 = 0

for i in range(epochs):
    for param_group in optimizer.param_groups:
        if param_group['lr'] > 0.005:
            param_group['lr'] = param_group['lr'] * 0.9
            
    print('Epoch:',i+1)
    
    model.zero_grad()
    model.train()
    loss, y_train, Ws = model(A, node_features, train_node, train_target)
    train_f1 = torch.mean(f1_score(torch.argmax(y_train.detach(),dim=1), train_target, num_classes=num_classes)).cpu().numpy()
    print('Train_Loss: {}, Macro_F1: {}'.format(loss.detach().cpu().numpy(), train_f1))
    loss.backward()
    optimizer.step()
    
    model.eval()
    with torch.no_grad():
        val_loss, y_valid,_ = model.forward(A, node_features, valid_node, valid_target)
        val_f1 = torch.mean(f1_score(torch.argmax(y_valid.detach(),dim=1), valid_target, num_classes=num_classes)).cpu().numpy()
        print('Valid_Loss: {}, Macro_F1: {}'.format(val_loss.detach().cpu().numpy(), val_f1))
        
        test_loss, y_test, W = model.forward(A, node_features, test_node, test_target)
        test_f1 = torch.mean(f1_score(torch.argmax(y_test.detach(),dim=1), test_target, num_classes=num_classes)).cpu().numpy()
        print('Test_Loss: {}, Macro_F1: {}\n'.format(test_loss.detach().cpu().numpy(), test_f1))
        
    if val_f1 > best_val_f1:
        best_train_loss = loss.detach().cpu().numpy()
        best_val_loss = val_loss.detach().cpu().numpy()
        best_test_loss = test_loss.detach().cpu().numpy()

        best_train_f1 = train_f1
        best_val_f1 = val_f1
        best_test_f1 = test_f1 
    
    print('---------------Best Results--------------------')
    print('Train_Loss: {}, Macro_F1: {}'.format(best_train_loss, best_train_f1))
    print('Valid_Loss: {}, Macro_F1: {}'.format(best_val_loss, best_val_f1))
    print('Test_Loss: {}, Macro_F1: {}\n'.format(best_test_loss, best_test_f1))

Epoch: 1
Train_Loss: 1.1012227535247803, Macro_F1: 0.2364683598279953
Valid_Loss: 1.0500946044921875, Macro_F1: 0.44932547211647034
Test_Loss: 1.0522735118865967, Macro_F1: 0.42141225934028625

---------------Best Results--------------------
Train_Loss: 1.1012227535247803, Macro_F1: 0.2364683598279953
Valid_Loss: 1.0500946044921875, Macro_F1: 0.44932547211647034
Test_Loss: 1.0522735118865967, Macro_F1: 0.42141225934028625

Epoch: 2
Train_Loss: 1.0467699766159058, Macro_F1: 0.47192850708961487
Valid_Loss: 0.9314001202583313, Macro_F1: 0.4733077585697174
Test_Loss: 0.9356404542922974, Macro_F1: 0.4496873915195465

---------------Best Results--------------------
Train_Loss: 1.0467699766159058, Macro_F1: 0.47192850708961487
Valid_Loss: 0.9314001202583313, Macro_F1: 0.4733077585697174
Test_Loss: 0.9356404542922974, Macro_F1: 0.4496873915195465

Epoch: 3
Train_Loss: 0.918605625629425, Macro_F1: 0.4781924784183502
Valid_Loss: 0.7829990386962891, Macro_F1: 0.8375387787818909
Test_Loss: 0.80195