### 1. Load & Prepare Data Set

In [1]:
%load_ext autoreload
%autoreload 2

from chartbot_config import *

In [2]:
from datasets import load_from_disk

df = load_from_disk(processed_data_path)
df

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['index', 'Domain', 'Sub domain', 'Intent', 'Answer Format', 'value', 'labels'],
        num_rows: 9766
    })
    validation: Dataset({
        features: ['index', 'Domain', 'Sub domain', 'Intent', 'Answer Format', 'value', 'labels'],
        num_rows: 3256
    })
    test: Dataset({
        features: ['index', 'Domain', 'Sub domain', 'Intent', 'Answer Format', 'value', 'labels'],
        num_rows: 3256
    })
})

In [3]:
from transformers import AutoTokenizer, DataCollatorWithPadding

checkpoint = checkpoint

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(examples):
    return tokenizer(examples["value"], truncation=True)

tokenized_dataset = df["test"].map(tokenize_function, batched=True)

data_collator = DataCollatorWithPadding(tokenizer = tokenizer, return_tensors = "tf")

tf_test_dataset = tokenized_dataset.to_tf_dataset(
    columns = ["attention_mask", "input_ids", "token_type_ids"],
    label_cols = ["labels"],
    shuffle = False,
    collate_fn = data_collator,
    batch_size = 1,
)
print("======================================")
print("Trainset:", tf_test_dataset)

100%|██████████| 4/4 [00:00<00:00, 29.67ba/s]


Trainset: <PrefetchDataset shapes: ({input_ids: (None, None), token_type_ids: (None, None), attention_mask: (None, None)}, (None, None)), types: ({input_ids: tf.int64, token_type_ids: tf.int64, attention_mask: tf.int64}, tf.int64)>


### 2. Load & Prepare Model

In [4]:
from transformers import TFAutoModelForSequenceClassification
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam

num_labels = len(df["train"]["labels"][0])
model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=num_labels, problem_type="multi_label_classification")
model.load_weights(model_path)

opt = Adam()
loss = BinaryCrossentropy(from_logits=True)

model.compile(
    optimizer = opt,
    loss = loss,
    metrics=["accuracy"],
)

model.summary()

All model checkpoint layers were used when initializing TFBertForSequenceClassification.

Some layers of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model: "tf_bert_for_sequence_classification"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
bert (TFBertMainLayer)       multiple                  109482240 
_________________________________________________________________
dropout_37 (Dropout)         multiple                  0         
_________________________________________________________________
classifier (Dense)           multiple                  56906     
Total params: 109,539,146
Trainable params: 109,539,146
Non-trainable params: 0
_________________________________________________________________


### 3. Model Inference

In [5]:
import numpy as np

preds = model.predict(tf_test_dataset)["logits"]
class_preds = [np.argmax(i) if np.sum(i) == 1 else len(i) for i in preds > 0]

In [6]:
import pandas as pd

class_true = np.argmax(tokenized_dataset["labels"], axis=1) 
print("The accuracy of test set:", (class_preds == class_true).mean())

test_df = pd.DataFrame(tokenized_dataset[:]).loc[:, "Domain":"value"]
test_df = pd.concat([test_df, pd.DataFrame(preds)], axis = 1)
test_df["y_hat"] = class_preds
test_df.sample(5)

The accuracy of test set: 0.980958230958231


Unnamed: 0,Domain,Sub domain,Intent,Answer Format,value,0,1,2,3,4,...,65,66,67,68,69,70,71,72,73,y_hat
3041,About Business Component,Filling,Filling_sensitive_info,You can skip filling any field if you are not ...,Can I drop entering sensitive information?,-4.698543,-4.554119,-4.828384,-4.561625,-5.068871,...,-5.625214,-5.161419,-4.972987,-4.238535,-5.089555,-4.608479,-4.916294,-4.725637,-4.586147,54
209,About the team,Our Contact,Contact_method,You may contact us by email at xxxxxx@filleasy...,"So, what is your email address?",-5.617332,-4.065939,-5.067706,-3.835828,-3.918454,...,-5.554646,-5.756432,-5.21599,-5.930146,-5.115957,-5.255075,-3.930214,-4.863859,-4.24891,19
1114,Others,About Conversation,Greeting,"Hi there. Here is ""Matthew"", your personal ass...",The I,-5.024716,-4.531721,-4.849027,-5.359993,-5.041405,...,-4.019416,-5.326167,-4.740526,-5.19613,-5.686471,-4.490896,-5.141884,-4.935785,-3.793845,63
3021,About Business Component,Selection,Form_request_adding,"If you have a physical form, we encourage you ...",Can we upload a form that we cannot find on yo...,-5.247623,-4.216721,-4.782459,-4.971338,-5.50562,...,-5.396286,-5.282259,-3.975134,-4.259389,-5.723114,-4.968718,-4.96548,-4.259097,-4.457862,43
841,About Business Component,Overall,Service_Summary,We provide a 1-stop straight through processin...,What Do you guys?,-5.404927,-4.085706,-5.16707,-4.81214,-4.498517,...,-5.309199,-6.224485,-5.801836,-6.399615,-5.607286,-5.618229,-3.706496,-5.61678,-4.146214,74


In [7]:
test_df.to_csv(result_file_path, index=False)