In [1]:
#Credit: Benjamin Chang, University of Toronto, https://github.com/lolzballs
import torch
import torch.autograd
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
import numpy as np

In [2]:
class GCN(nn.Module):
    def __init__(self, in_features, out_features, bias=False):
        super(GCN, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ',' \
               + str(self.out_features) + ')'
    
    def reset_parameters(self):
        stdv = 1. / self.weight.size(1) ** 1/2
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    
    # H, feature matrix
    # A, precomputed adj matrix
    def forward(self, H, A):
        n = torch.mm(A, torch.mm(H, self.weight))
        if self.bias is not None:
            return n + self.bias
        else:
            return n


In [3]:
# n-layer GCN Network
class Net(nn.Module):
    def __init__(self, in_features, body_features, out_features, n_layers, activation, bias=False):
        super(Net, self).__init__()
        assert(n_layers >= 2)
        self.activation = activation
        
        self.head = GCN(in_features, body_features, bias)
        self.layers = nn.ModuleList()
        for i in range(n_layers - 2):
            self.layers.append(GCN(body_features, body_features, bias))
        self.tail = GCN(body_features, out_features, bias)

    def forward(self, x, A):
        x = self.activation(self.head(x, A))
        for layer in self.layers:
            x = self.activation(layer(x, A))
        x = self.tail(x, A)
        return x

In [4]:
def create_A(data):
    adj = torch.eye(data.num_nodes, data.num_nodes)
    adj[data.edge_index[0,:], data.edge_index[1,:]] += 1
    deg = adj.sum(dim=1) ** (-1/2)
    D = torch.diag(deg)
    return D.mm(adj).mm(D)

# def create_A(data, mask=slice(None)):
#     adj = torch.eye(data.num_nodes, data.num_nodes)
#     adj[data.edge_index[0,:], data.edge_index[1,:]] += 1

#     masked = adj[mask,:][:,mask]
#     deg = masked.sum(dim=1) ** (-0.5)
#     D = torch.diag(deg)
#     return D.mm(masked).mm(D)

In [5]:
def masked_accuracy(pred, labels, mask):
    return (pred.argmax(dim=1) == labels)[mask].sum().item() / mask.sum().item()

In [6]:
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [7]:
print(data)

Data(edge_index=[2, 10556], test_mask=[2708], train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])


In [8]:
A = create_A(data)

train_labels = torch.where(data.train_mask, data.y, torch.tensor(-100))
val_labels = torch.where(data.val_mask, data.y, torch.tensor(-100))
test_labels = torch.where(data.test_mask, data.y, torch.tensor(-100))

print('training samples: ', data.train_mask.sum().item())
print('validation samples: ', data.val_mask.sum().item())
print('test samples: ', data.test_mask.sum().item())

training samples:  140
validation samples:  500
test samples:  1000


In [15]:
print(train_labels.shape)
print(data.train_mask.shape)
print(data.x.shape)

torch.Size([2708])
torch.Size([2708])
torch.Size([2708, 1433])


In [122]:
parameters = {
    'features': dataset.num_features,
    'body': 64,
    'classes': dataset.num_classes,
    'num_epochs': 2000,
    'learning_rate': 1e-2,
    'weight_decay': 5e-3
}
print(parameters)

{'features': 1433, 'body': 64, 'classes': 7, 'num_epochs': 2000, 'learning_rate': 0.01, 'weight_decay': 0.005}


In [123]:
torch.manual_seed(0)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

num_epochs = parameters['num_epochs']
A = A.to(device)
model = Net(parameters['features'], parameters['body'], parameters['classes'], 4, F.relu, bias=True).to(device)
input_features = data.x.to(device)
train_labels = train_labels.to(device)
val_labels = val_labels.to(device)
test_labels = test_labels.to(device)
train_mask = data.train_mask.to(device)
val_mask = data.val_mask.to(device)
test_mask = data.test_mask.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=parameters['learning_rate'], weight_decay=parameters['weight_decay'])
criterion = nn.CrossEntropyLoss()


for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    train_pred = model(input_features, A)
    train_loss = criterion(train_pred, train_labels)
    train_loss.backward()
    optimizer.step()
    train_acc = masked_accuracy(train_pred, train_labels, train_mask)
    
    model.eval()
    val_pred = model(input_features, A)
    val_loss = criterion(val_pred, val_labels)
    val_acc = masked_accuracy(val_pred, val_labels, val_mask)
    print("{}: \ttrain loss {}\tacc {:2f}\tval loss {}\tacc {:2f}".format(epoch, train_loss, train_acc, val_loss, val_acc))

1: 	train loss 1.9471932649612427	acc 0.142857	val loss 1.9459331035614014	acc 0.072000
2: 	train loss 1.9465835094451904	acc 0.142857	val loss 1.9463955163955688	acc 0.072000
3: 	train loss 1.945975661277771	acc 0.142857	val loss 1.9473267793655396	acc 0.114000
4: 	train loss 1.9454419612884521	acc 0.135714	val loss 1.9484270811080933	acc 0.122000
5: 	train loss 1.9449105262756348	acc 0.135714	val loss 1.9494330883026123	acc 0.154000
6: 	train loss 1.9443680047988892	acc 0.164286	val loss 1.950026512145996	acc 0.162000
7: 	train loss 1.9438023567199707	acc 0.142857	val loss 1.9501606225967407	acc 0.162000
8: 	train loss 1.9432144165039062	acc 0.142857	val loss 1.9497679471969604	acc 0.122000
9: 	train loss 1.942678451538086	acc 0.142857	val loss 1.9489516019821167	acc 0.122000
10: 	train loss 1.9422454833984375	acc 0.142857	val loss 1.948244571685791	acc 0.122000
11: 	train loss 1.9418559074401855	acc 0.142857	val loss 1.9478813409805298	acc 0.122000
12: 	train loss 1.9414736032485962

108: 	train loss 1.922305941581726	acc 0.171429	val loss 1.9389420747756958	acc 0.162000
109: 	train loss 1.922268271446228	acc 0.150000	val loss 1.9396083354949951	acc 0.166000
110: 	train loss 1.9222261905670166	acc 0.171429	val loss 1.93868088722229	acc 0.162000
111: 	train loss 1.9221888780593872	acc 0.150000	val loss 1.9396535158157349	acc 0.166000
112: 	train loss 1.9221512079238892	acc 0.178571	val loss 1.9389125108718872	acc 0.164000
113: 	train loss 1.9221110343933105	acc 0.157143	val loss 1.9390881061553955	acc 0.168000
114: 	train loss 1.9220715761184692	acc 0.171429	val loss 1.9394906759262085	acc 0.166000
115: 	train loss 1.922031283378601	acc 0.171429	val loss 1.9386820793151855	acc 0.166000
116: 	train loss 1.9220038652420044	acc 0.142857	val loss 1.939483404159546	acc 0.164000
117: 	train loss 1.9219681024551392	acc 0.171429	val loss 1.9390493631362915	acc 0.162000
118: 	train loss 1.9219368696212769	acc 0.164286	val loss 1.9388737678527832	acc 0.170000
119: 	train loss

212: 	train loss 1.9210989475250244	acc 0.171429	val loss 1.939563274383545	acc 0.162000
213: 	train loss 1.9209657907485962	acc 0.157143	val loss 1.9392752647399902	acc 0.166000
214: 	train loss 1.9209741353988647	acc 0.178571	val loss 1.9375724792480469	acc 0.164000
215: 	train loss 1.9210702180862427	acc 0.171429	val loss 1.940535306930542	acc 0.166000
216: 	train loss 1.9209344387054443	acc 0.171429	val loss 1.9374321699142456	acc 0.180000
217: 	train loss 1.9210612773895264	acc 0.164286	val loss 1.939461350440979	acc 0.160000
218: 	train loss 1.9209721088409424	acc 0.157143	val loss 1.939320683479309	acc 0.162000
219: 	train loss 1.9209684133529663	acc 0.178571	val loss 1.9375821352005005	acc 0.178000
220: 	train loss 1.9210479259490967	acc 0.164286	val loss 1.9402620792388916	acc 0.166000
221: 	train loss 1.9209429025650024	acc 0.178571	val loss 1.9380437135696411	acc 0.164000
222: 	train loss 1.9210237264633179	acc 0.178571	val loss 1.9387348890304565	acc 0.158000
223: 	train lo

317: 	train loss 1.920892596244812	acc 0.150000	val loss 1.9373570680618286	acc 0.164000
318: 	train loss 1.9210447072982788	acc 0.178571	val loss 1.9399652481079102	acc 0.160000
319: 	train loss 1.9209058284759521	acc 0.142857	val loss 1.9384664297103882	acc 0.162000
320: 	train loss 1.9209620952606201	acc 0.157143	val loss 1.9384260177612305	acc 0.160000
321: 	train loss 1.9209766387939453	acc 0.178571	val loss 1.9398232698440552	acc 0.160000
322: 	train loss 1.920906662940979	acc 0.150000	val loss 1.937751054763794	acc 0.162000
323: 	train loss 1.9210189580917358	acc 0.178571	val loss 1.9396271705627441	acc 0.162000
324: 	train loss 1.9209132194519043	acc 0.150000	val loss 1.9386521577835083	acc 0.164000
325: 	train loss 1.9209610223770142	acc 0.157143	val loss 1.938441514968872	acc 0.162000
326: 	train loss 1.9209694862365723	acc 0.164286	val loss 1.9396058320999146	acc 0.162000
327: 	train loss 1.9209147691726685	acc 0.150000	val loss 1.9380675554275513	acc 0.160000
328: 	train lo

422: 	train loss 1.9209413528442383	acc 0.164286	val loss 1.9387538433074951	acc 0.164000
423: 	train loss 1.9209413528442383	acc 0.157143	val loss 1.9388843774795532	acc 0.162000
424: 	train loss 1.9209415912628174	acc 0.150000	val loss 1.9388327598571777	acc 0.164000
425: 	train loss 1.9209409952163696	acc 0.164286	val loss 1.938792109489441	acc 0.164000
426: 	train loss 1.9209418296813965	acc 0.150000	val loss 1.9389045238494873	acc 0.164000
427: 	train loss 1.9209405183792114	acc 0.164286	val loss 1.938762903213501	acc 0.164000
428: 	train loss 1.9209423065185547	acc 0.157143	val loss 1.9388878345489502	acc 0.166000
429: 	train loss 1.9209405183792114	acc 0.150000	val loss 1.9388115406036377	acc 0.164000
430: 	train loss 1.9209423065185547	acc 0.157143	val loss 1.9388306140899658	acc 0.162000
431: 	train loss 1.9209400415420532	acc 0.150000	val loss 1.9388535022735596	acc 0.164000
432: 	train loss 1.9209433794021606	acc 0.164286	val loss 1.9388208389282227	acc 0.166000
433: 	train 

527: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.9388771057128906	acc 0.162000
528: 	train loss 1.9209392070770264	acc 0.150000	val loss 1.9388078451156616	acc 0.164000
529: 	train loss 1.9209420680999756	acc 0.157143	val loss 1.9388312101364136	acc 0.164000
530: 	train loss 1.9209407567977905	acc 0.157143	val loss 1.9388638734817505	acc 0.164000
531: 	train loss 1.9209398031234741	acc 0.150000	val loss 1.9387974739074707	acc 0.164000
532: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.938858151435852	acc 0.164000
533: 	train loss 1.9209392070770264	acc 0.150000	val loss 1.9388357400894165	acc 0.164000
534: 	train loss 1.9209420680999756	acc 0.157143	val loss 1.9388116598129272	acc 0.164000
535: 	train loss 1.9209409952163696	acc 0.157143	val loss 1.9388666152954102	acc 0.164000
536: 	train loss 1.9209400415420532	acc 0.157143	val loss 1.9388080835342407	acc 0.164000
537: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.9388476610183716	acc 0.164000
538: 	train

632: 	train loss 1.9209444522857666	acc 0.157143	val loss 1.938900113105774	acc 0.162000
633: 	train loss 1.9209381341934204	acc 0.150000	val loss 1.9387753009796143	acc 0.164000
634: 	train loss 1.9209457635879517	acc 0.164286	val loss 1.9388881921768188	acc 0.162000
635: 	train loss 1.920937180519104	acc 0.150000	val loss 1.9387835264205933	acc 0.164000
636: 	train loss 1.9209468364715576	acc 0.157143	val loss 1.9388959407806396	acc 0.164000
637: 	train loss 1.9209357500076294	acc 0.157143	val loss 1.938751220703125	acc 0.164000
638: 	train loss 1.9209487438201904	acc 0.150000	val loss 1.9389538764953613	acc 0.164000
639: 	train loss 1.9209328889846802	acc 0.164286	val loss 1.9386727809906006	acc 0.162000
640: 	train loss 1.9209520816802979	acc 0.150000	val loss 1.9390454292297363	acc 0.164000
641: 	train loss 1.9209297895431519	acc 0.164286	val loss 1.9385775327682495	acc 0.162000
642: 	train loss 1.920957326889038	acc 0.150000	val loss 1.9391409158706665	acc 0.164000
643: 	train lo

737: 	train loss 1.9209418296813965	acc 0.157143	val loss 1.938840627670288	acc 0.166000
738: 	train loss 1.9209402799606323	acc 0.150000	val loss 1.9388233423233032	acc 0.164000
739: 	train loss 1.9209423065185547	acc 0.157143	val loss 1.9388508796691895	acc 0.166000
740: 	train loss 1.9209398031234741	acc 0.150000	val loss 1.9388158321380615	acc 0.164000
741: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.9388554096221924	acc 0.166000
742: 	train loss 1.9209396839141846	acc 0.150000	val loss 1.9388127326965332	acc 0.164000
743: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.9388575553894043	acc 0.164000
744: 	train loss 1.9209394454956055	acc 0.150000	val loss 1.938811182975769	acc 0.164000
745: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.938859462738037	acc 0.164000
746: 	train loss 1.9209394454956055	acc 0.157143	val loss 1.938808798789978	acc 0.164000
747: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.9388625621795654	acc 0.164000
748: 	train lo

844: 	train loss 1.9209433794021606	acc 0.157143	val loss 1.9388806819915771	acc 0.164000
845: 	train loss 1.920937418937683	acc 0.164286	val loss 1.938788652420044	acc 0.166000
846: 	train loss 1.9209446907043457	acc 0.157143	val loss 1.9388762712478638	acc 0.164000
847: 	train loss 1.9209376573562622	acc 0.164286	val loss 1.9387989044189453	acc 0.164000
848: 	train loss 1.9209436178207397	acc 0.157143	val loss 1.9388628005981445	acc 0.164000
849: 	train loss 1.9209392070770264	acc 0.157143	val loss 1.9388154745101929	acc 0.164000
850: 	train loss 1.9209411144256592	acc 0.157143	val loss 1.9388434886932373	acc 0.166000
851: 	train loss 1.9209415912628174	acc 0.150000	val loss 1.9388368129730225	acc 0.164000
852: 	train loss 1.9209394454956055	acc 0.157143	val loss 1.938822865486145	acc 0.164000
853: 	train loss 1.9209426641464233	acc 0.150000	val loss 1.9388527870178223	acc 0.164000
854: 	train loss 1.9209389686584473	acc 0.164286	val loss 1.938815712928772	acc 0.164000
855: 	train lo

948: 	train loss 1.9208978414535522	acc 0.142857	val loss 1.9375368356704712	acc 0.164000
949: 	train loss 1.9210249185562134	acc 0.178571	val loss 1.93954598903656	acc 0.160000
950: 	train loss 1.920926570892334	acc 0.150000	val loss 1.9388877153396606	acc 0.164000
951: 	train loss 1.9209486246109009	acc 0.164286	val loss 1.9381427764892578	acc 0.162000
952: 	train loss 1.9209891557693481	acc 0.164286	val loss 1.9399651288986206	acc 0.158000
953: 	train loss 1.9209128618240356	acc 0.150000	val loss 1.9376481771469116	acc 0.164000
954: 	train loss 1.9210199117660522	acc 0.178571	val loss 1.9398095607757568	acc 0.160000
955: 	train loss 1.9209152460098267	acc 0.150000	val loss 1.9383306503295898	acc 0.160000
956: 	train loss 1.9209810495376587	acc 0.164286	val loss 1.9388214349746704	acc 0.164000
957: 	train loss 1.9209502935409546	acc 0.164286	val loss 1.939300298690796	acc 0.160000
958: 	train loss 1.9209333658218384	acc 0.157143	val loss 1.9381085634231567	acc 0.160000
959: 	train lo

1055: 	train loss 1.9209413528442383	acc 0.178571	val loss 1.9388056993484497	acc 0.160000
1056: 	train loss 1.9209474325180054	acc 0.171429	val loss 1.9387760162353516	acc 0.160000
1057: 	train loss 1.9209418296813965	acc 0.164286	val loss 1.9389350414276123	acc 0.158000
1058: 	train loss 1.9209470748901367	acc 0.164286	val loss 1.9387476444244385	acc 0.162000
1059: 	train loss 1.9209411144256592	acc 0.164286	val loss 1.9388610124588013	acc 0.162000
1060: 	train loss 1.9209450483322144	acc 0.157143	val loss 1.93888258934021	acc 0.158000
1061: 	train loss 1.9209415912628174	acc 0.150000	val loss 1.9387258291244507	acc 0.160000
1062: 	train loss 1.9209448099136353	acc 0.164286	val loss 1.9389628171920776	acc 0.158000
1063: 	train loss 1.9209424257278442	acc 0.164286	val loss 1.9387319087982178	acc 0.160000
1064: 	train loss 1.9209439754486084	acc 0.164286	val loss 1.938878059387207	acc 0.162000
1065: 	train loss 1.9209415912628174	acc 0.150000	val loss 1.9388560056686401	acc 0.164000
10

1162: 	train loss 1.920937180519104	acc 0.157143	val loss 1.9387725591659546	acc 0.164000
1163: 	train loss 1.9209442138671875	acc 0.157143	val loss 1.938881516456604	acc 0.164000
1164: 	train loss 1.9209383726119995	acc 0.150000	val loss 1.9388030767440796	acc 0.164000
1165: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.9388513565063477	acc 0.164000
1166: 	train loss 1.9209398031234741	acc 0.157143	val loss 1.938829779624939	acc 0.164000
1167: 	train loss 1.9209407567977905	acc 0.157143	val loss 1.9388278722763062	acc 0.164000
1168: 	train loss 1.9209415912628174	acc 0.157143	val loss 1.9388543367385864	acc 0.164000
1169: 	train loss 1.9209396839141846	acc 0.150000	val loss 1.9388024806976318	acc 0.164000
1170: 	train loss 1.9209429025650024	acc 0.157143	val loss 1.9388760328292847	acc 0.166000
1171: 	train loss 1.920938491821289	acc 0.150000	val loss 1.9387925863265991	acc 0.164000
1172: 	train loss 1.9209429025650024	acc 0.157143	val loss 1.938871145248413	acc 0.164000
1173

1267: 	train loss 1.9209492206573486	acc 0.171429	val loss 1.9380812644958496	acc 0.164000
1268: 	train loss 1.920958161354065	acc 0.157143	val loss 1.9395020008087158	acc 0.158000
1269: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.938616156578064	acc 0.164000
1270: 	train loss 1.920953631401062	acc 0.157143	val loss 1.9385416507720947	acc 0.164000
1271: 	train loss 1.9209418296813965	acc 0.150000	val loss 1.9394012689590454	acc 0.158000
1272: 	train loss 1.9209489822387695	acc 0.171429	val loss 1.9383716583251953	acc 0.178000
1273: 	train loss 1.9209479093551636	acc 0.150000	val loss 1.9389421939849854	acc 0.166000
1274: 	train loss 1.9209383726119995	acc 0.157143	val loss 1.9391025304794312	acc 0.160000
1275: 	train loss 1.9209487438201904	acc 0.171429	val loss 1.9384411573410034	acc 0.178000
1276: 	train loss 1.9209398031234741	acc 0.150000	val loss 1.939096212387085	acc 0.160000
1277: 	train loss 1.9209455251693726	acc 0.171429	val loss 1.9388703107833862	acc 0.162000
127

1372: 	train loss 1.9209396839141846	acc 0.157143	val loss 1.9388090372085571	acc 0.164000
1373: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.9388600587844849	acc 0.164000
1374: 	train loss 1.9209396839141846	acc 0.157143	val loss 1.9388110637664795	acc 0.164000
1375: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.9388554096221924	acc 0.164000
1376: 	train loss 1.9209398031234741	acc 0.157143	val loss 1.9388169050216675	acc 0.164000
1377: 	train loss 1.9209424257278442	acc 0.157143	val loss 1.938849687576294	acc 0.164000
1378: 	train loss 1.9209405183792114	acc 0.157143	val loss 1.9388221502304077	acc 0.164000
1379: 	train loss 1.9209418296813965	acc 0.157143	val loss 1.938844084739685	acc 0.164000
1380: 	train loss 1.9209407567977905	acc 0.157143	val loss 1.9388303756713867	acc 0.164000
1381: 	train loss 1.9209415912628174	acc 0.157143	val loss 1.9388316869735718	acc 0.164000
1382: 	train loss 1.9209418296813965	acc 0.157143	val loss 1.9388459920883179	acc 0.164000
1

1478: 	train loss 1.9209381341934204	acc 0.150000	val loss 1.9387308359146118	acc 0.164000
1479: 	train loss 1.920951008796692	acc 0.171429	val loss 1.939051628112793	acc 0.162000
1480: 	train loss 1.920928716659546	acc 0.157143	val loss 1.9385762214660645	acc 0.162000
1481: 	train loss 1.9209550619125366	acc 0.157143	val loss 1.939090609550476	acc 0.162000
1482: 	train loss 1.920929193496704	acc 0.157143	val loss 1.9386446475982666	acc 0.168000
1483: 	train loss 1.9209481477737427	acc 0.150000	val loss 1.9389375448226929	acc 0.162000
1484: 	train loss 1.9209378957748413	acc 0.164286	val loss 1.938834309577942	acc 0.162000
1485: 	train loss 1.9209381341934204	acc 0.150000	val loss 1.9387530088424683	acc 0.166000
1486: 	train loss 1.9209465980529785	acc 0.157143	val loss 1.938971996307373	acc 0.162000
1487: 	train loss 1.9209318161010742	acc 0.150000	val loss 1.9386799335479736	acc 0.166000
1488: 	train loss 1.9209500551223755	acc 0.157143	val loss 1.9389744997024536	acc 0.162000
1489: 

1583: 	train loss 1.9209437370300293	acc 0.150000	val loss 1.938999891281128	acc 0.164000
1584: 	train loss 1.920938491821289	acc 0.164286	val loss 1.9386281967163086	acc 0.162000
1585: 	train loss 1.9209452867507935	acc 0.150000	val loss 1.939094066619873	acc 0.160000
1586: 	train loss 1.920937418937683	acc 0.178571	val loss 1.938509225845337	acc 0.158000
1587: 	train loss 1.9209486246109009	acc 0.150000	val loss 1.9392447471618652	acc 0.160000
1588: 	train loss 1.9209357500076294	acc 0.171429	val loss 1.9383255243301392	acc 0.160000
1589: 	train loss 1.920953631401062	acc 0.150000	val loss 1.9394662380218506	acc 0.162000
1590: 	train loss 1.920935034751892	acc 0.185714	val loss 1.9380748271942139	acc 0.158000
1591: 	train loss 1.9209626913070679	acc 0.157143	val loss 1.939732313156128	acc 0.160000
1592: 	train loss 1.920935034751892	acc 0.185714	val loss 1.937840461730957	acc 0.160000
1593: 	train loss 1.9209762811660767	acc 0.164286	val loss 1.9398771524429321	acc 0.162000
1594: 	tr

1689: 	train loss 1.9209431409835815	acc 0.157143	val loss 1.938866376876831	acc 0.164000
1690: 	train loss 1.9209396839141846	acc 0.157143	val loss 1.9388062953948975	acc 0.164000
1691: 	train loss 1.9209426641464233	acc 0.157143	val loss 1.9388577938079834	acc 0.164000
1692: 	train loss 1.9209402799606323	acc 0.157143	val loss 1.938818335533142	acc 0.164000
1693: 	train loss 1.9209420680999756	acc 0.157143	val loss 1.9388447999954224	acc 0.164000
1694: 	train loss 1.9209409952163696	acc 0.157143	val loss 1.9388290643692017	acc 0.164000
1695: 	train loss 1.9209413528442383	acc 0.157143	val loss 1.9388371706008911	acc 0.164000
1696: 	train loss 1.9209409952163696	acc 0.157143	val loss 1.9388344287872314	acc 0.164000
1697: 	train loss 1.9209413528442383	acc 0.157143	val loss 1.9388333559036255	acc 0.164000
1698: 	train loss 1.9209411144256592	acc 0.157143	val loss 1.9388384819030762	acc 0.164000
1699: 	train loss 1.9209409952163696	acc 0.157143	val loss 1.9388294219970703	acc 0.164000
1

1796: 	train loss 0.7850291728973389	acc 0.678571	val loss 1.7498022317886353	acc 0.428000
1797: 	train loss 0.8363560438156128	acc 0.657143	val loss 1.5522099733352661	acc 0.442000
1798: 	train loss 0.7002217173576355	acc 0.707143	val loss 1.4805734157562256	acc 0.420000
1799: 	train loss 0.6583678722381592	acc 0.707143	val loss 1.620752215385437	acc 0.384000
1800: 	train loss 0.6846782565116882	acc 0.642857	val loss 1.3400845527648926	acc 0.420000
1801: 	train loss 0.5784755349159241	acc 0.721429	val loss 1.2768789529800415	acc 0.478000
1802: 	train loss 0.5614075064659119	acc 0.728571	val loss 1.2571698427200317	acc 0.638000
1803: 	train loss 0.5250402688980103	acc 0.892857	val loss 1.26291024684906	acc 0.638000
1804: 	train loss 0.460785448551178	acc 0.864286	val loss 1.1709113121032715	acc 0.648000
1805: 	train loss 0.38128504157066345	acc 0.857143	val loss 1.1454107761383057	acc 0.628000
1806: 	train loss 0.31692254543304443	acc 0.842857	val loss 1.159679889678955	acc 0.626000
18

1901: 	train loss 0.027763497084379196	acc 1.000000	val loss 0.8939884305000305	acc 0.784000
1902: 	train loss 0.027807816863059998	acc 1.000000	val loss 0.9024693965911865	acc 0.784000
1903: 	train loss 0.02783001773059368	acc 1.000000	val loss 0.9041697382926941	acc 0.784000
1904: 	train loss 0.027899429202079773	acc 1.000000	val loss 0.894679605960846	acc 0.788000
1905: 	train loss 0.027760833501815796	acc 1.000000	val loss 0.8878391981124878	acc 0.788000
1906: 	train loss 0.027755821123719215	acc 1.000000	val loss 0.8916568756103516	acc 0.788000
1907: 	train loss 0.027640972286462784	acc 1.000000	val loss 0.8963243365287781	acc 0.786000
1908: 	train loss 0.02765490859746933	acc 1.000000	val loss 0.8907273411750793	acc 0.786000
1909: 	train loss 0.027465688064694405	acc 1.000000	val loss 0.8855894804000854	acc 0.786000
1910: 	train loss 0.027398880571126938	acc 1.000000	val loss 0.8902792930603027	acc 0.786000
1911: 	train loss 0.027265677228569984	acc 1.000000	val loss 0.8941858410

In [124]:
model.eval()
test_pred = model(input_features, A)
test_acc = masked_accuracy(test_pred, test_labels, test_mask)
print(test_acc)

0.789


In [44]:
(A==0).sum().item() / torch.numel(A)

0.9981912556264169

In [19]:
test_labels.shape

torch.Size([2708])

In [14]:
labels = (data.y[data.train_mask])

In [15]:
data.train_mask.sum().item()

140

In [15]:
torch.cuda.get_device_name()

'Device 67df'

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.2236, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0')
