<a href="https://colab.research.google.com/github/joshchang1112/bert_gnn_arxiv/blob/master/pytorch/fine_tuned_bert_gnn_ogbn_arxiv_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-class Classification with fine-tuned BERT & GNN (Pytorch)

## Overview

BERT is the most powerful neural networks model in NLP area. Graph Neural Networks(GNN) is also one of the most popular model now. Therefore, we want to bring the advantages in BERT to the citation network, and evaluate that how much improvement can BERT help GNN models work.

In this colab notebook, we will explore the use of fine-tuned BERT, how to encode node features using BERT model, and using citation graph and node features to train GNN models. 

The step is as follows:

1.  Download `ogbn-arxiv` dataset from [Open Graph Benchmark](https://ogb.stanford.edu/).
2.  Fine-tuned BERT with arxiv dataset and select the model which has the best accuracy in validation set as encoder.
3.  Encode the node features using fine-tuned BERT.
4.  Train and evaluate the `GNN` model.

## Setup

Install the **Open Graph Benchmark**, **pytorch_geometric** and **transformers(huggingface)** package.

In [1]:
# Install package
!pip install ogb
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install torch-geometric
!pip install transformers

Collecting ogb
[?25l  Downloading https://files.pythonhosted.org/packages/1a/27/fa0cdde0be085d3b82807e210e0f2dfd75c9c01d4c587be0d05b02a0618b/ogb-1.2.4-py3-none-any.whl (58kB)
[K     |█████▋                          | 10kB 20.8MB/s eta 0:00:01[K     |███████████▏                    | 20kB 27.7MB/s eta 0:00:01[K     |████████████████▊               | 30kB 23.8MB/s eta 0:00:01[K     |██████████████████████▍         | 40kB 18.1MB/s eta 0:00:01[K     |████████████████████████████    | 51kB 17.4MB/s eta 0:00:01[K     |████████████████████████████████| 61kB 6.2MB/s 
Collecting outdated>=0.2.0
  Downloading https://files.pythonhosted.org/packages/86/70/2f166266438a30e94140f00c99c0eac1c45807981052a1d4c123660e1323/outdated-0.2.0.tar.gz
Collecting littleutils
  Downloading https://files.pythonhosted.org/packages/4e/b1/bb4e06f010947d67349f863b6a2ad71577f85590180a935f60543f622652/littleutils-0.2.2.tar.gz
Building wheels for collected packages: outdated, littleutils
  Building wheel for

## Dependencies and imports

In [2]:
import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
import transformers
import numpy as np
import pandas as pd
import csv
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset
from ogb.nodeproppred import Evaluator
from tqdm.notebook import tqdm

print("Pytorch Version: ",  torch.__version__)
if torch.cuda.is_available():
  print("GPU {} is available!".format(torch.cuda.current_device()))
else:
  print("Only CPU is available!")

Pytorch Version:  1.7.0+cu101
GPU 0 is available!


## Ogbn-arxiv dataset
The `ogbn-arxiv` dataset is a directed graph, representing the citation network between all Computer Science (CS) arXiv papers indexed by Microsoft academic graph (MAG). Each node is an arXiv paper and each directed edge indicates that one paper cites another one. The dataset also provide the mapping from MAG paper IDs into the raw texts of titles and abstracts here. 

The task involved is document classification where the goal is to categorize each paper into one of 40 subject areas of arXiv CS papers. In other words, this is a multi-class classification problem with 40 classes.

### Features
Each paper comes with a 128-dimensional feature vector obtained by averaging the embeddings of words in its title and abstract. The embeddings of individual words are computed by running the skip-gram model over the MAG corpus. 

In this colab notebook, we will also compare the results with features provided by OGB and the features encoded by our fine-tuned BERT.

### Download the ogbn-arxiv dataset

In [3]:
# Downloads ogbn_arxiv dataset and its raw titles and abstracts
from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset
import torch_geometric.transforms as T
!wget https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz
!gunzip titleabs.tsv.gz
dataset = PygNodePropPredDataset(name='ogbn-arxiv', 
                                 transform=T.ToSparseTensor())

--2021-01-23 08:48:21--  https://snap.stanford.edu/ogb/data/misc/ogbn_arxiv/titleabs.tsv.gz
Resolving snap.stanford.edu (snap.stanford.edu)... 171.64.75.80
Connecting to snap.stanford.edu (snap.stanford.edu)|171.64.75.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 70213527 (67M) [application/x-gzip]
Saving to: ‘titleabs.tsv.gz’


2021-01-23 08:48:28 (9.76 MB/s) - ‘titleabs.tsv.gz’ saved [70213527/70213527]

Downloading https://snap.stanford.edu/ogb/data/nodeproppred/arxiv.zip


Downloaded 0.08 GB: 100%|██████████| 81/81 [00:08<00:00, 10.09it/s]


Extracting dataset/arxiv.zip
Processing...
Loading necessary files...
This might take a while.


100%|██████████| 1/1 [00:00<00:00, 1136.36it/s]
100%|██████████| 1/1 [00:00<00:00, 25.95it/s]

Processing graphs...
Converting graphs into PyG objects...
Saving...





Done!


## Global variables

In [4]:
### File path
TRAIN_ID_PATH = 'dataset/ogbn_arxiv/split/time/train.csv.gz'
VALID_ID_PATH = 'dataset/ogbn_arxiv/split/time/valid.csv.gz'
TEST_ID_PATH = 'dataset/ogbn_arxiv/split/time/test.csv.gz'
LABEL_PATH = 'dataset/ogbn_arxiv/raw/node-label.csv.gz'
NODE2PAPER_PATH = 'dataset/ogbn_arxiv/mapping/nodeidx2paperid.csv.gz'
RAW_DATA_PATH = 'titleabs.tsv'


## Fine-tuned BERT

### Load data & preprocess

In [5]:
raw_data = pd.read_csv(RAW_DATA_PATH, sep='\t', header=None)
raw_data.columns = ['Id', 'Title', 'Abstract']
raw_data.iloc[0, 0] = 200971
raw_data = raw_data.drop(len(raw_data)-1)

node2paper = pd.read_csv(NODE2PAPER_PATH)
train_idx = pd.read_csv(TRAIN_ID_PATH, header=None)
val_idx = pd.read_csv(VALID_ID_PATH, header=None)
test_idx = pd.read_csv(TEST_ID_PATH, header=None)
label = pd.read_csv(LABEL_PATH, header=None)

train_idx = train_idx.iloc[:, 0].tolist()
val_idx = val_idx.iloc[:, 0].tolist()
test_idx = test_idx.iloc[:, 0].tolist()
label = label.iloc[:, 0].tolist()

paper2node_dict = {}
node2paper_dict = {}

for i, row in tqdm(node2paper.iterrows()):
    paper2node_dict[int(row[1])] = int(row[0])
    node2paper_dict[int(row[0])] = int(row[1])

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




### Tokenize and split dataset

---



In [6]:
from transformers import BertTokenizer
train = []
val = []
test = []
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
for i, row in tqdm(raw_data.iterrows()):
    if int(row['Id']) not in paper2node_dict:
        continue
    processed = {}
    processed['context'] = tokenizer.tokenize(text=row['Title']+row['Abstract'])
    processed['context'] = tokenizer.convert_tokens_to_ids(processed['context'])
    processed['length'] = len(processed['context'])
    processed['id'] = paper2node_dict[int(row['Id'])]
    processed['label'] = label[int(paper2node_dict[int(row['Id'])])]
    
    if processed['id'] in train_idx:
        train.append(processed)
    elif processed['id'] in val_idx:
        val.append(processed)
    elif processed['id'] in test_idx:
        test.append(processed)
    else:
        print("NOT MATCH!!!!!")
        break

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




### Hyperparameters

We will use an instance of `Args` to include various hyperparameters and constants used for training and evaluation in BERT. We briefly describe each of them below:


-   **num_classes**: There are a total 40 different classes

-   **max_seq_length**: This is the max length of document tokens in training and evaluation.

-   **train_epochs**: The number of training epochs.

-   **batch_size**: Batch size used for training and evaluation.

-   **learning_rate**: Learning rate in training.

-   **dropout_rate**: Controls the rate of dropout following each layer.

-   **eval_steps**: The number of batches to process before evaluation.

In [7]:
class Args(object):
  """Hyperparameters used for training BERT."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 40
    self.max_seq_length = 500
    ### training parameters
    self.train_epochs = 2
    self.batch_size = 8
    self.learning_rate = 2e-5
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = 4000

args = Args()

### Make dataset

In [8]:
from torch.utils.data import Dataset, DataLoader

def pad_to_len(arr, padded_len, padding=0):
    length_arr = len(arr)
    new_arr = arr
    if length_arr < padded_len:
        for i in range(padded_len - length_arr):
            new_arr.append(padding)
    else:
        for i in range(length_arr - padded_len):
            del new_arr[-2]
    return new_arr

class CitationDataset(Dataset):

    def __init__(self, data, max_seq_len, padding=0):
        self.data = data
        self.max_seq_len = max_seq_len
        self.padding = padding

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data = dict(self.data[index])
        if len(data['context']) > self.max_seq_len:
            data['context'] = data['context'][:self.max_seq_len]
        return data

    def collate_fn(self, datas):
        batch = {}
        batch['length'] = torch.LongTensor([data['length'] for data in datas])
        padded_len = min(self.max_seq_len, max(batch['length']))
        batch['context'] = torch.tensor(
            [pad_to_len(data['context'], padded_len, self.padding)
             for data in datas]
        )
        batch['label'] = torch.LongTensor([data['label'] for data in datas])
        return batch

train_dataset = CitationDataset(train, max_seq_len=args.max_seq_length)
valid_dataset = CitationDataset(val, max_seq_len=args.max_seq_length)

train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, 
    collate_fn=train_dataset.collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False,
    collate_fn=valid_dataset.collate_fn)

### Metrics

In [9]:
from sklearn.metrics import accuracy_score
class Metrics:
    def __init__(self):
        self.name = 'Metric Name'

    def reset(self):
        pass

    def update(self, predicts, batch):
        pass

    def get_score(self):
        pass

class Accuracy(Metrics):
    """
    Args:
         ats (int): @ to eval.
         rank_na (bool): whether to consider no answer.
    """
    def __init__(self):
        self.n = 0
        self.name = 'Accuracy'
        self.match = 0

    def reset(self):
        self.n = 0
        self.match = 0
        
    def update(self, predicts, label):
        """
        Args:
            predicts (FloatTensor): with size (batch, n_samples).
            batch (dict): batch.
        """
        predicts, label = predicts.cpu(), label.cpu()
        batch_size = list(predicts.size())[0]
        _, y_pred = torch.max(predicts, dim=1)
        self.match += accuracy_score(label, y_pred, normalize=False)
        self.n += batch_size
    
    def print_score(self):
        acc = self.match / self.n
        #self.get_category_f1()
        return '{:.4f}'.format(acc)

### Train and Evaluate BERT model on multi-classification task

In [10]:
def run_iter(batch, model, device, training):
    context, context_lens = batch['context'].to(device), batch['length'].to(device)
    batch_size = context.size()[0]
    max_context_len = context.size()[1]
    padding_mask = []
    for j in range(batch_size):
        if context_lens[j] < max_context_len:
            tmp = [1] * context_lens[j] + [0] * (max_context_len - context_lens[j])
        else:
            tmp = [1] * max_context_len
        padding_mask.append(tmp)

    padding_mask = torch.Tensor(padding_mask).to(device)
    if training:
        prob = model(context, attention_mask=padding_mask)[0]
    else:
        with torch.no_grad():
            prob = model(context, attention_mask=padding_mask)[0]
    return prob

In [11]:
def training(train_loader, valid_loader, model, optimizer, epochs, eval_steps, device):
    train_metrics = Accuracy()
    best_valid_acc = 0
    total_iter = 0
    criterion = torch.nn.CrossEntropyLoss()
    for epoch in range(epochs):
        train_trange = tqdm(enumerate(train_loader), total=len(train_loader), desc='training')
        train_loss = 0
        train_metrics.reset()
        for i, batch in train_trange:
            model.train()
            prob = run_iter(batch, model, device, training=True)
            answer = batch['label'].to(device)
            loss = criterion(prob, answer)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_iter += 1
            train_loss += loss.item()
            train_metrics.update(prob, answer)
            train_trange.set_postfix(loss= train_loss/(i+1),
                                     **{train_metrics.name: train_metrics.print_score()})
            
            if total_iter % eval_steps == 0:
                valid_acc = testing(valid_loader, model, device, valid=True)
                if valid_acc > best_valid_acc:
                    best_valid_acc = valid_acc
                    torch.save(model, 'best_val.pkl')

In [12]:
def testing(dataloader, model, device, valid):
    metrics = Accuracy()
    criterion = torch.nn.CrossEntropyLoss()
    trange = tqdm(enumerate(dataloader), total=len(dataloader), desc='validation' if valid else 'testing')
    model.eval()
    total_loss = 0
    metrics.reset()
    for k, batch in trange:
        model.eval()
        prob = run_iter(batch, model, device, training=False)
        answer = batch['label'].to(device)
        loss = criterion(prob, batch['label'].to(device))
        total_loss += loss.item()
        metrics.update(prob, answer)
        trange.set_postfix(loss= total_loss/(k+1),
                           **{metrics.name: metrics.print_score()})
    acc = metrics.match / metrics.n
    return acc

In [13]:
from transformers import BertForSequenceClassification
from torch.optim import Adam
device = torch.device('cuda:{}'.format(torch.cuda.current_device()) 
                       if torch.cuda.is_available() else 'cpu')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', 
                                                      num_labels=args.num_classes).to(device)
optimizer = Adam(model.parameters(), lr=args.learning_rate)
training(train_loader, valid_loader, model, optimizer, args.train_epochs, args.eval_steps, device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

HBox(children=(FloatProgress(value=0.0, description='training', max=11368.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='validation', max=3725.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='validation', max=3725.0, style=ProgressStyle(description_…





HBox(children=(FloatProgress(value=0.0, description='training', max=11368.0, style=ProgressStyle(description_w…

HBox(children=(FloatProgress(value=0.0, description='validation', max=3725.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='validation', max=3725.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='validation', max=3725.0, style=ProgressStyle(description_…





### Predict test set using fine-tuned BERT model

In [14]:
test_dataset = CitationDataset(test, max_seq_len=args.max_seq_length)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
                         collate_fn=test_dataset.collate_fn)
model = torch.load('best_val.pkl').to(device)
test_acc = testing(test_loader, model, device, valid=False)
print("Test Accuracy:{}".format(test_acc))

HBox(children=(FloatProgress(value=0.0, description='testing', max=6076.0, style=ProgressStyle(description_wid…


Test Accuracy:0.7237619077011707


In [15]:
!ls

best_val.pkl  dataset  sample_data  titleabs.tsv


## Encode node features using fine-tuned BERT

In [16]:
node_feats = torch.zeros((len(node2paper_dict), 768)).to(device)
model = torch.load('best_val.pkl').to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

### Freeze all fine-tuned BERT layers
freeze_layers = 12
for p in model.bert.embeddings.parameters():
    p.requires_grad = False
model.bert.embeddings.dropout.p = 0.0
for p in model.bert.pooler.parameters():
    p.requires_grad = False
for idx in range(freeze_layers):
    for p in model.bert.encoder.layer[idx].parameters():
        p.requires_grad = False
    model.bert.encoder.layer[idx].attention.self.dropout.p = 0.0
    model.bert.encoder.layer[idx].attention.output.dropout.p = 0.0
    model.bert.encoder.layer[idx].output.dropout.p = 0.0

for i, row in tqdm(raw_data.iterrows()):
    if int(row['Id']) not in paper2node_dict:
        continue
    context = row['Title'] + row['Abstract']
    tokenize_context = tokenizer.tokenize(context)
    context_len = len(tokenize_context)
    if context_len > 512:
        tokenize_context = tokenize_context[:512]
        
    context_id = tokenizer.convert_tokens_to_ids(tokenize_context)
    context_id = torch.LongTensor(context_id).unsqueeze(0).cuda()
    feat = model.bert(context_id)[1]
    node_id = paper2node_dict[int(row['Id'])]
    node_feats[node_id, :] = feat
    torch.cuda.empty_cache()

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




## Graph Neural Networks (GNN)

### Load citation graph data

In [17]:
dataset = PygNodePropPredDataset(name='ogbn-arxiv',
                                     transform=T.ToSparseTensor())
data = dataset[0]
data.adj_t = data.adj_t.to_symmetric()
data = data.to(device)
split_idx = dataset.get_idx_split()
train_idx = split_idx['train'].to(device)

### GNN Models

#### Graph Convolution Network (GCN)

In [18]:
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GCN, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels, cached=True))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, cached=True))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels, cached=True))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


#### Graph Attention Network (GAT)

In [19]:
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(GAT, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.convs.append(GATConv(in_channels, hidden_channels, heads=4))
        self.bns = torch.nn.ModuleList()
        self.bns.append(torch.nn.BatchNorm1d(hidden_channels*4))
        for _ in range(num_layers - 2):
            self.convs.append(
                GATConv(hidden_channels*4, hidden_channels, heads=4))
            self.bns.append(torch.nn.BatchNorm1d(hidden_channels*4))
        self.convs.append(GATConv(hidden_channels*4, out_channels, heads=4, concat=False))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for bn in self.bns:
            bn.reset_parameters()

    def forward(self, x, adj_t):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, adj_t)
            x = self.bns[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t)
        return x.log_softmax(dim=-1)


### Metrics

In [20]:
class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 3
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 1].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Train: {result[:, 0].max():.2f}')
            print(f'Highest Valid: {result[:, 1].max():.2f}')
            print(f'  Final Train: {result[argmax, 0]:.2f}')
            print(f'   Final Test: {result[argmax, 2]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)

            best_results = []
            for r in result:
                train1 = r[:, 0].max().item()
                valid = r[:, 1].max().item()
                train2 = r[r[:, 1].argmax(), 0].item()
                test = r[r[:, 1].argmax(), 2].item()
                best_results.append((train1, valid, train2, test))

            best_result = torch.tensor(best_results)

            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 1]
            print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 2]
            print(f'  Final Train: {r.mean():.2f} ± {r.std():.2f}')
            r = best_result[:, 3]
            print(f'   Final Test: {r.mean():.2f} ± {r.std():.2f}')


### Hyperparameters

In [21]:
class GNN_Args(object):
  """Hyperparameters used for training GNN."""
  def __init__(self):
    self.log_steps = 1
    self.num_layers = 3
    self.hidden_channels = 256
    self.dropout = 0.5
    self.epochs = 500
    self.lr = 5e-3
    self.runs = 10

args = GNN_Args()

### Train and Evaluate GNN model on word2vec node features

In [22]:
def train(model, data, train_idx, optimizer):
    model.train()

    optimizer.zero_grad()
    out = model(data.x, data.adj_t)[train_idx]
    loss = F.nll_loss(out, data.y.squeeze(1)[train_idx])
    loss.backward()
    optimizer.step()

    return loss.item()

@torch.no_grad()
def test(model, data, split_idx, evaluator):
    model.eval()

    out = model(data.x, data.adj_t)
    y_pred = out.argmax(dim=-1, keepdim=True)

    train_acc = evaluator.eval({
        'y_true': data.y[split_idx['train']],
        'y_pred': y_pred[split_idx['train']],
    })['acc']
    valid_acc = evaluator.eval({
        'y_true': data.y[split_idx['valid']],
        'y_pred': y_pred[split_idx['valid']],
    })['acc']
    test_acc = evaluator.eval({
        'y_true': data.y[split_idx['test']],
        'y_pred': y_pred[split_idx['test']],
    })['acc']

    return train_acc, valid_acc, test_acc

In [23]:
from ogb.nodeproppred import Evaluator
evaluator = Evaluator(name='ogbn-arxiv')
logger = Logger(args.runs, args)
model = GCN(data.num_features, args.hidden_channels,
            dataset.num_classes, args.num_layers,
            args.dropout).to(device)
# model = GAT(data.num_features, args.hidden_channels,
#             dataset.num_classes, args.num_layers,
#             args.dropout).to(device)

for run in range(args.runs):
  model.reset_parameters()
  optimizer = Adam(model.parameters(), lr=args.lr)
  for epoch in range(1, 1 + args.epochs):
    loss = train(model, data, train_idx, optimizer)
    result = test(model, data, split_idx, evaluator)
    logger.add_result(run, result)

    if epoch % args.log_steps == 0:
      train_acc, valid_acc, test_acc = result
      print(f'Run: {run + 1:02d}, '
            f'Epoch: {epoch:02d}, '
            f'Loss: {loss:.4f}, '
            f'Train: {100 * train_acc:.2f}%, '
            f'Valid: {100 * valid_acc:.2f}% '
            f'Test: {100 * test_acc:.2f}%')
  logger.print_statistics(run)
logger.print_statistics()


[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
Run: 01, Epoch: 56, Loss: 1.0187, Train: 71.19%, Valid: 70.62% Test: 70.29%
Run: 01, Epoch: 57, Loss: 1.0201, Train: 71.31%, Valid: 70.63% Test: 70.26%
Run: 01, Epoch: 58, Loss: 1.0160, Train: 71.36%, Valid: 70.63% Test: 70.12%
Run: 01, Epoch: 59, Loss: 1.0126, Train: 71.35%, Valid: 70.55% Test: 69.93%
Run: 01, Epoch: 60, Loss: 1.0099, Train: 71.41%, Valid: 70.48% Test: 69.78%
Run: 01, Epoch: 61, Loss: 1.0060, Train: 71.46%, Valid: 70.56% Test: 69.58%
Run: 01, Epoch: 62, Loss: 1.0047, Train: 71.56%, Valid: 70.63% Test: 69.63%
Run: 01, Epoch: 63, Loss: 1.0044, Train: 71.68%, Valid: 70.75% Test: 69.63%
Run: 01, Epoch: 64, Loss: 0.9998, Train: 71.77%, Valid: 71.03% Test: 70.02%
Run: 01, Epoch: 65, Loss: 0.9969, Train: 71.86%, Valid: 71.18% Test: 70.51%
Run: 01, Epoch: 66, Loss: 0.9938, Train: 71.75%, Valid: 71.03% Test: 70.64%
Run: 01, Epoch: 67, Loss: 0.9938, Train: 71.84%, Valid: 71.17% Test: 70.68%
Run: 01, Epoch: 68, Loss: 0.9919, Train: 71.99%, Vali

### Train and Evaluate GNN model on fine-tuned BERT node features

In [24]:
args.num_layers = 2
evaluator = Evaluator(name='ogbn-arxiv')
logger = Logger(args.runs, args)
data.x = node_feats
model = GCN(data.num_features, args.hidden_channels,
            dataset.num_classes, args.num_layers,
            args.dropout).to(device)
# model = GAT(data.num_features, args.hidden_channels,
#             dataset.num_classes, args.num_layers,
#             args.dropout).to(device)

for run in range(args.runs):
  model.reset_parameters()
  optimizer = Adam(model.parameters(), lr=args.lr)
  for epoch in range(1, 1 + args.epochs):
    loss = train(model, data, train_idx, optimizer)
    result = test(model, data, split_idx, evaluator)
    logger.add_result(run, result)

    if epoch % args.log_steps == 0:
      train_acc, valid_acc, test_acc = result
      print(f'Run: {run + 1:02d}, '
            f'Epoch: {epoch:02d}, '
            f'Loss: {loss:.4f}, '
            f'Train: {100 * train_acc:.2f}%, '
            f'Valid: {100 * valid_acc:.2f}% '
            f'Test: {100 * test_acc:.2f}%')
  logger.print_statistics(run)
logger.print_statistics()


[1;30;43m串流輸出內容已截斷至最後 5000 行。[0m
Run: 01, Epoch: 56, Loss: 0.7760, Train: 77.68%, Valid: 74.64% Test: 73.14%
Run: 01, Epoch: 57, Loss: 0.7742, Train: 77.70%, Valid: 74.67% Test: 73.11%
Run: 01, Epoch: 58, Loss: 0.7744, Train: 77.73%, Valid: 74.68% Test: 73.11%
Run: 01, Epoch: 59, Loss: 0.7732, Train: 77.73%, Valid: 74.75% Test: 73.17%
Run: 01, Epoch: 60, Loss: 0.7728, Train: 77.76%, Valid: 74.77% Test: 73.32%
Run: 01, Epoch: 61, Loss: 0.7716, Train: 77.77%, Valid: 74.78% Test: 73.40%
Run: 01, Epoch: 62, Loss: 0.7704, Train: 77.76%, Valid: 74.81% Test: 73.45%
Run: 01, Epoch: 63, Loss: 0.7694, Train: 77.78%, Valid: 74.86% Test: 73.49%
Run: 01, Epoch: 64, Loss: 0.7687, Train: 77.81%, Valid: 74.86% Test: 73.48%
Run: 01, Epoch: 65, Loss: 0.7675, Train: 77.83%, Valid: 74.86% Test: 73.44%
Run: 01, Epoch: 66, Loss: 0.7680, Train: 77.82%, Valid: 74.85% Test: 73.45%
Run: 01, Epoch: 67, Loss: 0.7665, Train: 77.84%, Valid: 74.89% Test: 73.44%
Run: 01, Epoch: 68, Loss: 0.7654, Train: 77.84%, Vali

The proposed gcn model, which using fine-tuned BERT to encode node features, has about 3% higher accuracy(**74.94%**) than the gcn model which used original node features (72.09%) and the fine-tuned BERT without citation graph data (72.13%).

## Conclusion

We have demonstrated the use of fine-tuned BERT, how to encode node features, and train GNN for multi-class classification on [Open Graph Benchmark](https://ogb.stanford.edu/) - [ogbn-arxiv](https://ogb.stanford.edu/docs/nodeprop/#ogbn-arxiv) dataset. The result shows that pretrained transformer architechture (BERT) can really bring its advantage on natural language understanding to Graph Neural Networks (GNN) and we are excited that the test accuracy by **our proposed method beats the 1st place** in [Open Graph Benchmark Leaderboard](https://ogb.stanford.edu/docs/leader_nodeprop/#ogbn-arxiv) (**74.94%** v.s. 74.16%)!!!

In this work, we only use the base method and architecture in transformer (BERT) and graph neural network (GCN). We encourage users to experiment further by trying different architechture to encode the node features and designing advanced training procedure in graph neural networks (GNN) for multi-class node classification.
