In [1]:
# self defined library
from data_utils import keystoint, get_seg_loader # other unneeded definitions: MyData, MyDataset, SegmentBatchCollater, SegmentDataLoader, 

In [2]:
from tqdm.auto import trange, tqdm
from dataclasses import dataclass, field
from typing import List, Optional, Sequence, Union

import collections
import itertools
from datetime import datetime
import json
import os
import pandas as pd

import wandb
import numpy as np
import random

import torch_geometric as pyg
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv, PDNConv, global_mean_pool, global_max_pool
from torch_geometric.utils import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data

import transformers
from transformers import get_scheduler, AutoTokenizer
from transformers.models.bert.modeling_bert import BertModel

import evaluate
from sklearn.metrics import top_k_accuracy_score

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# device = torch.device("cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transformers.__version__, pyg.__version__, torch.__version__,device

('4.26.0', '2.2.0', '1.13.1', device(type='cuda'))

# definitions

In [4]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [5]:
@dataclass
class MyConfig:
    # dataset configs
    modelname: str
    dataset: str
    num_classes: int 
    segment_length: int = 'doc'

    # model archtectures
    bert_emb_dim: int = 768
    pos_emb_dim: int = 64 # this is determined by the pretrained pos bert
    dep_emb_dim: int = 32 # edge attribute dim
    bert_pos_fuse: str = 'after hier' # how embeddings fuse togethor. 'token', 'before hier', or 'after hier'
    hidden_dim: int = field(init=False)
    # these 4 parameters below are not changable
    num_dep_type: int = 37 # this is determined by dependency2id
    # max syllable for common word is 17. However, the longest word in English is a protein that has 189819 letters! 
    # That must have much more syllables. Truncate to 32 to simplify.
    max_num_syllables: int = 32  
    max_sentence_num: int = 64 
    # zipf frequency bins. 0-8 stands for its frequency between 10**(x-1) and 10**x. 9 is for punctuations. 10 is for CLS and SEP
    num_freq_type: int = 11

    num_layers: int = 4
    heads: int = 4
    num_hierarchy: int = 1

    add_self_loops: bool = False
    add_syllables: bool = True
    add_word_freq: bool = True
    add_dep: bool = True
    add_sentence_order: bool = True # only if num_hierarchy > 0
    
    # training configs
    max_length: int = 256 # this is only for pos tokenizer, bert tokenizer use default 512
    dropout: float = 0.1
    batch_size: int = 16
    epochs: int = 100
    warmup_ratio: float = 0.15
    lr: float = 2e-3
    save: bool = False
    save_location: str = field(init=False)

    # pretrained checkpoints
    pos_checkpoint: str = field(init=False)
    bert_checkpoint: str = 'bert-base-uncased'
    freeze_bert: int = 10 # how many bert layers to freeze

    def __post_init__(self):
        assert self.pos_emb_dim % self.heads==0, 'make sure pos_emb_dim is dividable to heads'
        self.hidden_dim = self.pos_emb_dim//self.heads

        pos_emb_dim2pos_checkpoint = {64: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_22/checkpoint-95000/",
                                    48: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_16/checkpoint-95000/",
                                    32: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_10/checkpoint-145000/",
                                    16: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_5/checkpoint-95000/",}
        self.pos_checkpoint = pos_emb_dim2pos_checkpoint[self.pos_emb_dim]

        self.save_location = None if not self.save else f'/scratch/data_jz17d/result/{self.modelname} {self.dataset}'

In [6]:
@dataclass
class myGNNoutput:
    loss: None
    logit: None
    emb: None


In [7]:
# GNNtype2layer = {'GATConv':GATConv, 'GATv2Conv':GATv2Conv, 'TransformerConv':TransformerConv, 'PDNConv':PDNConv}

class MyGNNBlock(torch.nn.Module):
    def __init__(self, 
                 in_channels,
                 out_channels,
                 heads,
                 dropout=0.1,
                 dropout_position='last',
                 **kwargs):
        super().__init__()

        self.layernorm = nn.LayerNorm(in_channels)
        self.gnnlayer = TransformerConv(in_channels=in_channels, out_channels=out_channels, heads=heads, beta=True, **kwargs)
        self.dropout = dropout
        self.dropout_position = dropout_position

    def forward(self, x, edge_index, edge_attr=None):
        if self.dropout_position=='last':
            x = x + self.gnnlayer(self.layernorm(x), edge_index, edge_attr).relu()
            x = F.dropout(x, p=self.dropout, training=self.training)
        elif self.dropout_position=='first':
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + self.gnnlayer(self.layernorm(x), edge_index, edge_attr).relu()
        return x

In [8]:
class HierGNN(torch.nn.Module):
    def __init__(self, myconfig):
        super().__init__()
        self.num_classes = myconfig.num_classes
    
        # model archtectures
        self.bert_emb_dim = myconfig.bert_emb_dim
        self.pos_emb_dim = myconfig.pos_emb_dim
        self.dep_emb_dim = myconfig.dep_emb_dim
        self.hidden_dim = myconfig.hidden_dim
        self.bert_pos_fuse = myconfig.bert_pos_fuse
        
        # these 4 parameters are not changable
        self.num_dep_type = myconfig.num_dep_type
        self.max_num_syllables = myconfig.max_num_syllables
        self.max_sentence_num = myconfig.max_sentence_num
        self.num_freq_type = myconfig.num_freq_type

        self.num_layers = myconfig.num_layers
        self.heads = myconfig.heads
        self.num_hierarchy = myconfig.num_hierarchy
        
        assert self.bert_pos_fuse.endswith('hier') != self.num_hierarchy > 0, 'hierarchy settings mismatch'

        self.add_self_loops = myconfig.add_self_loops
        self.add_syllables = myconfig.add_syllables
        self.add_word_freq = myconfig.add_word_freq
        self.add_dep = myconfig.add_dep
        self.add_sentence_order = myconfig.add_sentence_order

        # model misc
        self.max_length = myconfig.max_length # this is for pos tokenizer
        self.dropout = myconfig.dropout

        # pretrained checkpoints
        self.pos_checkpoint = myconfig.pos_checkpoint
        self.bert_checkpoint = myconfig.bert_checkpoint

        # loading pretrained models
        self.bert_tokenizer = AutoTokenizer.from_pretrained(self.bert_checkpoint)
        self.bert = BertModel.from_pretrained(self.bert_checkpoint, add_pooling_layer = False).to(device)
        self.pos_tokenizer = AutoTokenizer.from_pretrained(self.pos_checkpoint, local_files_only=True)
        self.pos_bert = BertModel.from_pretrained(self.pos_checkpoint, local_files_only=True, add_pooling_layer = False).to(device)
        
        # embedding layers
        if self.add_syllables:
            # the longest word in the world has 17 syllables. However, if either processing error or people speak like that, such there is no space between words, error will arise.
            self.syllable_emb_layer = nn.Embedding(self.max_num_syllables, self.pos_emb_dim)
        if self.add_word_freq:
            self.freq_emb_layer = nn.Embedding(self.num_freq_type, self.pos_emb_dim)
        if self.add_dep:
            self.dep_emb_layer = nn.Embedding(self.num_dep_type, self.dep_emb_dim)
        
        # determine dimension size for gnn layers
        if self.bert_pos_fuse == 'token':
            self.gnn_dim = self.bert_emb_dim + self.pos_emb_dim
        else: 
            self.gnn_dim = self.pos_emb_dim
        self.gnn_hidden = self.gnn_dim//self.heads 

        # gnns within sentences
        self.gnns = nn.ModuleList()
        for i in range(self.num_layers):
            if self.add_dep:
                self.gnns.append(MyGNNBlock(self.gnn_dim, self.gnn_hidden, heads = self.heads, add_self_loops=self.add_self_loops, dropout=self.dropout, edge_dim=self.dep_emb_dim))
            else:
                self.gnns.append(MyGNNBlock(self.gnn_dim, self.gnn_hidden, heads = self.heads, add_self_loops=self.add_self_loops, dropout=self.dropout))

        if self.num_hierarchy:
            # for sentence order
            # the longest text has 104 lines. how to deal with super long text?
            if self.add_sentence_order:
                self.sentence_position_emb_layer = nn.Embedding(self.max_sentence_num, 2*self.gnn_dim) 

            # determine dimension size for hier layers
            if self.bert_pos_fuse == 'token':
                self.hier_dim = 2*(self.bert_emb_dim + self.pos_emb_dim)
            elif self.bert_pos_fuse == 'before hier':
                self.hier_dim = self.bert_emb_dim + 2*self.pos_emb_dim
            elif self.bert_pos_fuse == 'after hier':
                self.hier_dim = 2*self.pos_emb_dim
            self.hier_hidden = self.hier_dim//self.heads 

            # hierarchical layer
            self.hierarchy_gnns = nn.ModuleList()
            for i in range(self.num_hierarchy):
                self.hierarchy_gnns.append(MyGNNBlock(self.hier_dim, self.hier_hidden, heads = self.heads, add_self_loops=self.add_self_loops, dropout=self.dropout))
            
        # determine dimension size for cls layer
        if self.bert_pos_fuse == 'token' and self.num_hierarchy:
            self.cls_dim = 4*(self.bert_emb_dim + self.pos_emb_dim)
        elif self.bert_pos_fuse.endswith('hier'):
            self.cls_dim = 2*self.bert_emb_dim + 4*self.pos_emb_dim 
        elif not self.num_hierarchy:
            self.cls_dim = 2*(self.bert_emb_dim + self.pos_emb_dim)

        self.classifier = nn.Linear(self.cls_dim, self.num_classes)
        self.lossfn = nn.CrossEntropyLoss()
        
    def forward(self, text, pos, alignments, edge_index, edge_type_ids, batch, ptr, y, segment_ids, num_syllable, word_freq):
        # get pos embeddings, reshape and squeeze the dimension 0 to match pyg batching fashion
        # x.shape = (sum of #sentence, max_length, pos_emb_dim)
        tokens = self.pos_tokenizer(pos, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
        x = self.pos_bert(**tokens).last_hidden_state
        # reshape! drop padded tokens!
        # x.shape = (sum of #token, pos_emb_dim)
        x = x.masked_select(tokens.attention_mask.ge(0.5).unsqueeze(2)).reshape((-1,self.pos_emb_dim))
        
        # add syllables embedding to pos embeddings
        if self.add_syllables:
            x = x + self.syllable_emb_layer(torch.clip(num_syllable, max=self.max_num_syllables-1)) # clip to make sure no error

        # add freq embedding to pos embeddings
        if self.add_word_freq:
            x = x + self.freq_emb_layer(word_freq)

        # get bert embeddings, reshape and squeeze the dimension 0 to match pyg batching fashion
        # bert_x.shape = (sum of #sentence, max_length, bert_emb_dim)
        bert_tokens = self.bert_tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)
        bert_x = self.bert(**bert_tokens).last_hidden_state
        
        if self.bert_pos_fuse == 'token':
            bert_x = bert_x.masked_select(bert_tokens.attention_mask.ge(0.5).unsqueeze(2)).reshape((-1,self.bert_emb_dim))
            zero_emb = torch.zeros(alignments[-1]+1, self.bert_emb_dim).to(device)
            bert_x = zero_emb.index_reduce(0, alignments, bert_x, 'mean', include_self=False)
            x = torch.concat([bert_x, x], axis=1) 
        elif self.bert_pos_fuse.endswith('hier'):
            bert_x = bert_x[:,0,:] # CLS token

        # get edge embeddings
        if self.add_dep:
            edge_attr = self.dep_emb_layer(edge_type_ids)

        # graph conv
        for i in range(self.num_layers):
            x = x + self.gnns[i](x, edge_index, edge_attr=edge_attr).relu() if self.add_dep else self.gnns[i](x, edge_index).relu()
        
        if self.num_hierarchy:
            # readout to get sentence embeddings
            # x.shape = (#sentence, pos_emb_dim*2)
            non_zero_i, non_zero_j = tokens.attention_mask.nonzero(as_tuple=True)
            # the input batch is segment level batch indices. Need sentence level batch indices here
            sent_batch = (((torch.arange(len(pos)).to(device)+1).unsqueeze(1)*tokens.attention_mask)[non_zero_i, non_zero_j] - 1)
            x = torch.cat([global_mean_pool(x, sent_batch), global_max_pool(x, sent_batch)], axis=1)

            if self.bert_pos_fuse == 'before hier':
                x = torch.cat([bert_x, x], axis=1)
                
            # calculate edge_index between sentences from the same paragraph
            edges_among_sentences = torch.LongTensor().to(device)
            if self.add_sentence_order: 
                sentence_id = torch.LongTensor()
            for i in range(segment_ids.max().item()+1):
                idx = (segment_ids==i).nonzero().long().squeeze(1)  # select all sentence id belong to current segment
                edge_x, edge_y = torch.meshgrid(idx, idx)
                edge = torch.vstack([edge_x.flatten(), edge_y.flatten()])
                edges_among_sentences = torch.cat([edges_among_sentences, edge], axis = 1)
                if self.add_sentence_order: 
                    sentence_id = torch.cat([sentence_id, torch.arange(len(idx), dtype=torch.long)])
            
            # add sentence position
            if self.add_sentence_order:
                sentence_id = sentence_id.to(device)
                x = x + self.sentence_position_emb_layer(torch.clip(sentence_id, max=self.max_sentence_num-1))

            if self.bert_pos_fuse == 'after hier':
                x = torch.cat([bert_x, x], axis=1)

            # hierarchical layers
            for i in range(self.num_hierarchy):
                x = x + self.hierarchy_gnns[i](x, edges_among_sentences).relu()

            # readout to get segment/doc embeddings
            # x.shape = (#segment/#doc, pos_emb_dim*4)
            x = torch.cat([global_mean_pool(x, segment_ids), global_max_pool(x, segment_ids)], axis=1)

        else: 
            # readout to get segment/doc embeddings
            # x.shape = (#segment/#doc, pos_emb_dim*2)
            x = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], axis=1)

        # prepare logits and output
        x = F.dropout(x, p=self.dropout, training=self.training)
        logit = self.classifier(x)
        loss = self.lossfn(logit, y)
        return myGNNoutput(loss=loss, logit=logit, emb=x)

In [9]:
def save_model(model_folder, model, optimizer, scheduler):
    torch.save(model.state_dict(), f"{model_folder}/pytorch_model.bin")
    torch.save(optimizer.state_dict(), f"{model_folder}/optimizer.pt")
    torch.save(scheduler.state_dict(), f"{model_folder}/scheduler.pt")

In [10]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert is True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert is True: # == True is wrong!!!
        for param in model.bert.parameters():
            param.requires_grad = False
    elif freeze_bert is False: # isinstance(False, int) returns True!
        return model
    elif isinstance(freeze_bert, (int, np.int32, np.int64, torch.int32, torch.int64)):
        for param in model.bert.embeddings.parameters():
            param.requires_grad = False  
        for layer in model.bert.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  
    return model

# run

In [11]:
def get_configs(dataset, num_classes, exclude_keys=['repeat'], **kwargs):
    '''
    If want to try different settings, give a list. Otherwise, just a number/str.
    '''
    keys = []
    values = []
    direct_kwargs = {}
    for k,v in kwargs.items():
        if k.lower() not in exclude_keys:
            assert k.lower() in MyConfig.__dict__['__annotations__'], f"{k} doesn't match any MyConfig option"
        if isinstance(v, list):
            keys.append(k)
            values.append(v)
        else:
            direct_kwargs[k]=v

    CONFIGS = itertools.product(*values)
    config_lists = []
    for raw_config in CONFIGS:
        myconfig = MyConfig(dataset=dataset, num_classes=num_classes, **direct_kwargs)
        for k,v in zip(keys, raw_config):
            if k.lower() not in exclude_keys:
                myconfig.__dict__[k.lower()] = v
        myconfig.__post_init__()
        config_lists.append(myconfig)
    return config_lists


In [12]:
modelname = 'Bert + POS'
dataset='ccat50'
num_classes=50

scratch_data_dir = '/scratch/data_jz17d/data'
dataset_dir = f'{scratch_data_dir}/{dataset}'

config_lists = get_configs(modelname=modelname,
                           dataset=dataset, 
                           num_classes=num_classes,
                           num_hierarchy=[1],
                           segment_length=[2, 3, 4, 'doc'],
                           add_sentence_order=True,
                           bert_pos_fuse=['token', 'before hier', 'after hier'],
                           save=True,
                           )

skip_runs = -1
######################## in most cases, no need to edit the section below ##########################
run_pbar = trange(len(config_lists), leave=False)
for i_run, myconfig in enumerate(config_lists):

    if i_run <= skip_runs:
        run_pbar.update(1)
        continue
    
    seed = int(datetime.now().timestamp())
    set_seed(seed)
    
    # load necessary files and dataset
    doc_true = np.load(f'{dataset_dir}/doc_true.npy')
    with open(f'{dataset_dir}/test_docid2index.json') as f:
        test_docid2index = json.load(f, object_hook=keystoint)
    
    train_loader = get_seg_loader(dataset=dataset, segment_length=myconfig.segment_length, split='train', batch_size=myconfig.batch_size, shuffle=True, max_length=myconfig.max_length)
    num_training_steps = len(train_loader)
    test_loader = get_seg_loader(dataset=dataset, segment_length=myconfig.segment_length, split='test', batch_size=myconfig.batch_size, shuffle=True, max_length=myconfig.max_length)
    # num_test_steps = len(test_loader)
    
    # initialize model, optimizere, and lr scheduler
    model = HierGNN(myconfig)
    model = model.to(device)
    model = freeze_model(model, myconfig.freeze_bert)    

    optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad==True], lr=myconfig.lr)
    # paras, other_para = [], []
    # for name, module in model.named_children():
    #     if name == 'bert':
    #         paras.append({"params": [p for p in module.parameters() if p.requires_grad==True], 'lr': 5e-5})
    #     else:
    #         other_para.extend([p for p in module.parameters()])
    # paras.append({"params": other_para, 'lr': myconfig.lr})
    # optimizer = torch.optim.Adam(paras)

    scheduler = get_scheduler("linear",
                            optimizer=optimizer,
                            num_warmup_steps=int(myconfig.warmup_ratio*myconfig.epochs*num_training_steps),
                            num_training_steps=myconfig.epochs*num_training_steps)
    
    # start sync to wandb
    wconfig = {}
    wconfig['seed'] = seed
    wconfig.update(myconfig.__dict__)
    run = wandb.init(project=f"{modelname} {dataset}", 
                     entity="fsu-dsc-cil", 
                     dir='/scratch/data_jz17d/wandb_tmp/', 
                     config=wconfig,
                     name=f'run_{i_run}',
                     reinit=True,
                     settings=wandb.Settings(start_method='thread'))
    
    best_evaluation = collections.defaultdict(float)
    pbar = trange(myconfig.epochs*num_training_steps, leave=False)
    for i_epoch in range(myconfig.epochs):
        # train
        model.train()
        for batch in train_loader:
            batch = batch.to(device, non_blocking=True)
            optimizer.zero_grad(), 
            output = model(batch.text, batch.pos, batch.alignments, batch.edge_index, batch.edge_type_ids, batch.batch, batch.ptr, batch.y, batch.segment_ids, batch.num_syllables, batch.word_freqs)
            output.loss.backward()
            optimizer.step()
            scheduler.step()
            pbar.update(1)

        # evaluate on test set
        model.eval()
        doc_score = 1e-8*np.ones((len(test_docid2index), myconfig.num_classes))
        metric = evaluate.load('/home/jz17d/Desktop/metrics/accuracy')
        for batch in test_loader:
            batch = batch.to(device, non_blocking=True)
            output = model(batch.text, batch.pos, batch.alignments, batch.edge_index, batch.edge_type_ids, batch.batch, batch.ptr, batch.y, batch.segment_ids, batch.num_syllables, batch.word_freqs)
            pred = output.logit.argmax(axis=-1).cpu().detach().numpy()
            metric.add_batch(predictions=pred, references=batch.y.cpu().numpy())
            doc_id = np.vectorize(test_docid2index.get)(batch.doc_id.cpu().detach().numpy()) 
            doc_score[doc_id,pred] += 1
        
        # logging current
        evaluation = metric.compute()
        for k in range(1, 6):
            evaluation.update({f'test_doc_acc@{k}': top_k_accuracy_score(doc_true, doc_score, k=k)})
        wandb.log(evaluation, step=pbar.n)
        
        # logging best
        for key in evaluation:
            best_evaluation[f'best_{key}'] = max(best_evaluation[f'best_{key}'], evaluation[key])
        wandb.log(best_evaluation, step=pbar.n)

    if myconfig.save:
        model_folder = f"{myconfig.save_location}/run_{i_run}"
        os.makedirs(model_folder, exist_ok = True) 
        save_model(model_folder, model, optimizer, scheduler)

    run.finish()
    run_pbar.update(1)

  0%|          | 0/12 [00:00<?, ?it/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'bert.pooler.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at /scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_

  0%|          | 0/258100 [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [None]:
# in case needed

# with open(f'{dataset_dir}/test_docid2index.json') as f:
#     test_docid2index = json.load(f, object_hook=keystoint)
# test_docid2index = {v:k for k,v in test_docid2index.items()}
# with open(f'{dataset_dir}/test_docid2index.json', 'w') as f:
#     json.dump(test_docid2index, f)