In [1]:
import json
import re
import string
import random
from itertools import islice
from pathlib import Path
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import spacy
import torch
from spacy.tokens import Span, Doc
from spacy.tokenizer import Tokenizer
from spacy import displacy
from tqdm import tqdm
from cytoolz import groupby
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoTokenizer,
)
from seqeval.metrics import f1_score, precision_score, recall_score

from daseg import SwdaDataset, Call, FunctionalSegment, TransformerModel
from daseg.data import to_transformers_ner_dataset

%load_ext autoreload
%autoreload 2

In [2]:
dataset = SwdaDataset.from_path('deps/swda/swda')

In [None]:
call_ids = dataset.call_ids
calls = dataset.calls
call = dataset.calls[0]
call[:10]

In [None]:
texts_by_act = dataset.acts_with_examples()

In [None]:
len(texts_by_act.keys()), texts_by_act.keys()

In [None]:
acts = set(texts_by_act.keys())

In [None]:
original_acts = set(Path('/Users/pzelasko/jhu/da/swda-dialog-act-list').read_text().split('\n')[:-1])  # empty line

In [None]:
len(original_acts)

In [None]:
acts - original_acts

In [None]:
len(acts)

In [None]:
original_acts - acts

In [None]:
len(original_acts & acts)

In [None]:
pd.Series({act: len(texts) for act, texts in texts_by_act.items()}).sort_values().plot.barh(figsize=(10, 12), logx=True)

In [None]:
texts_by_act['Hedge']

## Number of turns distribution

In [None]:
sum(map(len, texts_by_act.values()))

In [None]:
pd.Series([len(call) for call in calls]).hist()

## Word length distribution

In [None]:
special_symbols = dataset.special_symbols()
len(special_symbols)

In [None]:
words_len_dist = pd.Series([sum(len(u.split()) for u, _, _, _ in call) for call in calls])

In [None]:
words_len_dist.hist()

In [None]:
to_transformers_ner_dataset(calls[1073], special_symbols)[:20]

In [None]:
if False:
    for split_name, split_dataset in dataset.train_dev_test_split().items():
        split_dataset.dump_for_transformers_ner(f'deps/transformers/examples/ner/{split_name}.txt.tmp')
else:
    print("DATASETS NOT WRITTEN TO DISK")

# Visualize

In [None]:
call[:20]

In [None]:
call.render(max_turns=20)

# Train the model / Predict

Refer to `run_da.sh` for this purpose. 

# Read model predictions

In [None]:
#preds_path = '/home/pzelasko/transformers/examples/ner/swda-xlmroberta-kosher-split-t43/test_predictions.txt'
preds_path = '/home/pzelasko/daseg/deps/transformers/examples/ner/xlnet-v1/test_predictions.txt'
calls = SwdaDataset.from_transformers_predictions(preds_path)

## Render model predictions

In [None]:
idx = 7

In [None]:
calls.calls[idx].render(max_turns=None)

# Inference

In [3]:
eval_dset = dataset.train_dev_test_split()['test']

In [4]:
#model_dir = 'deps/transformers/examples/ner/xlnet-v1/'
#model_dir = '/Users/pzelasko/jhu/da/xlnet-v1/'
#model_dir = '/Users/pzelasko/jhu/da/xlnet-t46-textnorm/'
model_dir = '/Users/pzelasko/jhu/da/longformer-t42-submission'

In [6]:
from daseg import TransformerModel
model = TransformerModel.from_path(model_dir)



In [7]:
results = model.predict(eval_dset, batch_size=1)

Token indices sequence length is longer than the specified maximum sequence length for this model (1510 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2899 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (954 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2962 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (3129 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Predicting dialog acts (batches of 1)',…



  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
results.keys()

In [None]:
for x in 'accuracy f1 precision recall'.split():
    print(results[x])

from seqeval.metrics import classification_report
print(classification_report(results['true_labels'], results['predictions']))

In [8]:
results['dataset'].calls[0].render()

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]