In [1]:
from datasets import load_dataset, Dataset
import pandas as pd
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange
import random
from transformers import DataCollatorWithPadding

from transformers import AutoTokenizer


ds_raw = load_dataset("silicone", "mrda")

  from .autonotebook import tqdm as notebook_tqdm
Found cached dataset silicone (/home/mms9355/.cache/huggingface/datasets/silicone/mrda/1.0.0/af617406c94e3f78da85f7ea74ebfbd3f297a9665cb54adbae305b03bc4442a5)
100%|██████████| 3/3 [00:00<00:00, 16.28it/s]


In [21]:
c = 0
for i in ds_raw['train']:
    c += len(i['Utterance'])
c/len(i)
c/len(ds_raw['train'])

37.11226665713639

In [3]:
from transformers import BertConfig, BertModel
model = BertForSequenceClassification.from_pretrained("../models/model_mrda_v2_t1.model")
# model = BertForSequenceClassification.from_pretrained("model__v1_t3.model")

In [7]:
ds_raw

DatasetDict({
    train: Dataset({
        features: ['Utterance_ID', 'Dialogue_Act', 'Channel_ID', 'Speaker', 'Dialogue_ID', 'Utterance', 'Label', 'Idx'],
        num_rows: 83943
    })
    validation: Dataset({
        features: ['Utterance_ID', 'Dialogue_Act', 'Channel_ID', 'Speaker', 'Dialogue_ID', 'Utterance', 'Label', 'Idx'],
        num_rows: 9815
    })
    test: Dataset({
        features: ['Utterance_ID', 'Dialogue_Act', 'Channel_ID', 'Speaker', 'Dialogue_ID', 'Utterance', 'Label', 'Idx'],
        num_rows: 15470
    })
})

In [4]:
labels = ["statement", "declarative question", "backchannel", "follow-me", "question"]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}


In [8]:
# Encodes utterances and assigns them classified act labels

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def dataprep(samples):
  encoding = tokenizer.encode_plus(samples['Utterance'], add_special_tokens = True,
                        max_length = 32,
                        return_attention_mask = True,
                        return_tensors = 'pt',
                        truncation=True,
                        padding="max_length"
                   )
  samples['input_ids'] = encoding['input_ids']
  samples['attention_masks'] = encoding['attention_mask']
  ls = np.zeros(len(labels))
  ls[samples['Label']] = 1
  # ls[ACT_LABELS[RAW_ACT_TAGS[samples['damsl_act_tag']]]] = 1
  samples['labels'] = ls

  return samples

# Creates encoded dataset and sets the format to pytorch
encoded = ds_raw.map(dataprep)

# encoded.set_format("torch")

                                                                   

In [9]:
for i in range(0,len(labels)):
    z = 0
    for x in encoded['test']['labels']:
        l = np.zeros(len(labels)).tolist()
        l[i] = 1
        if(x == l):
            z += 1
    print(labels[i] + ": " + str(z))


statement: 8864
declarative question: 2246
backchannel: 1961
follow-me: 1317
question: 1082


In [10]:
encoded.set_format('torch')

In [12]:
encoded['test'][5]['Utterance']
encoded['test'][5]['labels']

tensor(0)

In [13]:
corrects = 0
i = 0
labeled = np.zeros(len(encoded['test']))
for e in encoded['test']:
    out = model(e['input_ids'], token_type_ids=None, attention_mask=e['attention_masks'])
    logits = out.logits.detach().cpu().numpy()
    # print(logits)
    labeled[i] = logits.argmax()
    if (np.where(e['labels'] == 1)[0][0]  == logits.argmax()):
        corrects += 1
    i+=1
accuracy = corrects/ len(encoded['test'])
accuracy

0.8933419521654816

In [15]:
# Sample classification that doesn't work
e = encoded['test'][100]
print(e['Utterance'])
out = model(e['input_ids'], token_type_ids=None, attention_mask=e['attention_masks'])
logits = out.logits.detach().cpu().numpy()
print("Guess: " + labels[logits.argmax()])
print("Actual: " + labels[np.where(e['labels'] == 1)[0][0]])

and um - um - next to some - some more or less bureaucratic uh - stuff with the - the data collection she's also the wizard in the data collection .
Guess: statement
Actual: statement


In [16]:
# Example of flawed classification
e = encoded['test'][122]
print(e['Utterance'])
out = model(e['input_ids'], token_type_ids=None, attention_mask=e['attention_masks'])
logits = out.logits.detach().cpu().numpy()
print("Guess: " + labels[logits.argmax()])
print("Actual: " + labels[np.where(e['labels'] == 1)[0][0]])

okay | um - why don't we get started on that subject anyways ?
Guess: statement
Actual: statement


In [17]:
# Sample classification that works #1
e = encoded['test'][998]
print(e['Utterance'])
out = model(e['input_ids'], token_type_ids=None, attention_mask=e['attention_masks'])
logits = out.logits.detach().cpu().numpy()
print("Guess: " + labels[logits.argmax()])
print("Actual: " + labels[np.where(e['labels'] == 1)[0][0]])

what's also nice and for a- - i- - for me in my mind .
Guess: statement
Actual: statement


In [18]:
# Sample classification that works #2
e = encoded['test'][775]
print(e['Utterance'])
out = model(e['input_ids'], token_type_ids=None, attention_mask=e['attention_masks'])
logits = out.logits.detach().cpu().numpy()
print("Guess: " + labels[logits.argmax()])
print("Actual: " + labels[np.where(e['labels'] == 1)[0][0]])

rad !
Guess: statement
Actual: statement


In [60]:
swda_df = encoded['test'].to_pandas()
swda_df['labels_pred'] = labeled
for i in range(0, len(labels)):
    content = "__NONE__"
    df = swda_df[swda_df['labels_pred'] == i]['Utterance']
    if(len(df.index) > 0):
        rand_sample_ind = random.randint(0, len(df.index)-1)
        if(rand_sample_ind >= 0):
            content = df.iloc[random.randint(0, len(df.index)-1)]

    print_str = labels[i] + ": " + str(len(df.index)) + " : " + content
    print(print_str)

statement: 8344 : so you could start pulling back .
declarative question: 1999 : th- ==
backchannel: 2780 : right .
follow-me: 1264 : uh ==
question: 1083 : full open ?
