In [None]:
!git clone https://github.com/leonardozilli/FrameTrigger.git
%cd FrameTrigger/

In [1]:
!pip install transformers[torch] datasets -q

In [None]:
from src.fn17 import load_dataset_hf, load_dataset_nltk
from src.train import train, test
from src.predict import predict_triggers
from src.process_data import prepare_data

In [2]:
CHECKPOINT = 'bert-base-cased'
TASK = 'frames'

In [3]:
dataset = load_dataset_nltk()
tokenized_dataset = prepare_data(dataset, CHECKPOINT, task=TASK)
tokenized_dataset

[nltk_data] Downloading package framenet_v17 to /home/leo/nltk_data...
[nltk_data]   Package framenet_v17 is already up-to-date!
100%|██████████| 107/107 [00:23<00:00,  4.58it/s]
100%|██████████| 107/107 [00:21<00:00,  5.02it/s]
100%|██████████| 107/107 [00:24<00:00,  4.30it/s]


Map:   0%|          | 0/3425 [00:00<?, ? examples/s]

Map:   0%|          | 0/328 [00:00<?, ? examples/s]

Map:   0%|          | 0/1354 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 3425
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 328
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1354
    })
})

In [4]:
tokenized_dataset['train'][0]

{'input_ids': [101,
  6167,
  112,
  188,
  1802,
  1607,
  3471,
  1114,
  1157,
  4734,
  1107,
  6210,
  2175,
  3002,
  125,
  117,
  1288,
  1201,
  2403,
  117,
  1133,
  1175,
  1125,
  1151,
  1769,
  7536,
  1303,
  1111,
  3944,
  117,
  1930,
  6159,
  1424,
  5813,
  117,
  1196,
  1115,
  119,
  102],
 'token_type_ids': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'labels': [-100,
  0,
  0,
  -100,
  0,
  0,
  873,
  0,
  0,
  1077,
  0,
  0,
  637,
  0,
  0,
  -100,
  -100,
  711,
  1137,
  0,
  0,
  456,
  0,
  456,
  801,
  673,
  0,
  0,
  711,
  0,
  662,
  177,
  -100,
  -100,
  0,
  1137,
  0,
  0,
  -100]}

## Train

In [4]:
OUT_DIR = './models'
N_EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 2e-5

In [None]:
train(pretrained_model=CHECKPOINT, task=TASK, dataset=tokenized_dataset,
    epochs=N_EPOCHS, batch_size=BATCH_SIZE, lr=LEARNING_RATE, model_output_path=OUT_DIR)

## Test

In [None]:
test(OUT_DIR, tokenized_dataset['test'], task=TASK)

## Predict

In [74]:
sentence = "When the moon hits your eye, that's 'amore'."

predict_triggers(OUT_DIR, sentence)

"When* the moon hits* your eye* , that ' s ' amore ' ."

In [None]:
from spacy import displacy

dic_ents = {
    "text": "My name is John Smith and I live in Paris",
    "ents": [
        {"start": 11, "end": 21, "label": "Person"},
        {"start": 36, "end": 41, "label": "Location"},
    ],
    "title": None
}

displacy.render(dic_ents, manual=True, style="ent")


In [38]:
# dataset
#  - iob
# displacy
dataset['train'][269]

{'id': 1430,
 'sent_id': 1542,
 'tokens': ['Does',
  'Iran',
  'have',
  'the',
  'infrastructure',
  'necessary',
  'to',
  'produce',
  'nuclear',
  'weapons',
  '?'],
 'frame_tags': [0, 0, 112, 0, 134, 192, 0, 9, 662, 662, 0],
 'coarse_pos': ['VVZ',
  'NP',
  'VH',
  'dt',
  'nn',
  'jj',
  'to',
  'VV',
  'jj',
  'nns',
  'sent'],
 'fine_pos': ['NNP',
  'NNP',
  'VB',
  'DT',
  'NN',
  'JJ',
  'TO',
  'VB',
  'JJ',
  'NNS',
  '.'],
 'lemma': ['Does',
  'Iran',
  'have',
  'the',
  'infrastructure',
  'necessary',
  'to',
  'produce',
  'nuclear',
  'weapon',
  '?'],
 'LUs': ['_', '_', 'have.v', '_', '_', '_', '_', '_', '_', '_', '_'],
 'FEs': ['O\n',
  'S-Owner\n',
  'O\n',
  'B-Possession\n',
  'I-Possession\n',
  'I-Possession\n',
  'I-Possession\n',
  'I-Possession\n',
  'I-Possession\n',
  'I-Possession\n',
  'O\n']}