In [10]:
from transformers import DistilBertForSequenceClassification,DistilBertTokenizer
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, TensorDataset,DataLoader
from sklearn.model_selection import train_test_split
import torch
from helper_functions import get_paths_for_en_episodes
from helper_functions import dialogue_json_to_pandas

In [None]:
device = torch.device('cpu')

In [6]:
# Load DistilBert Classification model from transformer library
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased',num_labels=1).to(device=device)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Load saved best BERT model op CPU (model was trained on GPU in Colab)
model.load_state_dict(torch.load('DistilBERT_best_model.pt', map_location=device))
model.eval()

I0307 16:05:15.470422 4446334464 configuration_utils.py:256] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json from cache at /Users/Bart/.cache/torch/transformers/a41e817d5c0743e29e86ff85edc8c257e61bc8d88e4271bb1b243b6e7614c633.8949e27aafafa845a18d98a0e3a88bc2d248bbc32a1b75947366664658f23b1c
I0307 16:05:15.473039 4446334464 configuration_utils.py:292] Model config DistilBertConfig {
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": null,
  "dim": 768,
  "do_sample": false,
  "dropout": 0.1,
  "eos_token_ids": null,
  "finetuning_task": null,
  "hidden_dim": 3072,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 1

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
       

In [20]:
all_en_episodes_paths = []
for subset_number in [0, 1, 2]:
    all_en_episodes_paths.append(get_paths_for_en_episodes(subset_number)[1])

# flatten the list
all_en_episodes_paths = [item for sublist in all_en_episodes_paths for item in sublist]
print(len(all_en_episodes_paths))

100%|██████████| 36/36 [00:15<00:00,  2.12it/s]
100%|██████████| 36/36 [00:17<00:00,  1.97it/s]
100%|██████████| 36/36 [00:19<00:00,  1.56it/s]

39857





In [21]:
# filter out the episodes that where used in the training and validation set
train_val_eps = ['podcast_data_no_audio/podcasts-transcripts/0/E/show_0e2tMqHNabAf1lJUF2Nakg/1Y6InCAx7VPhMB67HHEpVE.json', 'podcast_data_no_audio/podcasts-transcripts/0/F/show_0f2P0fH4EwuEtXKpXIt7Ui/5Km5wY535jnqy9glmwjuV8.json', 'podcast_data_no_audio/podcasts-transcripts/0/6/show_06DN2th96dYmRtHEQFNKTo/0lmQa6w0BR8e5TyzF5wjdN.json', 'podcast_data_no_audio/podcasts-transcripts/0/L/show_0L5kSg5frpqKFnQFasMcGG/3oYP9Ukre8PhDHG4TBRKn8.json', 'podcast_data_no_audio/podcasts-transcripts/0/C/show_0CSrUveOqf2QM7fgrdIkVy/0sP7Z8pgO0QP1cWeavTCnB.json', 'podcast_data_no_audio/podcasts-transcripts/0/T/show_0tN8aZWZ5GKbkc0gzUUFDP/7j7G7kaPgxH8kspdXW3HiA.json', 'podcast_data_no_audio/podcasts-transcripts/0/V/show_0vG9O03AYlTclAHNxfFlDI/4vmVc5gXNBdMkUydhi7WcR.json', 'podcast_data_no_audio/podcasts-transcripts/0/N/show_0NpVLVfKd8mtKlBjmEj0vu/16bv16BRw8oV8AUH5JaicY.json', 'podcast_data_no_audio/podcasts-transcripts/0/G/show_0gflUCrpF0H9uXuWtMQWLx/1Zk1xrDmSpwuCtdWHF8dMn.json', 'podcast_data_no_audio/podcasts-transcripts/0/L/show_0LaYlRViq9hVwwGykQLHd3/4q9et9b1xeAJcN3FnUq1vN.json', 'podcast_data_no_audio/podcasts-transcripts/0/2/show_02qeRiltSbNgNczt4wjP6q/72paNmckHAMP0B8dj7qgAy.json', 'podcast_data_no_audio/podcasts-transcripts/0/M/show_0m4KhpeNnmuvPlCJj5l0oV/61BB4ODbjq0RQM2mQOrZPO.json', 'podcast_data_no_audio/podcasts-transcripts/0/I/show_0I22C9iyvVT3M6DEILuD9F/1O7Q39FeikjQofDo2QpzT4.json', 'podcast_data_no_audio/podcasts-transcripts/0/L/show_0L04op9D76TOfmzm7yOf9T/1zWtKaYeyBd6k7MkOb67V4.json', 'podcast_data_no_audio/podcasts-transcripts/0/N/show_0NGeePSmXrnyU4k6EX4wNv/2i84Vw9tZDWrD10bt2uvR2.json', 'podcast_data_no_audio/podcasts-transcripts/0/I/show_0I22C9iyvVT3M6DEILuD9F/1pef0MZpUV05KlJGqEPjTG.json', 'podcast_data_no_audio/podcasts-transcripts/0/0/show_00kBfZGbf0p8LKaPxpvkbi/0oVZ7OHBw3R6op5FioXQrQ.json', 'podcast_data_no_audio/podcasts-transcripts/0/9/show_09CgCaGCGxbeaeMVviFKxw/00u97YwLndEB0KSNoaIA22.json', 'podcast_data_no_audio/podcasts-transcripts/0/M/show_0MTCY7tw7AKad94BlV25Lh/6XnXwwno1h7vHIgzy9Pwyg.json', 'podcast_data_no_audio/podcasts-transcripts/0/R/show_0RCLNMBkrHcVMruEONenxD/25V7rh0ypzAwdKD1UWBTse.json', 'podcast_data_no_audio/podcasts-transcripts/0/A/show_0AQnbBbrcnOEVvpyEt2hDg/3o4nDbWg5WImMkjo6eWANS.json', 'podcast_data_no_audio/podcasts-transcripts/0/O/show_0Ow12LWRnZAjrC9CdzcVXf/0DpZyyqz4j6PDJbAKbcE0L.json', 'podcast_data_no_audio/podcasts-transcripts/0/N/show_0nooPN4bqy4mXjxzcw6Q8x/5tqBWNeJFqluMscMPNx3Jl.json', 'podcast_data_no_audio/podcasts-transcripts/0/E/show_0E2L8zPYhApYkmWWFef7aK/7KuzIPpOo3ce9JoapOyx7x.json', 'podcast_data_no_audio/podcasts-transcripts/0/4/show_04oioSRpSb6NwO8L6SOODX/2BqGPeR8EtXRQAtOk1FnZu.json', 'podcast_data_no_audio/podcasts-transcripts/1/W/show_1W8XRepty6Lw4UtSKQypoW/5UXTpnjCqRoXQrakHUIixK.json', 'podcast_data_no_audio/podcasts-transcripts/1/F/show_1F8eCztQpRKBXGfXn60hfA/0sjKTXOdiq21FWfURkNhK9.json', 'podcast_data_no_audio/podcasts-transcripts/1/N/show_1NuPhncm1111kQjUQEgVjB/5aW8y7qZYy7fVMuOZTstFO.json', 'podcast_data_no_audio/podcasts-transcripts/1/K/show_1K2Fzlro0Lmp9pGS7Ak00D/4kwQvhhqChdY7wmWr0IEC0.json', 'podcast_data_no_audio/podcasts-transcripts/1/Z/show_1ZKHflAANeLU5C4XZ1Aa2O/78mQuY226OGypfZTE0eXEF.json', 'podcast_data_no_audio/podcasts-transcripts/1/M/show_1MOX9JsJxj4qMVkECiO4dF/1j7r7kFe2lw0OqaaUbufa6.json', 'podcast_data_no_audio/podcasts-transcripts/1/I/show_1IJslH3oyMzNDjlGyb1D15/2rpuMACJeBr3mU5QAspY0i.json', 'podcast_data_no_audio/podcasts-transcripts/1/F/show_1FcrPHfZW2YPygOev9mp6G/03NWT0unEHJf0qwSEcPF8U.json', 'podcast_data_no_audio/podcasts-transcripts/1/W/show_1wkrwgQJo7I6wEC4RwU8x2/38m9SUCzPKWurLqoqKp1x7.json', 'podcast_data_no_audio/podcasts-transcripts/1/T/show_1T69Xe0EJ4n0gOO4RD9qv0/6wtwVAYvYUpWDMGk4apvyC.json', 'podcast_data_no_audio/podcasts-transcripts/1/H/show_1hAEAWtQFOaIYTMRZlMcdg/6zYGnMcEPqVyuryRUr4MvF.json', 'podcast_data_no_audio/podcasts-transcripts/1/U/show_1UO3SVMDFdE9hUkvi9G6QS/6YaAfVdDepKf7iBflevWMV.json', 'podcast_data_no_audio/podcasts-transcripts/1/W/show_1w5HzKD00MO1PpzEJJo3Vn/0HPK9fGvi4FrMEf2vzqyo0.json', 'podcast_data_no_audio/podcasts-transcripts/1/C/show_1cTKtWwQ2BXvK5Q0wfPiQ5/67vPsEwE5Vp4r218jLX3rl.json', 'podcast_data_no_audio/podcasts-transcripts/1/Y/show_1yBJMbywkEEhIGKkvwoIVg/24mGzaqBlP6taAtMwiECSB.json', 'podcast_data_no_audio/podcasts-transcripts/1/F/show_1FQS4jnubbMxz5OvVlsxwX/6o4ojoy8AL1CLW1IHoqOTJ.json', 'podcast_data_no_audio/podcasts-transcripts/1/S/show_1SXD1U55jqbK9HHoPvdbsw/2EN1wmxv3M259CwSkVi2Rv.json', 'podcast_data_no_audio/podcasts-transcripts/1/Z/show_1z7SoufmYZoz2Gqvel1IzO/29RGXrRgPCAcUcRlR0R31Y.json', 'podcast_data_no_audio/podcasts-transcripts/1/L/show_1lbzRaT4n1Rxx04QatZo9Y/1Jikd0xspLVluTaYJWftMX.json', 'podcast_data_no_audio/podcasts-transcripts/1/Q/show_1qZ5TK5ghLJnNeWP0NvSKp/1H2uvhjwMtPCoxFL6KFBOE.json', 'podcast_data_no_audio/podcasts-transcripts/1/I/show_1iHjo84YtAIWa2GaYwowiq/31jD07gEHkjZ4kEttHP2bl.json', 'podcast_data_no_audio/podcasts-transcripts/1/Z/show_1Z9gEuGf562fInSsMoLqyu/5ifXdhAl7qPXDQpUrzgpaJ.json', 'podcast_data_no_audio/podcasts-transcripts/1/H/show_1hygb4nGhNhlLn4pBnN00j/5XXWZGEAANXkIIQ7X2U2R7.json', 'podcast_data_no_audio/podcasts-transcripts/1/Y/show_1y0ZxlG9I4t9YImbff00I2/3UHX2Ax6B6VHwwhKXUHbkw.json', 'podcast_data_no_audio/podcasts-transcripts/1/J/show_1J0yZFJ0lUUvpMb8ZJRbXy/3qnjfTYDjoV5lRCDusjzSI.json', 'podcast_data_no_audio/podcasts-transcripts/2/G/show_2gs5645b3F1d2End3KXBp4/0iSMYT4A1ULCUCY8STewvY.json', 'podcast_data_no_audio/podcasts-transcripts/2/9/show_29Qk08pTL6B9LPRejivbS6/4VvILwzXpbTcWt9YMoyFUZ.json', 'podcast_data_no_audio/podcasts-transcripts/2/I/show_2iHte4DYbwL2mlhgysmb7l/11ybYZlrlm3lmkWX1OXl3p.json', 'podcast_data_no_audio/podcasts-transcripts/2/T/show_2tetA7Ub1xxSLm0oHA1gmV/6VuAZ7nxVzGDcK7Ie8jXyJ.json', 'podcast_data_no_audio/podcasts-transcripts/2/Q/show_2qY4nOu5zaZ9CMTlB6XBj5/6y7bLRK9VKOJiDhZaZM9Sy.json', 'podcast_data_no_audio/podcasts-transcripts/2/I/show_2iaGIA0ODxmgHSmKXYopRX/0GfpjRHpEhQK1bwWkoNbza.json', 'podcast_data_no_audio/podcasts-transcripts/2/1/show_21ASCcEXgUlbFSmoqjroZm/6D59qfOpamnVRDElnuTocj.json', 'podcast_data_no_audio/podcasts-transcripts/2/W/show_2W1Hy8xOcSi0b8ppdQ52qU/21iECAkP64WkYOc9kIusyI.json', 'podcast_data_no_audio/podcasts-transcripts/2/8/show_28KKqFWNBw6kk5aBFu2viN/6WphEO0vKaCcJeVsYa7Kng.json', 'podcast_data_no_audio/podcasts-transcripts/2/K/show_2kl1GlTTgSwgobmxgGEaBX/4RO6YTUkTTryRb3uNXL4l0.json', 'podcast_data_no_audio/podcasts-transcripts/2/0/show_20zWNMtAU8S3rir62g1r0Q/1T3FUP68I7oYVnuOyM44uE.json', 'podcast_data_no_audio/podcasts-transcripts/2/M/show_2M8KJDwxT0zzwXEJhDZYCL/06mLFp9wQFnIQd79Qjg4jq.json', 'podcast_data_no_audio/podcasts-transcripts/2/C/show_2C0AgUOt4eCULjFjb3mynN/4jaXWbdotGutHNxKKjMOWs.json', 'podcast_data_no_audio/podcasts-transcripts/2/H/show_2hq8BvOX4DocQpqkm20XLu/6kz63TvGtLKe5dI9qr8sqm.json', 'podcast_data_no_audio/podcasts-transcripts/2/A/show_2AkW5V4H6xAh8IXJU0jHUm/2PHi40upTwPGW5btJugKuK.json', 'podcast_data_no_audio/podcasts-transcripts/2/H/show_2HIFmqNqJkR2SADAcG4Fpq/6EUGCLZqWARW8TcNzezI0e.json', 'podcast_data_no_audio/podcasts-transcripts/2/E/show_2eXdry3liXk0Z0BfTCVKff/4TssEgbtKfFBhH2NVpkhD2.json', 'podcast_data_no_audio/podcasts-transcripts/2/Q/show_2qy2KehR0K2FGZqsEsU2CY/59S06bau9DJEEx1f1dC4y7.json', 'podcast_data_no_audio/podcasts-transcripts/2/O/show_2OVM4aQOEgo8uvuUcqJSei/7o7HaMZf49jS2yyXYUdSCc.json', 'podcast_data_no_audio/podcasts-transcripts/2/J/show_2jW0aO9MqYFndIcIAhFukV/2Tt47VISLTvm8gTyhFjyb8.json', 'podcast_data_no_audio/podcasts-transcripts/2/4/show_24aqN472kMGKAhdJbIK59L/7sNutOB7XH6jihysx0tfiT.json', 'podcast_data_no_audio/podcasts-transcripts/2/A/show_2AsvwIbKhe8yV3ePi8nGY0/0ULidGb0REL65hEkd7zwLo.json', 'podcast_data_no_audio/podcasts-transcripts/2/1/show_21ASCcEXgUlbFSmoqjroZm/6FSUZ1PWVjCa7aHLifMPGa.json', 'podcast_data_no_audio/podcasts-transcripts/2/1/show_21ASCcEXgUlbFSmoqjroZm/2aZy5qlRmNbPDXIiV1hxQJ.json', 'podcast_data_no_audio/podcasts-transcripts/2/9/show_29DGC8r0eQZWtzPs10jASp/5qHQFfGcYnLWUPCOQt5PfM.json']

all_en_episodes_paths_filtered = [x for x in all_en_episodes_paths if x not in train_val_eps]
print(len(all_en_episodes_paths_filtered))

39782


In [26]:
# convert episodes into dataframe
inference_dialogue = dialogue_json_to_pandas(all_en_episodes_paths_filtered[1])

In [27]:
# tokenize utterances in dialogue
tokenized_utterances = test_dialogue.text.apply(lambda x: tokenizer.encode(x,add_special_tokens=True))

In [32]:
# pad sentences for inference
max_len = max(map(len,tokenized_utterances))
padded_utterances = np.array([ i+[0]*(max_len-len(i))  for i in tokenized_utterances])
attention_masked_utterances = np.where(padded_utterances != 0,1,0)

In [31]:
# Convert dialogue into tensors and dataset
X_inference = torch.tensor(padded_utterances)
X_inference_attention = torch.tensor(attention_masked_utterances)

inference_dataset = TensorDataset(X_inference, X_inference_attention)
inference_loader = DataLoader(inference_dataset, batch_size=16, shuffle=False)

In [36]:
 def sigmoid(x):
    return 1 / (1 + np.exp(-x)) 

In [None]:
preds = np.zeros([len(inference_dataset), 1])
for i, (x_batch, x_mask) in enumerate(inference_loader):
    print(i)
    outputs = model(x_batch.to(device),attention_mask=x_mask.to(device))

    y_pred = sigmoid(outputs[0].detach().cpu().numpy())

    preds[i*16:(i+1)*16, :] = y_pred


pred_labels = []
for p in preds:
    if p > 0.5:
        pred_labels.append(1)
    else:
        pred_labels.append(0)
        
print(pred_labels)

0
1
2
3
4
5
6
7
