In [22]:
# Imports
import random
import tarfile
import os.path
import json
import pickle
import ast
import copy
import re
import string
from typing import Union
import numpy as np
import sqlite3
from pathlib import Path
from tqdm import tqdm
from datetime import datetime, timedelta
import datetime as dt
from collections import defaultdict
from itertools import chain

# Pandas
import pandas as pd
pd.options.mode.chained_assignment = None  # Turn off the SettingWithCopyWarning
tqdm.pandas()
# NLP
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk import bigrams, trigrams
from nltk.tokenize import word_tokenize


# Plots
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Patch
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

from datasets import load_dataset, load_metric, Dataset, DatasetDict
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_curve, f1_score, accuracy_score, RocCurveDisplay


import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig, pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TokenClassificationPipeline

#### Train Test DF's

In [121]:
action_enrichment_df_train = pd.read_csv("/sise/home/ofirbenm/Wsc_ex1/action_enrichment_df_train.csv")
action_enrichment_df_test = pd.read_csv("/sise/home/ofirbenm/Wsc_ex1/action_enrichment_df_test.csv")
actions_number = action_enrichment_df_train.Action.nunique()
action_enrichment_df_test.head()

Unnamed: 0,Text,Action,Action_number,Label
0,Test your mental toughness. Staffin turned it ...,No action,0,0
1,So far tonight champeney one of five for just ...,No action,0,0
2,Texas Tech is gonna stay unbeaten at home. Ohh...,No action,0,0
3,Crosswell inside over Wheeler it goes. Good move.,No action,0,0
4,Virginia is kind of used to playing log in the...,No action,0,0


#### Action to number

In [122]:
file_path = "/sise/home/ofirbenm/Wsc_ex1/action_enrichment_ds_home_exercise_old.csv"
action_enrichment_df = pd.read_csv(file_path)
actions = action_enrichment_df.Action.unique()
number_to_action = {i:action for i, action in enumerate(actions)}

#### Models

In [19]:
device = 'cpu'
bert_model = "bert-base-uncased"

action_model_path = "/sise/home/ofirbenm/Wsc_ex1/bert_action_model.pth"
action_model = BertForSequenceClassification.from_pretrained(bert_model, num_labels=actions_number)
action_model.load_state_dict(torch.load(action_model_path))
action_model.to(device)
action_model.eval()

validity_model_path = "/sise/home/ofirbenm/Wsc_ex1/bert_validity_model.pth"
validity_model = BertForSequenceClassification.from_pretrained(bert_model, num_labels=2)
validity_model.load_state_dict(torch.load(validity_model_path))
validity_model.to(device)
validity_model.eval()

tokenizer = AutoTokenizer.from_pretrained(bert_model)
opt_thresh_model = 0.392

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

In [104]:
def predict_action(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True)
    inputs.to(device)
    outputs = action_model(**inputs)
    predicted_label = torch.argmax(outputs.logits).item()
    return predicted_label

def predict_label(text, action):
    inputs = tokenizer(text, action, padding="max_length", add_special_tokens=True, truncation=True, return_tensors="pt")
    inputs = {key: value.to(device) for key, value in inputs.items()}
    logits = validity_model(**inputs).logits
    return F.softmax(logits, dim=1).tolist()[0]

def predict(text):
    action_number = predict_action(text)
    action = number_to_action[action_number]
    if action_number == 0:
        print(f'There is no action in the  transcript.')
        label = 0
        return action, label
    else:
        softmax = predict_label(row['Text'], action)
        label = 1 if softmax[1] > opt_thresh_model else 0
        if label == 1:
            print(f"The action is '{action}'.")
        else:
            print(f"The action '{action}' is not valid.")
        return action, label

In [106]:
correct_predictions = 0
for i, row in tqdm(action_enrichment_df_test.iterrows(), total=len(action_enrichment_df_test)):
    action, label = predict(row['Text'])
    if action == row['Action'] and label == row['Label']:
        correct_predictions += 1

total_predictions = action_enrichment_df_test.shape[0]
accuracy = correct_predictions / total_predictions
print("Accuracy on Test set: {:.3f}%".format(accuracy * 100))

Accuracy on Test set: 86.076%
