In [1]:
import logging
from tqdm.auto import tqdm, trange
import os
import random
import numpy as np
from glob import glob
from typing import List, Dict
import pandas as pd
from nltk.parse.corenlp import CoreNLPParser, CoreNLPDependencyParser
import inspect
from itertools import cycle
import ast
import itertools
from pathlib import Path
from datetime import datetime
import collections

In [2]:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
from transformers import BertForMaskedLM, BertConfig, PreTrainedModel, AutoModel, AutoModelForSequenceClassification
from transformers import DataCollatorForLanguageModeling, DataCollatorWithPadding
from transformers import Trainer, TrainingArguments
from transformers import IntervalStrategy
from transformers import AutoTokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers import get_scheduler
import transformers


from torch.utils.data import Dataset, DataLoader, RandomSampler
import torch

import wandb
import datasets
from datasets import Dataset
from sklearn import preprocessing
import evaluate

from sklearn.metrics import mean_squared_error, accuracy_score, precision_recall_fscore_support
from sklearn.metrics import top_k_accuracy_score

# import matplotlib.pyplot as plt

In [3]:
from typing import Optional

from torch import Tensor

from torch_geometric.utils import scatter

In [4]:
import torch_geometric as pyg
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader


In [5]:
print(torch.__version__, pyg.__version__, transformers.__version__)

1.13.1 2.2.0 4.26.0


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [8]:
def global_min_pool(x: Tensor, batch: Optional[Tensor],
                    size: Optional[int] = None) -> Tensor:
    r"""Returns batch-wise graph-level-outputs by taking the channel-wise
    maximum across the node dimension, so that for a single graph
    :math:`\mathcal{G}_i` its output is computed by

    .. math::
        \mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n.

    Functional method of the
    :class:`~torch_geometric.nn.aggr.MaxAggregation` module.

    Args:
        x (torch.Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
        batch (torch.Tensor, optional): The batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each element to a specific example.
        size (int, optional): The number of examples :math:`B`.
            Automatically calculated if not given. (default: :obj:`None`)
    """
    dim = -1 if x.dim() == 1 else -2

    if batch is None:
        return x.max(dim=dim, keepdim=x.dim() <= 2)[0]
    size = int(batch.max().item() + 1) if size is None else size
    return scatter(x, batch, dim=dim, dim_size=size, reduce='min')


In [9]:
cols_to_eval = ['edge_indexs', 'hetoro_edges', 'pos_seqs', 'upos_seqs', 'num_syllables', 'alignments']

In [10]:
def preprocess_author_ids(df, col='author'):
    assert col in df, f'no column named {col} found in df'
    
    unique_author = sorted(df['author'].unique())
    mapping = {unique_author[i]:i for i in range(len(unique_author))}
    df[col] = df[col].map(mapping)
    
    return df

In [11]:
def freeze_model(model, freeze_bert):
    '''
    if freeze_bert is True, freeze all layer. 
    if freeze_bert is a positive integer, freeze the bottom {freeze_bert} attention layers
    negative integer should also work
    '''
    if freeze_bert is True: # == True is wrong!!!
        for param in model.bert.parameters():
            param.requires_grad = False
    elif freeze_bert is False: # isinstance(False, int) returns True!
        return model
    elif isinstance(freeze_bert, (int, np.int32, np.int64, torch.int32, torch.int64)):
        for param in model.bert.embeddings.parameters():
            param.requires_grad = False  
        for layer in model.bert.encoder.layer[:freeze_bert]: 
            for param in layer.parameters():
                param.requires_grad = False  
    return model

In [12]:
def nested_to(dic, device):
    for k,v in dic.items():
        dic[k] = v.to(device)
    return dic

In [13]:
# English specific denpendency relations: https://universaldependencies.org/en/dep/
s = '''nsubj 	csubj
↳nsubj:pass 	↳csubj:pass
↳nsubj:outer 	↳csubj:outer
obj 	ccomp 	xcomp
iobj
obl 	advcl 	advmod
↳obl:npmod 	↳advcl:relcl
↳obl:tmod
vocative 	aux 	mark
discourse 	↳aux:pass
expl 	cop
nummod 	acl 	amod
  	↳acl:relcl
appos 	  	det
  	  	↳det:predet
nmod 	  	 
↳nmod:npmod
↳nmod:tmod
↳nmod:poss
compound 	flat
↳compound:prt 	↳flat:foreign
fixed 	goeswith
conj 	cc
  	↳cc:preconj
list 	parataxis 	orphan
dislocated 		reparandum
root 	punct 	dep'''
all_relations = []
s = s.split('\n')
for line in s:
    if '↳' in line:
        continue
    line = line.split('\t')
    for r in line:
        if r.strip() == '':
            continue
        all_relations.append(r.split(':')[0].strip())
if 'root' in all_relations:
    all_relations.remove('root')
    all_relations.append('ROOT')
    all_relations.append('case')      # manually add relation not in list
    all_relations.append('discourse')    # manually add relation not in list
all_relations = sorted(all_relations)

In [14]:
relation2id = {all_relations[i]:i for i in range(len(all_relations))}
relation2id['self'] = 36 # add self loop type

In [15]:
def clean_text(text):
    text = text.replace('&amp;', '')
    # corenlp and bert deal with xxxn't differently
    # need to add a space inbetween
    text = text.replace("dont", "don't")
    text = text.replace("doesnt", "doesn't")
    text = text.replace("wont", "will n't")
    text = text.replace("n\'t", " n\'t")
    text = text.replace("N\'T", " N\'T")
    text = text.replace("cannot", "can not")
    return text

In [16]:
def get_loader(df, add_syllables=False, col='pos_seqs', limit=None, batch_size=32, shuffle=True, max_length=128):
    data_list = []
    if limit is not None:
        dfnew = df.sample(frac=1).reset_index(drop=True)[:limit]
    else:
        dfnew = df
    data_list = []
    count = 0
    for i in trange(len(dfnew), leave=False):
        curr = df.iloc[i]
        data = Data()
        data.edge_index = torch.cat([torch.tensor([[0],[0]]),  # for self loop of CLS token
                                     torch.tensor(curr['edge_indexs']).T, 
                                     # for batching purpose, if data.x is missing, edge_index is used to inference batch
                                     # an isolated node (the SEP in this case) will mess all up
                                     torch.tensor([[len(curr['edge_indexs'])+1],[len(curr['edge_indexs'])+1]])], 
                                    axis=1)
        
        # add self loop only for cls and sep tokens
        data.edge_type_ids = torch.tensor([36]+[relation2id[t.split(':')[0]] for t in curr['hetoro_edges']]+[36])
        if data.edge_index.shape[1] >= max_length-1:
            count += 1
#             print(f"data {i} too long length {data.edge_index.shape[1]}")
            continue
        
        data.text = clean_text(curr['text'])
        data.pos = ' '.join(curr[col])
        data.y = torch.tensor([curr['author']])
        if add_syllables:
            data.num_syllables = torch.tensor([17]+curr['num_syllables']+[17])
            
        if 'doc_id' in curr:
            data.doc_id = torch.tensor([curr['doc_id']])
            
        data.num_nodes = len(data.edge_type_ids)
        data.alignments = curr['alignments']
        data_list.append(data)
    print(f'{count} data dropped because of exceeding max_length {max_length}')
    loader = DataLoader(data_list, batch_size=batch_size, shuffle=shuffle)
    return loader


In [17]:
def preprocess_author_ids(df):
    assert 'author' in df, 'no column named "author" found in df'
    
    max_id, min_id = df['author'].max(), df['author'].min()
    mapping = {i+min_id:i for i in range(max_id-min_id+1)}
    df['author'] = df['author'].map(mapping)
    
    return df

In [18]:
from dataclasses import dataclass
@dataclass
class myGNNoutput:
    loss: None
    logit: None
    emb: None

In [19]:
hiddensize2checkpoint = {64: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_22/checkpoint-95000/",
                         48: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_16/checkpoint-95000/",
                         32: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_10/checkpoint-145000/",
                         16: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_5/checkpoint-95000/",}

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv, PDNConv, global_mean_pool, global_max_pool

GNNtype2layer = {'GATConv':GATConv, 'GATv2Conv':GATv2Conv, 'TransformerConv':TransformerConv, 'PDNConv':PDNConv}

class SemSynGNN(torch.nn.Module):
    def __init__(self, 
                 num_layers, 
                 num_classes, 
                 num_dep_type, 
                 heads, 
                 hidden_dim, 
                 dep_emb_dim=32, 
                 add_self_loops=False, 
                 gnntype='GATConv', 
                 add_syllables=None,
                 pos_checkpoint="/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_22/checkpoint-95000/",
                 checkpoint='bert-base-uncased',
                 max_length=256,
                 dropout=0.1):
        
        super().__init__()
        self.pos_checkpoint = pos_checkpoint
        self.pos_tokenizer = AutoTokenizer.from_pretrained(self.pos_checkpoint, local_files_only=True)
        self.pos_bert = BertModel.from_pretrained(self.pos_checkpoint, local_files_only=True, add_pooling_layer = False).to(device)
        
        self.checkpoint = checkpoint
        self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint, local_files_only=True)
        self.bert = BertModel.from_pretrained(self.checkpoint, add_pooling_layer = False).to(device)
        
        self.max_length = max_length
        
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.pos_emb_dim = 64 # this is determined by POS Bert
        self.heads = heads
        self.hidden_dim = hidden_dim
        self.dep_emb_dim = dep_emb_dim
        self.add_syllables = add_syllables
        self.dropout = dropout
        
        self.gnn_dim = self.pos_emb_dim + 768
        
        if add_syllables:
            self.num_syllables = 18 # the longest word has 17 syllables
            self.syllable_emb_layer = nn.Embedding(self.num_syllables, self.pos_emb_dim)
            
        self.GNNlayer = GNNtype2layer[gnntype]
        
        self.add_self_loops = add_self_loops
        self.dep_emb_layer = nn.Embedding(num_dep_type, self.dep_emb_dim)
        
        self.gnns = nn.ModuleList()
        for i in range(self.num_layers):
            self.gnns.append(self.GNNlayer(self.gnn_dim, self.gnn_dim//self.heads, heads = self.heads, edge_dim=self.dep_emb_dim, beta=True))
        
        self.layernorms = nn.ModuleList()
        for i in range(self.num_layers):
            self.layernorms.append(nn.LayerNorm(self.gnn_dim))
            
        self.classifier = nn.Linear(3*self.gnn_dim, self.num_classes)
        self.lossfn = nn.CrossEntropyLoss()
        
    def forward(self, text, pos, alignments, edge_index, edge_type_ids, batch, y, ptr, num_syllable=None, readout='pool'):
        # word embeddings
        # merge subwords and concatenate
        word_embs = []
        for t,al in zip(text, alignments):
            word_tokens = self.tokenizer(t, padding=True, truncation=True, return_tensors='pt').to(device)
            word_emb = self.bert(**word_tokens).last_hidden_state.squeeze(0)
            al = torch.LongTensor([-1]+al+[al[-1]+1])+1
            al = al.to(device)
            zero_emb = torch.zeros(al[-1]+1, 768).to(device)
            word_embs.append(zero_emb.index_reduce(0, al, word_emb, 'mean', include_self=False))
        word_embs = torch.concat(word_embs, axis=0)    
        
        # pos embeddings
        # drop padded tokens then flatten 
        pos_tokens = self.pos_tokenizer(pos, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
        pos_emb = self.pos_bert(**pos_tokens).last_hidden_state
        pos_emb = pos_emb.masked_select(pos_tokens.attention_mask.ge(0.5).unsqueeze(2)).reshape((-1,self.pos_emb_dim))
        if self.add_syllables:
            syllable_emb = self.syllable_emb_layer(num_syllable)
            pos_emb = pos_emb + syllable_emb
        
        x = torch.concat([word_embs, pos_emb], axis=1)
        
        edge_attr = self.dep_emb_layer(edge_type_ids)
        
        for i in range(self.num_layers):
            x = self.gnns[i](self.layernorms[i](x), edge_index, edge_attr=edge_attr).relu() + x
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        if readout == 'pool':
            x = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch), global_min_pool(x, batch)], axis=1)
        elif readout == 'cls':
            x = x[ptr[:-1],:]
            
        x = F.dropout(x, p=self.dropout, training=self.training)

        logit = self.classifier(x)
        loss = self.lossfn(logit, y)
        return myGNNoutput(loss=loss, logit=logit, emb=x)

In [None]:
max_length = 256

epochs = 100
warmup_ratio = 0.15
freeze_bert = 11 # 10, 11
batch_size = 64
pos_bert_dim = 64
pos_checkpoint = hiddensize2checkpoint[pos_bert_dim]

LIMIT = [None]
NUM_LAYERS = [4]
LR = [1e-3]
HEADS = [4] 
READOUT = ['pool']
GNNTYPE = ['TransformerConv'] # 'GATConv', 'GATv2Conv', 'TransformerConv'
HIDDEN_DIM = [(768+64)//4] # not used 
DEP_EMB_DIM = [64]
NUM_SENT = [1,2,3]
ADD_SELF_LOOPS = [False]
ADD_SYLLABLES = [True]
REPEAT = list(range(1))

ARGS = itertools.product(LIMIT, NUM_LAYERS, LR, HEADS, READOUT, GNNTYPE, HIDDEN_DIM, DEP_EMB_DIM, NUM_SENT, ADD_SELF_LOOPS, ADD_SYLLABLES, REPEAT)
num_runs = len(list(ARGS))
run_pbar = trange(num_runs, leave=False)

skip_runs = -1
ARGS = itertools.product(LIMIT, NUM_LAYERS, LR, HEADS, READOUT, GNNTYPE, HIDDEN_DIM, DEP_EMB_DIM, NUM_SENT, ADD_SELF_LOOPS, ADD_SYLLABLES, REPEAT)
for i_run, args in enumerate(ARGS):

    if i_run <= skip_runs:
        run_pbar.update(1)
        continue
    limit, num_layers, lr, heads, readout, gnntype, hidden_dim, dep_emb_dim, num_sent_per_text, add_self_loops, add_syllables, repeat = args
    
    seed = int(datetime.now().timestamp())
    set_seed(seed)
    
    file = f'../../data/CCAT50/processed/author_all_sent_{num_sent_per_text}_0.csv'
    df = pd.read_csv(file)
    for col in cols_to_eval:
        df[col] = df[col].apply(ast.literal_eval)

    file = f'../../data/CCAT50/processed/author_all_sent_{num_sent_per_text}_1.csv'
    df_val = pd.read_csv(file)
    for col in cols_to_eval:
        df_val[col] = df_val[col].apply(ast.literal_eval)
    val_docid2index = {doc_id:i for i,doc_id in enumerate(df_val['doc_id'].unique())}
    
    valid_loader = get_loader(df_val, add_syllables=add_syllables, max_length=max_length, batch_size=batch_size)
    num_valid_steps = len(valid_loader)
    train_loader = get_loader(df, limit = limit, add_syllables=add_syllables, max_length=max_length, batch_size=batch_size)
    num_training_steps = len(train_loader)
    
    model = SemSynGNN(num_layers=num_layers,
                       num_classes=50, 
                       num_dep_type=len(relation2id), 
                       heads=heads,
                       hidden_dim=hidden_dim,
                       dep_emb_dim=dep_emb_dim, 
                       add_self_loops=add_self_loops,
                       gnntype=gnntype,
                       add_syllables=add_syllables,
                       pos_checkpoint=pos_checkpoint,
                      )
    
    model = model.to(device)
    model = freeze_model(model, freeze_bert)    
    
    para = []
    for name, module in model.named_children():
        if name == 'bert':
            para.append({"params": [p for p in module.parameters() if p.requires_grad==True], 'lr': 5e-5})
        else:
            para.append({"params": module.parameters(), 'lr': lr})

    optimizer = torch.optim.Adam(para)
    
    scheduler = get_scheduler("linear",
                            optimizer=optimizer,
                            num_warmup_steps=int(warmup_ratio*epochs*num_training_steps),
                            num_training_steps=epochs*num_training_steps)
    
    wconfig = {}
    wconfig['seed'] = seed
    wconfig['num_sent_per_text'] = num_sent_per_text
    wconfig['limit'] = limit
    wconfig['num_layers'] = num_layers
    wconfig['lr'] = lr
    wconfig['heads'] = heads
    wconfig['readout'] = readout
    wconfig['GNNtype'] = gnntype
    wconfig['add_self_loops'] = add_self_loops
    wconfig['add_syllables'] = add_syllables
    wconfig['pos_checkpoint'] = pos_checkpoint
    
    run = wandb.init(project="SemSynGNN (all authors, bert unfrozen)", 
                     entity="fsu-dsc-cil", 
                     dir='/scratch/data_jz17d/wandb_tmp/', 
                     config=wconfig,
                     name=f'run_{i_run}',
                     reinit=True,
                     settings=wandb.Settings(start_method='thread'))
    
    best_evaluation = collections.defaultdict(float)
    pbar = trange(epochs*num_training_steps, leave=False)
    for i_epoch in range(epochs):
        model.train()
        for data in train_loader:
            data.to(device)
            optimizer.zero_grad()
            if add_syllables:
                output = model(data.text, data.pos, data.alignments, data.edge_index, data.edge_type_ids, data.batch, data.y, data.ptr, data.num_syllables, readout=readout)
            else:
                output = model(data.text, data.pos, data.alignments, data.edge_index, data.edge_type_ids, data.batch, data.y, data.ptr, readout=readout)
            output.loss.backward()
            optimizer.step()
            scheduler.step()
            pbar.update(1)

        model.eval()
        doc_score = 1e-8*np.ones((len(val_docid2index),50))
        doc_true = np.zeros(len(val_docid2index))
        metric = evaluate.load('/home/jz17d/Desktop/metrics/accuracy')
        for data in valid_loader:
            data.to(device)
            if add_syllables:
                output = model(data.text, data.pos, data.alignments, data.edge_index, data.edge_type_ids, data.batch, data.y, data.ptr, data.num_syllables, readout=readout)
            else:
                output = model(data.text, data.pos, data.alignments, data.edge_index, data.edge_type_ids, data.batch, data.y, data.ptr, readout=readout)
            metric.add_batch(predictions=output.logit.argmax(axis=-1).cpu().detach().numpy(), references=data.y.cpu().numpy())
            
            pred = output.logit.argmax(axis=-1).cpu().detach().numpy()
            doc_id = np.vectorize(val_docid2index.get)(data.doc_id.cpu().detach().numpy()) 
            doc_score[doc_id,pred] += 1
            doc_true[doc_id] = data.y.cpu().numpy()
        
        # logging
        evaluation = metric.compute()
        for k in range(1, 6):
            evaluation.update({f'doc_acc@{k}': top_k_accuracy_score(doc_true, doc_score, k=k)})
        wandb.log(evaluation, step=pbar.n)
        
        # logging best
        for key in evaluation:
            best_evaluation[f'best_{key}'] = max(best_evaluation[f'best_{key}'], evaluation[key])
        wandb.log(best_evaluation, step=pbar.n)
    
    run.finish()
    run_pbar.update(1)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/8207 [00:00<?, ?it/s]

0 data dropped because of exceeding max_length 256


  0%|          | 0/32937 [00:00<?, ?it/s]

Some weights of the model checkpoint at /scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_22/checkpoint-95000/ were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


1 data dropped because of exceeding max_length 256


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'bert.pooler.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
[34m[1mwandb[0m: Currently logged in as: [33mcpuyyp[0m ([33mfsu-dsc-cil[0m). Use [1m`wandb lo

  0%|          | 0/51500 [00:00<?, ?it/s]

In [None]:
run.finish()

In [None]:
def count_grad_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad )
def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
count_grad_parameters(model),count_all_parameters(model)

(14810354, 109526258)

In [None]:
count_grad_parameters(model.bert), count_all_parameters(model.bert)

(14175744, 108891648)

In [None]:
count_grad_parameters(model.gnns)

559104