In [1]:
from tqdm.auto import trange, tqdm
import pandas as pd
import ast
import itertools
from itertools import groupby

import wandb
import evaluate
from itertools import cycle
import numpy as np
import random
import time

from datetime import datetime
import collections
from sklearn.metrics import top_k_accuracy_score

In [2]:
import torch_geometric.transforms as T
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv, PDNConv, global_mean_pool, global_max_pool
import torch_geometric as pyg

import torch
import torch.nn as nn
import torch.nn.functional as F


import transformers
from transformers import get_scheduler
from transformers import AutoTokenizer
from transformers.models.bert.modeling_bert import BertModel


In [3]:
from typing import Optional

from torch import Tensor

from torch_geometric.utils import scatter

In [4]:
transformers.__version__, pyg.__version__

('4.26.0', '2.2.0')

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

In [7]:
from dataclasses import dataclass
@dataclass
class myGNNoutput:
    loss: None
    logit: None
    emb: None
    doc_y: None

In [8]:
def global_min_pool(x: Tensor, batch: Optional[Tensor],
                    size: Optional[int] = None) -> Tensor:
    r"""Returns batch-wise graph-level-outputs by taking the channel-wise
    maximum across the node dimension, so that for a single graph
    :math:`\mathcal{G}_i` its output is computed by

    .. math::
        \mathbf{r}_i = \mathrm{max}_{n=1}^{N_i} \, \mathbf{x}_n.

    Functional method of the
    :class:`~torch_geometric.nn.aggr.MaxAggregation` module.

    Args:
        x (torch.Tensor): Node feature matrix
            :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
        batch (torch.Tensor, optional): The batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
            each element to a specific example.
        size (int, optional): The number of examples :math:`B`.
            Automatically calculated if not given. (default: :obj:`None`)
    """
    dim = -1 if x.dim() == 1 else -2

    if batch is None:
        return x.max(dim=dim, keepdim=x.dim() <= 2)[0]
    size = int(batch.max().item() + 1) if size is None else size
    return scatter(x, batch, dim=dim, dim_size=size, reduce='min')


In [9]:
# English specific denpendency relations: https://universaldependencies.org/en/dep/
s = '''nsubj 	csubj
↳nsubj:pass 	↳csubj:pass
↳nsubj:outer 	↳csubj:outer
obj 	ccomp 	xcomp
iobj
obl 	advcl 	advmod
↳obl:npmod 	↳advcl:relcl
↳obl:tmod
vocative 	aux 	mark
discourse 	↳aux:pass
expl 	cop
nummod 	acl 	amod
  	↳acl:relcl
appos 	  	det
  	  	↳det:predet
nmod 	  	 
↳nmod:npmod
↳nmod:tmod
↳nmod:poss
compound 	flat
↳compound:prt 	↳flat:foreign
fixed 	goeswith
conj 	cc
  	↳cc:preconj
list 	parataxis 	orphan
dislocated 		reparandum
root 	punct 	dep'''
all_relations = []
s = s.split('\n')
for line in s:
    if '↳' in line:
        continue
    line = line.split('\t')
    for r in line:
        if r.strip() == '':
            continue
        all_relations.append(r.split(':')[0].strip())
if 'root' in all_relations:
    all_relations.remove('root')
    all_relations.append('ROOT')
    all_relations.append('case')      # manually add relation not in list
    all_relations.append('discourse')    # manually add relation not in list
all_relations = sorted(all_relations)

In [10]:
relation2id = {all_relations[i]:i for i in range(len(all_relations))}
relation2id['self'] = 36 # add self loop type

In [11]:
cols_to_eval = ['edge_indexs', 'hetoro_edges', 'pos_seqs', 'upos_seqs', 'num_syllables', 'alignments']

In [12]:
from collections.abc import Mapping
from typing import List, Optional, Sequence, Union

import torch.utils.data
from torch.utils.data.dataloader import default_collate

from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData

class ParagraphBatchCollater:
    def __init__(self, follow_batch, exclude_keys):
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

    def __call__(self, batch): 
        # batch is a list of lists of BaseData. E.g.
        # [[Data(), # para 1
        #   Data(),
        #   Data()],
        #  [Data(), # para 2
        #   Data()]]
        
        # so it needs only one extra step, flatten the list
        batch = [item for sublist in batch for item in sublist]
        
        elem = batch[0]
        if isinstance(elem, BaseData):
            return Batch.from_data_list(batch, self.follow_batch,
                                        self.exclude_keys)
        elif isinstance(elem, torch.Tensor):
            return default_collate(batch)
        elif isinstance(elem, float):
            return torch.tensor(batch, dtype=torch.float)
        elif isinstance(elem, int):
            return torch.tensor(batch)
        elif isinstance(elem, str):
            return batch
        elif isinstance(elem, Mapping):
            return {key: self([data[key] for data in batch]) for key in elem}
        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
            return type(elem)(*(self(s) for s in zip(*batch)))
        elif isinstance(elem, Sequence) and not isinstance(elem, str):
            return [self(s) for s in zip(*batch)]

        raise TypeError(f'DataLoader found invalid type: {type(elem)}')

    def collate(self, batch):  # pragma: no cover
        # TODO Deprecated, remove soon.
        return self(batch)

In [13]:
# from pyg source code
class ParagraphDataLoader(torch.utils.data.DataLoader):
    r"""A data loader which merges data objects from a
    :class:`torch_geometric.data.Dataset` to a mini-batch.
    Data objects can be either of type :class:`~torch_geometric.data.Data` or
    :class:`~torch_geometric.data.HeteroData`.

    Args:
        dataset (Dataset): The dataset from which to load the data.
        batch_size (int, optional): How many samples per batch to load.
            (default: :obj:`1`)
        shuffle (bool, optional): If set to :obj:`True`, the data will be
            reshuffled at every epoch. (default: :obj:`False`)
        follow_batch (List[str], optional): Creates assignment batch
            vectors for each key in the list. (default: :obj:`None`)
        exclude_keys (List[str], optional): Will exclude each key in the
            list. (default: :obj:`None`)
        **kwargs (optional): Additional arguments of
            :class:`torch.utils.data.DataLoader`.
    """
    def __init__(
        self,
        dataset: Union[Dataset, Sequence[BaseData]],
        batch_size: int = 1,
        shuffle: bool = False,
        follow_batch: Optional[List[str]] = None,
        exclude_keys: Optional[List[str]] = None,
        **kwargs,
    ):

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        # Save for PyTorch Lightning < 1.6:
        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=ParagraphBatchCollater(follow_batch, exclude_keys),
            **kwargs,
        )

In [14]:
def get_loader(df, add_syllables=False, col='pos_seqs', limit=None, batch_size=32, shuffle=True, max_length=128):
    data_list = []
    if limit is not None:
        dfnew = df.sample(frac=1).reset_index(drop=True)[:limit]
    else:
        dfnew = df
    data_lists = []
    count = 0
    
    for i in trange(len(dfnew), leave=False):
        curr = df.iloc[i]
        data_list = []
        for j in range(len(curr['edge_indexs'])):
            data = Data()
            data.edge_index = torch.cat([torch.tensor([[0],[0]]),  # for self loop of CLS token
                                         torch.tensor(curr['edge_indexs'][j]).T, 
                                         # for batching purpose, if data.x is missing, edge_index is used to inference batch
                                         # an isolated node (the SEP in this case) will mess all up
                                         torch.tensor([[len(curr['edge_indexs'][j])+1],[len(curr['edge_indexs'][j])+1]])], 
                                        axis=1)
            data.edge_type_ids = torch.tensor([36]+[relation2id[t.split(':')[0]] for t in curr['hetoro_edges'][j]]+[36])
            if data.edge_index.shape[1] >= max_length-1:
                count += 1
                continue
        
            data.text = ' '.join(curr[col][j])
            data.y = torch.tensor([curr['author']])
            if add_syllables:
                data.num_syllables = torch.tensor([17]+curr['num_syllables'][j]+[17])
            
            if 'doc_id' in curr:
                data.doc_id = torch.tensor([curr['doc_id']])

            data.num_nodes = len(data.edge_type_ids)
            data_list.append(data)
        data_lists.append(data_list)
        
    print(f'{count} data dropped because of exceeding max_length {max_length}')
    loader = ParagraphDataLoader(data_lists, batch_size=batch_size, shuffle=shuffle)
    return loader


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv, PDNConv, global_mean_pool

GNNtype2layer = {'GATConv':GATConv, 'GATv2Conv':GATv2Conv, 'TransformerConv':TransformerConv, 'PDNConv':PDNConv}

class HierGNN(torch.nn.Module):
    def __init__(self, 
                 num_layers, 
                 num_classes, 
                 num_dep_type, 
                 heads, 
                 hidden_dim, 
                 num_hier_layers=1, 
                 dep_emb_dim=32, 
                 add_self_loops=False, 
                 gnntype='GATConv', 
                 add_syllables=None,
                 checkpoint='/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_256_pos_mlm_0_recovered/',
                 max_length=256,
                 dropout=0.1):
        
        super().__init__()
        self.checkpoint = checkpoint
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint, local_files_only=True)
        self.bert = BertModel.from_pretrained(self.checkpoint, local_files_only=True, add_pooling_layer = False).to(device)
        self.num_layers = num_layers
        self.num_hier_layers = num_hier_layers
        self.num_classes = num_classes
        self.pos_emb_dim = 64 # this is determined by POS Bert
        self.heads = heads
        self.hidden_dim = hidden_dim
        self.dep_emb_dim = dep_emb_dim
        self.add_syllables = add_syllables
        self.dropout=dropout
        
        if add_syllables:
            self.num_syllables = 18 # the longest word has 17 syllables
            self.syllable_emb_layer = nn.Embedding(self.num_syllables, self.pos_emb_dim)
            
        self.GNNlayer = GNNtype2layer[gnntype]
        
        self.add_self_loops = add_self_loops
        self.dep_emb_layer = nn.Embedding(num_dep_type, self.dep_emb_dim)
        
        self.gnns = nn.ModuleList()
        for i in range(self.num_layers):
            self.gnns.append(self.GNNlayer(self.pos_emb_dim, self.pos_emb_dim//self.heads, heads = self.heads, add_self_loops=self.add_self_loops, edge_dim=self.dep_emb_dim, beta=True))
            
        
        # layernorms for above gnn 
        self.layernorms = nn.ModuleList()
        for i in range(self.num_layers):
            self.layernorms.append(nn.LayerNorm(self.pos_emb_dim))
        
        # hierarchical layer
        self.transformer_layer1 = TransformerConv(3*self.pos_emb_dim, 3*self.pos_emb_dim//self.heads, heads=self.heads, beta=True)
        self.layer_norm1 = nn.LayerNorm(3*self.pos_emb_dim)
        
        self.transformer_layer2 = TransformerConv(3*self.pos_emb_dim, 3*self.pos_emb_dim//self.heads, heads=self.heads, beta=True)
        self.layer_norm2 = nn.LayerNorm(3*self.pos_emb_dim)
        
        self.classifier = nn.Linear(9*self.pos_emb_dim, self.num_classes)
        self.lossfn = nn.CrossEntropyLoss()
        
    def forward(self, text, edge_index, edge_type_ids, batch, y, doc_id, ptr, num_syllable=None, readout='pool'):
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt').to(device)
        x = self.bert(**tokens).last_hidden_state
        # reshape! drop padded tokens!
        x = x.masked_select(tokens.attention_mask.ge(0.5).unsqueeze(2)).reshape((-1,self.pos_emb_dim))
        
        if self.add_syllables:
            syllable_emb = self.syllable_emb_layer(num_syllable)
            x = x + syllable_emb
            
        edge_attr = self.dep_emb_layer(edge_type_ids)
        for i in range(self.num_layers):
            if i == 0:
                x = self.gnns[i](self.layernorms[i](x), edge_index, edge_attr=edge_attr).relu()
            else:
                x = self.gnns[i](self.layernorms[i](x), edge_index, edge_attr=edge_attr).relu() + x
            x = F.dropout(x, p=self.dropout, training=self.training)
        
        # readout to get sentence embeddings
        if readout == 'pool':
            x = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch), global_min_pool(x, batch)], axis=1)
        elif readout == 'cls':
            x = x[ptr[:-1],:]
        
        # calculate edge_index between sentences from the same paragraph
        edges_among_sentences = torch.LongTensor().to(device)
        doc_y = torch.zeros(doc_id.unique().shape[0], device=device).long()
        batch_doc = torch.LongTensor()
        for ii, i in enumerate([i for i, j in groupby(doc_id.tolist())]):
            idx = (doc_id==i).nonzero().long().squeeze(1) 
            doc_y[ii] = y[idx[0]]
            batch_doc = torch.cat([batch_doc, torch.as_tensor([ii]*len(idx), dtype=torch.long)])
            edge_x, edge_y = torch.meshgrid(idx, idx)
            edge = torch.vstack([edge_x.flatten(), edge_y.flatten()])
            edges_among_sentences = torch.cat([edges_among_sentences, edge], axis = 1)
        
        x = self.transformer_layer1(self.layer_norm1(x), edges_among_sentences).relu() + x
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.transformer_layer2(self.layer_norm2(x), edges_among_sentences).relu() + x
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        # readout to get doc embeddings
        batch_doc = batch_doc.to(device)
        x = torch.cat([global_mean_pool(x, batch_doc), global_max_pool(x, batch_doc), global_min_pool(x, batch_doc)], axis=1)
        
        x = F.dropout(x, p=self.dropout, training=self.training)

        logit = self.classifier(x)
        loss = self.lossfn(logit, doc_y)
        return myGNNoutput(loss=loss, logit=logit, emb=x, doc_y=doc_y)

In [16]:
hiddensize2checkpoint = {64: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_22/checkpoint-95000/",
                         48: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_16/checkpoint-95000/",
                         32: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_10/checkpoint-145000/",
                         16: "/scratch/data_jz17d/result/pos_mlm_corenlp/retrained_all_pos_mlm_5/checkpoint-95000/",}

In [None]:
max_length = 256

epochs = 200
warmup_ratio = 0.15
monitering_metric = 'accuracy'


pos_hidden_dim = 64
checkpoint = hiddensize2checkpoint[pos_hidden_dim]

LIMIT = [None]
NUM_LAYERS = [4]
LR = [1e-3]
HEADS = [4]
READOUT = ['pool']
GNNTYPE = ['TransformerConv'] # 'GATConv', 'GATv2Conv', 
HIDDEN_DIM = [64] # not used
DEP_EMB_DIM = [64]
NUM_SENT = [3,2,1]
ADD_SELF_LOOPS = [False]
ADD_SYLLABLES = [True, False]
REPEAT = list(range(1))

ARGS = itertools.product(LIMIT, NUM_LAYERS, LR, HEADS, READOUT, GNNTYPE, HIDDEN_DIM, DEP_EMB_DIM, NUM_SENT, ADD_SELF_LOOPS, ADD_SYLLABLES, REPEAT)
num_runs = len(list(ARGS))
run_pbar = trange(num_runs, leave=False)

skip_runs = -1
ARGS = itertools.product(LIMIT, NUM_LAYERS, LR, HEADS, READOUT, GNNTYPE, HIDDEN_DIM, DEP_EMB_DIM, NUM_SENT, ADD_SELF_LOOPS, ADD_SYLLABLES, REPEAT)
for i_run, args in enumerate(ARGS):

    if i_run <= skip_runs:
        run_pbar.update(1)
        continue
    limit, num_layers, lr, heads, readout, gnntype, hidden_dim, dep_emb_dim, num_sent_per_text, add_self_loops, add_syllables, repeat = args
    
    seed = int(datetime.now().timestamp())
    set_seed(seed)
    
    file = f'../../data/CCAT50/processed/author_all_sent_{num_sent_per_text}_0.csv'
    df = pd.read_csv(file)
    for col in cols_to_eval:
        df[col] = df[col].apply(ast.literal_eval)
    gb = df.groupby('doc_id')
    df_agg = gb.agg({'text': list, 'author':lambda x: x.iloc[0] , 'edge_indexs':list, 'hetoro_edges':list, 'pos_seqs':list, 'upos_seqs':list, 'num_syllables':list})
    df_doc_train = df_agg.reset_index()
    
    file = f'../../data/CCAT50/processed/author_all_sent_{num_sent_per_text}_1.csv'
    df_val = pd.read_csv(file)
    for col in cols_to_eval:
        df_val[col] = df_val[col].apply(ast.literal_eval)
    gb = df_val.groupby('doc_id')
    df_agg = gb.agg({'text': list, 'author':lambda x: x.iloc[0] , 'edge_indexs':list, 'hetoro_edges':list, 'pos_seqs':list, 'upos_seqs':list, 'num_syllables':list})
    df_doc_val = df_agg.reset_index()
    
    val_docid2index = {doc_id:i for i,doc_id in enumerate(df_val['doc_id'].unique())}
    
    train_loader = get_loader(df_doc_train, add_syllables=add_syllables, batch_size=4, shuffle=True, max_length=max_length)
    num_training_steps = len(train_loader)
    valid_loader = get_loader(df_doc_val, add_syllables=add_syllables, batch_size=4, shuffle=True, max_length=max_length)
    num_valid_steps = len(valid_loader)
    
    model = HierGNN(num_layers=num_layers,
                    num_classes=50, 
                    num_dep_type=len(relation2id), 
                    heads=heads,
                    hidden_dim=hidden_dim,
                    dep_emb_dim=dep_emb_dim, 
                    add_self_loops=add_self_loops,
                    gnntype=gnntype,
                    add_syllables=add_syllables,
                    checkpoint=checkpoint,
                   )
    
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    scheduler = get_scheduler("linear",
                            optimizer=optimizer,
                            num_warmup_steps=int(warmup_ratio*epochs*num_training_steps),
                            num_training_steps=epochs*num_training_steps)
    
    wconfig = {}
    wconfig['seed'] = seed
    wconfig['num_sent_per_text'] = num_sent_per_text
    wconfig['limit'] = limit
    wconfig['num_layers'] = num_layers
    wconfig['lr'] = lr
    wconfig['heads'] = heads
    wconfig['readout'] = readout
    wconfig['GNNtype'] = gnntype
    wconfig['add_self_loops'] = add_self_loops
    wconfig['add_syllables'] = add_syllables
    wconfig['pooling_method'] = 'mean+max+min'
    wconfig['checkpoint'] = checkpoint
    
    
    
    run = wandb.init(project="hierarchical POS GNN", 
                     entity="fsu-dsc-cil", 
                     dir='/scratch/data_jz17d/wandb_tmp/', 
                     config=wconfig,
                     name=f'run_{i_run}',
                     reinit=True,
                     settings=wandb.Settings(start_method='thread'))
    
    best_evaluation = collections.defaultdict(float)
    pbar = trange(epochs*num_training_steps, leave=False)
    for i_epoch in range(epochs):
        model.train()
        for data in train_loader:
            data.to(device)
            optimizer.zero_grad()
            if add_syllables:
                output = model(data.text, data.edge_index, data.edge_type_ids, data.batch, data.y, data.doc_id, data.ptr, data.num_syllables, readout=readout)
            else:
                output = model(data.text, data.edge_index, data.edge_type_ids, data.batch, data.y, data.doc_id, data.ptr, readout=readout)
            output.loss.backward()
            optimizer.step()
            scheduler.step()
            pbar.update(1)

        model.eval()
        doc_score = 1e-8*np.ones((len(val_docid2index),50))
        doc_true = np.zeros(len(val_docid2index))
        metric = evaluate.load('/home/jz17d/Desktop/metrics/accuracy')
        for data in valid_loader:
            data.to(device)
            if add_syllables:
                output = model(data.text, data.edge_index, data.edge_type_ids, data.batch, data.y, data.doc_id, data.ptr, data.num_syllables, readout=readout)
            else:
                output = model(data.text, data.edge_index, data.edge_type_ids, data.batch, data.y, data.doc_id, data.ptr, readout=readout)
            metric.add_batch(predictions=output.logit.argmax(axis=-1).cpu().detach().numpy(), references=output.doc_y.cpu().numpy())
            
#             pred = output.logit.argmax(axis=-1).cpu().detach().numpy()
#             doc_id = np.vectorize(val_docid2index.get)(data.doc_id.cpu().detach().numpy()) 
#             doc_score[doc_id,pred] += 1
#             doc_true[doc_id] = data.y.cpu().numpy()
        
        # logging
        evaluation = metric.compute()
#         for k in range(1, 6):
#             evaluation.update({f'doc_acc@{k}': top_k_accuracy_score(doc_true, doc_score, k=k)})
        wandb.log(evaluation, step=pbar.n)
        
        # logging best
        for key in evaluation:
            best_evaluation[f'best_{key}'] = max(best_evaluation[f'best_{key}'], evaluation[key])
        wandb.log(best_evaluation, step=pbar.n)
    
    run.finish()
    run_pbar.update(1)

  0%|          | 0/6 [00:00<?, ?it/s]

In [20]:
run.finish()