In [2]:
import os; os.chdir('/home/duxin/code/readability')
import numpy as np
import pandas as pd
print(os.getcwd())

/misc/home/duxin/code/readability


In [3]:
from argparse import Namespace
from lib.syntax_augmented_bert.utils.loader import (
    FeaturizedDataset,
    FeaturizedDataLoader,
)
from lib.syntax_augmented_bert import model
from lib.syntax_augmented_bert.utils import (
    constant,
)
from lib.syntax_augmented_bert.utils.utils import (
    OntoNotesSRLProcessor,
)
from lib.syntax_augmented_bert.opt import get_args
from lib.syntax_augmented_bert.main import (
    load_and_cache_examples,
    MODEL_CLASSES,
)

In [4]:
opt = get_args([
    '--model_type=syntax_bert_seq',
    '--model_name_or_path=bert-base-uncased',
    '--task_name=ontonotes_srl',
    '--data_dir=data/dep/',
    '--max_seq_length=512',
    '--per_gpu_eval_batch_size=32',
    '--output_dir=results/syntax/checkpoints/1/'
    '--save_steps=1000', 
    '--overwrite_output_dir',
    '--num_train_epochs=20',
    '--do_eval',
    '--do_train',
    '--evaluate_during_training',
    '--config_name_or_path=lib/syntax_augmented_bert/config/srl/bert-base/joint_fusion.json',
    '--per_gpu_train_batch_size=16',
    '--gradient_accumulation_steps=1',
    '--wordpiece_aligned_dep_graph',
    '--seed=40',
])

In [5]:
config_class, model_class, tokenizer_class = MODEL_CLASSES[opt.model_type]
tokenizer = tokenizer_class.from_pretrained('bert-base-cased')
# train_dataset = load_and_cache_examples(opt, task='ontonotes_srl', tokenizer=tokenizer, split='train')

In [6]:
label_map = constant.OntoNotes_SRL_LABEL_TO_ID
num_labels = len(label_map)

In [13]:
examples=OntoNotesSRLProcessor().get_train_examples('data/dep/')

In [25]:
examples[-1].text_a

'animals are made of many cells . they eat things and digest them inside . most animals can move . only animals have brains ( though not even all animals do ; jelly ##fish , for example , do not have brains ) . animals are found all over the earth . they dig in the ground , swim in the oceans , and fly in the sky . humans are a type of animal . so are dogs , cats , cows , horses , frogs , fish , and so on and on . animals can be divided into two main groups , ve ##rte ##brates and invertebrates . ve ##rte ##brates can be further divided into mammals , fish , birds , reptiles , and amp ##hi ##bians . invertebrates can be divided into art ##hr ##op ##ods ( like insects , spiders , and crabs ) , mollusk ##s , sponge ##s , several different kinds of worms , jelly ##fish — and quite a few other subgroup ##s . there are at least thirty kinds of invertebrates , compared to the five kinds of ve ##rte ##brates . ve ##rte ##brates have a backbone , while invertebrates do not .'

In [7]:
train_dataset = FeaturizedDataset(
    examples=OntoNotesSRLProcessor().get_train_examples('data/dep/'),
    opt=opt,
    tokenizer=tokenizer,
    label_map=label_map,
    cls_token_segment_id=0,
    pad_token_segment_id=0,
)

In [27]:
def load_data(path):
    df = pd.read_csv(path, index_col=0)
    df['excerpt'] = df['excerpt'].str.lower()
    df = df.loc[df['standard_error'] > 0]
    return df

def load_syntax_data(path):
    examples = OntoNotesSRLProcessor().get_unk_examples(path)
    return examples
data = load_data('data/commonlitreadabilityprize/train.csv')
syntax_data = load_syntax_data('data/dep/train.json')

In [8]:
config = model.SyntaxBertConfig.from_pretrained('lib/syntax_augmented_bert/config/srl/bert-base-uncased/joint_fusion.json',
                                                num_labels=num_labels,
                                                finetuning_task='ontonotes_srl')
syntbert = model.SyntaxBertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                                     config=config)

In [9]:
opt.device = 'cuda:0'
opt.task_name = 'readability'
train_dataloader = FeaturizedDataLoader(train_dataset, opt, batch_size=32)

In [10]:
it = iter(train_dataloader)
batch = next(it)

In [11]:
batch.keys()

odict_keys(['input_tokens', 'input_ids', 'wp_token_mask', 'token_type_ids', 'dep_head', 'dep_rel', 'wp_rows', 'align_sizes', 'seq_len', 'subj_pos', 'obj_pos', 'verb_index', 'linguistic_token_mask'])

In [34]:
syntbert.to('cuda:0')
bert = syntbert.bert

In [35]:
outputs = bert(input_ids=batch['input_ids'],
     token_type_ids=batch['token_type_ids'],
     wp_token_mask=batch['wp_token_mask'],
     dep_head=batch['dep_head'],
     dep_rel=batch['dep_rel'],
     wp_rows=batch['wp_rows'],
     align_sizes=batch['align_sizes'],
     seq_len=batch['seq_len'],
     subj_pos=batch['subj_pos'],
     obj_pos=batch['obj_pos'],
     linguistic_token_mask=batch['linguistic_token_mask'],
     output_hidden_states=True)

In [39]:
outputs[2]

(tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
          [ 0.4672, -0.1002,  0.1676,  ..., -0.1317,  1.4961,  0.0407],
          [ 1.1783, -0.4477, -0.6944,  ...,  0.6610, -0.8734,  0.2244],
          ...,
          [ 0.2824, -0.4747,  0.4477,  ..., -0.4223, -0.6601,  0.0719],
          [-0.0974, -0.4676,  0.7018,  ..., -0.6491, -0.3163,  0.0398],
          [ 0.5101, -0.5235,  0.2857,  ..., -0.3723, -0.8822, -0.0276]],
 
         [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
          [ 0.2556,  0.1523,  0.2624,  ...,  0.7252,  0.7743, -0.3568],
          [-0.8482,  0.2862, -1.0584,  ...,  0.3248,  0.4633, -0.2949],
          ...,
          [ 0.2759, -0.3649,  0.4321,  ..., -0.2785, -0.6606, -0.0218],
          [-0.6380,  0.1404,  0.8850,  ..., -1.1050,  0.0876, -0.2460],
          [ 0.5101, -0.5235,  0.2857,  ..., -0.3723, -0.8822, -0.0276]],
 
         [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
          [-0.6811, -0.0508,

In [16]:
import json

In [17]:
fr = open('data/dep/train.json', 'rt')

In [None]:
len(list(fr))

In [18]:
linej = json.loads(next(fr))
linej

{'sentence_id': 'c12129c31',
 'tokens': ['when',
  'the',
  'young',
  'people',
  'returned',
  'to',
  'the',
  'ballroom',
  ',',
  'it',
  'presented',
  'a',
  'decidedly',
  'changed',
  'appearance',
  '.',
  'instead',
  'of',
  'an',
  'interior',
  'scene',
  ',',
  'it',
  'was',
  'a',
  'winter',
  'landscape',
  '.',
  '\n',
  'the',
  'floor',
  'was',
  'covered',
  'with',
  'snow',
  '-',
  'white',
  'canvas',
  ',',
  'not',
  'laid',
  'on',
  'smoothly',
  ',',
  'but',
  'rumpled',
  'over',
  'bumps',
  'and',
  'hillocks',
  ',',
  'like',
  'a',
  'real',
  'snow',
  'field',
  '.',
  'the',
  'numerous',
  'palms',
  'and',
  'evergreens',
  'that',
  'had',
  'decorated',
  'the',
  'room',
  ',',
  'were',
  'powdered',
  'with',
  'flour',
  'and',
  'strewn',
  'with',
  'tufts',
  'of',
  'cotton',
  ',',
  'like',
  'snow',
  '.',
  'also',
  'diamond',
  'dust',
  'had',
  'been',
  'lightly',
  'sprinkled',
  'on',
  'them',
  ',',
  'and',
  'glitter

In [27]:
fr = open('data/synt/train.old.json', 'rt')
fw = open('data/synt/train.json', 'wt')

for line in fr:
    linej = json.loads(line)
    linej['ontonotes_deprel'] = [x.lower() for x in linej['dep_label']]
    del linej['dep_label']
    linej['ontonotes_head'] = linej['dep_head'].copy()
    del linej['dep_head']

    fw.write(json.dumps(linej)+'\n')

In [17]:
fr = open('data/synt/train.old.json', 'rt')
linej = json.loads(next(fr))

In [21]:
fr = open('data/synt/train.json', 'rt')
for i, line in enumerate(fr):
    linej = json.loads(line)

    for tok in linej['tags'][0]:
        # if tok not in constant.DEPREL_TO_ID:
        if tok not in constant.OntoNotes_SRL_LABEL_TO_ID:
            print(tok)

I-V
I-V
I-V


In [19]:
line.keys()

dict_keys(['sentence_id', 'tokens', 'pos_tags', 'verb_indicator', 'dep_head', 'dep_label', 'tags', 'metadata'])

In [29]:
line['dep_head']

[5,
 4,
 4,
 5,
 11,
 5,
 8,
 6,
 11,
 11,
 0,
 15,
 14,
 15,
 11,
 11,
 18,
 24,
 21,
 21,
 18,
 24,
 24,
 0,
 27,
 27,
 24,
 24,
 30,
 32,
 32,
 0,
 32,
 36,
 36,
 37,
 33,
 32,
 40,
 32,
 40,
 40,
 40,
 40,
 40,
 45,
 46,
 47,
 47,
 45,
 45,
 55,
 55,
 55,
 51,
 32,
 59,
 59,
 69,
 59,
 59,
 64,
 64,
 59,
 66,
 64,
 69,
 69,
 0,
 69,
 70,
 69,
 69,
 73,
 74,
 75,
 76,
 73,
 73,
 79,
 69,
 88,
 84,
 88,
 88,
 88,
 88,
 0,
 88,
 89,
 88,
 88,
 96,
 95,
 93,
 88,
 96,
 99,
 97,
 96,
 112,
 103,
 101,
 103,
 106,
 104,
 112,
 112,
 110,
 108,
 112,
 0,
 118,
 118,
 117,
 117,
 118,
 112,
 112,
 121,
 122,
 0,
 122,
 123,
 124,
 124,
 126,
 129,
 127,
 126,
 126,
 131,
 134,
 132,
 122,
 138,
 138,
 0,
 140,
 138,
 138,
 143,
 145,
 145,
 0,
 145,
 148,
 146,
 148,
 151,
 149,
 145,
 154,
 145,
 154,
 157,
 155,
 145,
 145,
 161,
 163,
 163,
 145,
 166,
 166,
 163,
 163,
 163,
 173,
 173,
 173,
 173,
 163,
 163,
 176,
 178,
 178,
 0,
 181,
 181,
 178,
 186,
 184,
 181,
 186,
 184,
 186,
