In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

Using backend: pytorch


# data

In [2]:
import dgl.data

In [3]:
dataset = dgl.data.CoraGraphDataset()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [4]:
g = dataset[0]

# model

In [5]:
import dgl.function as fn

In [6]:
class SAGEConv(nn.Module):
    '''
    graph conv by grashsage
    '''
    def __init__(self,in_feat,out_feat):
        super(SAGEConv,self).__init__()
        self.linear = nn.Linear(in_features=in_feat*2,out_features=out_feat)
    
    def forward(self,g,h):
        ##message passing
        with g.local_scope():
            g.ndata['h'] = h
            ##mesaage and reduce
            g.update_all(message_func=fn.copy_u('h','m'),reduce_func=fn.mean('m','h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h,h_N],dim=1)
            return self.linear(h_total)

In [7]:
class Model(nn.Module):
    def __init__(self,in_feat,h_feat,num_class):
        super(Model,self).__init__()
        self.gc1 = SAGEConv(in_feat,h_feat)
        self.gc2 = SAGEConv(h_feat,num_class)
    
    def forward(self,g,in_feat):
        h = self.gc1(g,in_feat)
        h = F.relu(h)
        h = self.gc2(g,h)
        return h

In [8]:
model = Model(g.ndata['feat'].shape[1],16,dataset.num_classes)

In [9]:
optimizer = torch.optim.Adam(model.parameters(),lr=0.01)

In [10]:
loss_func = torch.nn.CrossEntropyLoss()

# train

In [11]:
model = model.to('cuda')

In [12]:
def train(g,model):
    best_val_acc = 0
    best_test_acc = 0
    
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    
    for e in range(500):
        logits = model(g,features)
        pred = logits.argmax(1)
        
        ##
        loss = loss_func(logits[train_mask],labels[train_mask])
        
        train_acc = (pred[train_mask]==labels[train_mask]).float().mean()
        val_acc = (pred[val_mask]==labels[val_mask]).float().mean()
        test_acc = (pred[test_mask]==labels[test_mask]).float().mean()
        
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc  = test_acc
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if e%5 == 0:
            print('epoch:%d,loss:%f,val_acc:%f -best_val:%f,test_acc:%f - best_test:%f,'%(e,\
                                                                                           loss,val_acc,best_val_acc,test_acc,best_test_acc))

In [13]:
train(g.to("cuda"),model)

epoch:0,loss:1.948182,val_acc:0.162000 -best_val:0.162000,test_acc:0.149000 - best_test:0.149000,
epoch:5,loss:1.873497,val_acc:0.304000 -best_val:0.304000,test_acc:0.308000 - best_test:0.308000,
epoch:10,loss:1.721230,val_acc:0.466000 -best_val:0.494000,test_acc:0.462000 - best_test:0.483000,
epoch:15,loss:1.486352,val_acc:0.510000 -best_val:0.510000,test_acc:0.499000 - best_test:0.499000,
epoch:20,loss:1.182170,val_acc:0.592000 -best_val:0.592000,test_acc:0.566000 - best_test:0.566000,
epoch:25,loss:0.848616,val_acc:0.646000 -best_val:0.646000,test_acc:0.645000 - best_test:0.645000,
epoch:30,loss:0.544626,val_acc:0.716000 -best_val:0.716000,test_acc:0.688000 - best_test:0.688000,
epoch:35,loss:0.317590,val_acc:0.740000 -best_val:0.740000,test_acc:0.729000 - best_test:0.729000,
epoch:40,loss:0.176517,val_acc:0.746000 -best_val:0.746000,test_acc:0.752000 - best_test:0.752000,
epoch:45,loss:0.098740,val_acc:0.754000 -best_val:0.754000,test_acc:0.763000 - best_test:0.763000,
epoch:50,los

epoch:475,loss:0.000528,val_acc:0.736000 -best_val:0.758000,test_acc:0.759000 - best_test:0.760000,
epoch:480,loss:0.000518,val_acc:0.736000 -best_val:0.758000,test_acc:0.759000 - best_test:0.760000,
epoch:485,loss:0.000509,val_acc:0.736000 -best_val:0.758000,test_acc:0.759000 - best_test:0.760000,
epoch:490,loss:0.000500,val_acc:0.736000 -best_val:0.758000,test_acc:0.759000 - best_test:0.760000,
epoch:495,loss:0.000491,val_acc:0.736000 -best_val:0.758000,test_acc:0.759000 - best_test:0.760000,
