In [1]:
import os
import pandas as pd
import numpy as np
import torch
from transformers import ZeroShotClassificationPipeline
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tabulate import tabulate


In [2]:
data = pd.concat([
    pd.read_csv('../data/atis/train.tsv', delimiter="\t", names=['text', 'label']),
    pd.read_csv('../data/atis/test.tsv', delimiter="\t", names=['text', 'label']),
])

data['label_cnt'] = data['label'].apply(lambda s: len(s.split('+')))
print(f"MAX classes per `text`: {data['label_cnt'].max()}")

data

MAX classes per `text`: 3


Unnamed: 0,text,label,label_cnt
0,i want to fly from boston at 838 am and arrive...,flight,1
1,what flights are available from pittsburgh to ...,flight,1
2,what is the arrival time in san francisco for ...,flight_time,1
3,cheapest airfare from tacoma to orlando,airfare,1
4,round trip fares from pittsburgh to philadelph...,airfare,1
...,...,...,...
845,please find all the flights from cincinnati to...,flight,1
846,find me a flight from cincinnati to any airpor...,flight,1
847,i 'd like to fly from miami to chicago on amer...,flight,1
848,i would like to book a round trip flight from ...,flight,1


In [3]:
print(f"Median `text` words: {int(np.median(data['text'].apply(lambda t: len(t.split()))))}")
print(f"Max `text` words: {np.max(data['text'].apply(lambda t: len(t.split())))}")
print(f"Max `text` length: {np.max(data['text'].apply(lambda t: len(t)))}")

Median `text` words: 11
Max `text` words: 46
Max `text` length: 259


Let's have a look into `label` values and distribution.

In [4]:
print(f"Label distribution:\n{data['label'].value_counts()}\n")


Label distribution:
label
flight                        4039
airfare                        451
ground_service                 271
airline                        176
abbreviation                   134
aircraft                        86
flight_time                     53
quantity                        52
capacity                        37
airport                         31
distance                        30
flight+airfare                  29
ground_fare                     24
city                            23
flight_no                       20
meal                            12
restriction                      5
day_name                         2
airline+flight_no                2
airfare+flight_time              1
cheapest                         1
aircraft+flight+flight_no        1
ground_service+ground_fare       1
airfare+flight                   1
flight+airline                   1
flight_no+airline                1
Name: count, dtype: int64



Having entries that look like the following: 'airline+flight_no', 'ground_service+ground_fare', 'aircraft+flight+flight_no' etc., where each of those seems to be build from multiple independent "base level" labels, allow us to conclude that we are dealing with a **multiclass classification** case.

Let's have a look at "base-level" labels in both datasets.

In [5]:
sorted(data['label'].str.split('+').explode().unique())

['abbreviation',
 'aircraft',
 'airfare',
 'airline',
 'airport',
 'capacity',
 'cheapest',
 'city',
 'day_name',
 'distance',
 'flight',
 'flight_no',
 'flight_time',
 'ground_fare',
 'ground_service',
 'meal',
 'quantity',
 'restriction']

## Choosing a model

It appears that trying out a pre-trained ` Zero-Shot Classification`, `Transformers` model from Hugging Face can be a good choice.

Using a pre-trained zero-shot model 'raw' without fine-tuning requires careful label preparation to avoid model confusion. To ensure accurate embeddings, we need labels with explicit, simple words, avoiding shortenings or unconventional modifications. Therefore, we replace 'flight_no' with the more explicit 'flight_number' to enhance clarity.

Also, to make label tokens more distinct, let's get rid of underscores in the labels too.

In [6]:
data['label'] = data['label'].str.replace('flight_no', 'flight_number')
data['label'] = data['label'].str.replace('_', ' ')

# let's have a look at the fupdated labels once again
sorted(data['label'].str.split('+').explode().unique())

['abbreviation',
 'aircraft',
 'airfare',
 'airline',
 'airport',
 'capacity',
 'cheapest',
 'city',
 'day name',
 'distance',
 'flight',
 'flight number',
 'flight time',
 'ground fare',
 'ground service',
 'meal',
 'quantity',
 'restriction']

Let's see how the text that's labeled as more "fancier" categories (that's not present in the `data` DF too) look like in the `test` DF.

In [7]:
for l in ['airline+flight number', 'ground service+ground fare', 'aircraft+flight+flight number']:
    print(f"For the '{l}':\n{data.loc[data['label'] == l].iloc[0]['text']}\n")

For the 'airline+flight number':
airline and flight number from columbus to minneapolis

For the 'ground service+ground fare':
what ground transportation is available from the pittsburgh airport to downtown and how much does it cost

For the 'aircraft+flight+flight number':
i want to fly from detroit to st. petersburg on northwest airlines and leave around 9 am tell me what aircraft are used by this flight and tell me the flight number



In [8]:
def determine_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mps')
    else:
        return torch.device('cpu')

device = determine_device()

## Testing a model on a few "manual" inputs

In [9]:
LABELS = [
    'abbreviation',
    'aircraft',
    'airfare',
    'airline',
    'airport',
    'capacity',
    'cheapest',
    'city',
    'distance',
    'flight',
    'flight number',
    'flight time',
    'ground fare',
    'ground service',
    # 'flight_number',
    # 'flight_time',
    # 'ground_fare',
    # 'ground_service',
    'meal',
    'quantity',
    'restriction'
]
print(f'LABELS len: {len(LABELS)}')

model = 'typeform/distilbert-base-uncased-mnli'
# model = 'MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli'

# Download the model and tokenizer
model_name = "typeform/distilbert-base-uncased-mnli"
model_dir = "../models"
os.makedirs(model_dir, exist_ok=True)

model = AutoModelForSequenceClassification.from_pretrained(model_name, cache_dir=model_dir)
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_dir)

classifier_pipeline = ZeroShotClassificationPipeline(
    model=model,
    tokenizer=tokenizer,
    device=device
)

print(classifier_pipeline)

inference = classifier_pipeline(
    # "i want to fly from detroit to st. petersburg on northwest airlines and leave around 9 am tell me what aircraft are used by this flight and tell me the flight number",
    "what is the arrival time in san francisco for the 755 am flight leaving washington",
    # "what day of the week do flights from nashville to tacoma fly on?",
    candidate_labels=LABELS
)

print(f"'inference' = {inference}")

def top_k_classes(inference: dict, K: int = 3) -> list:
    label_score_pairs = list(zip(inference['labels'], inference['scores']))
    # no need to sort as 'labels' come sort by 'score' in descending order in the output
    return [pair[0] for pair in label_score_pairs[:K]]


print(top_k_classes(inference, 3))

# Let's see how the model is reacting to a new extra label and a text on a topic unrelated to ATIS data
lbls = LABELS.copy()
lbls.append('fashion')
inf = classifier_pipeline("i like my jeans", candidate_labels=lbls)
print(inf)
# the predicted top class 'fashion' looks like a great choice for our test premise ☝️

# let's check also the output to an input that's unrelated to our class set from a human POV
inf = classifier_pipeline("little puppies and toddlers are cute", candidate_labels=lbls)
print(inf)

LABELS len: 17


The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


<transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline object at 0x37a905e50>
'inference' = {'sequence': 'what is the arrival time in san francisco for the 755 am flight leaving washington', 'labels': ['flight time', 'flight number', 'flight', 'airfare', 'distance', 'airport', 'airline', 'capacity', 'abbreviation', 'aircraft', 'quantity', 'city', 'restriction', 'cheapest', 'ground fare', 'meal', 'ground service'], 'scores': [0.7422202229499817, 0.13428395986557007, 0.0673116073012352, 0.01873125694692135, 0.006247242446988821, 0.005457504186779261, 0.005328348372131586, 0.004026705399155617, 0.0036638150922954082, 0.0036268967669457197, 0.0027513313107192516, 0.001577554503455758, 0.0014490586472675204, 0.001412139623425901, 0.0009985595243051648, 0.000516661733854562, 0.0003971454279962927]}
['flight time', 'flight number', 'flight']
{'sequence': 'i like my jeans', 'labels': ['fashion', 'distance', 'restriction', 'quantity', 'capacity', 'airfare', 'abbreviatio

Without replacing underscored to spaces in labels we got:
```
['flight', 'aircraft', 'flight_number']
```
as the result for the 
```
"i want to fly from detroit to st. petersburg on northwest airlines and leave around 9 am tell me what aircraft are used by this flight and tell me the flight number"
```
whereas after the pre-processing (fixing wording, replacing underscores with spaces) we got
```
['flight number', 'flight', 'aircraft']
```
which appears to be a more accurate outcome (top class 'flight number').

# Testing

As we are not performing any training in this notebook and rather testing a pre-trained zero-shot model's ability to infer intent against our set of labels, we will simply apply the classifier pipline to the whole data set and then calculate and assess metrics.

## Inference application

In [10]:
# Let's run inference against the data set
# (omitting 'sequence' inference's key)
data['inference'] = data['text'].apply(lambda t: {k: v for k, v in classifier_pipeline(t, LABELS).items() if k != 'sequence'})
data.head(3)

Unnamed: 0,text,label,label_cnt,inference
0,i want to fly from boston at 838 am and arrive...,flight,1,"{'labels': ['flight number', 'flight', 'airlin..."
1,what flights are available from pittsburgh to ...,flight,1,"{'labels': ['flight', 'flight number', 'flight..."
2,what is the arrival time in san francisco for ...,flight time,1,"{'labels': ['flight time', 'flight number', 'f..."


## Result processing

Let's create and populate two new binary (0/1) columns for:
 * inference's top 3 classes hitting any of the `label` values and 
 * inference's top class hitting  any of the `label` values

In [11]:
data['top_3_classes'] = data['inference'].apply(lambda i: top_k_classes(inference=i, K=3))
data['top_class'] = data['inference'].apply(lambda i: top_k_classes(inference=i)[0])

# when 'top_3_classes' "hits" (has intersections) with any class in the `label` subset - count as a hit
data['top_3_hit'] = data.apply(lambda row: int(len(set(row['label'].split('+')) & set(row['top_3_classes'])) > 0), axis=1)
data['top_hit'] = data.apply(lambda row: int(row['top_class'] in row['label'].split('+')), axis=1)
data.head(3)

Unnamed: 0,text,label,label_cnt,inference,top_3_classes,top_class,top_3_hit,top_hit
0,i want to fly from boston at 838 am and arrive...,flight,1,"{'labels': ['flight number', 'flight', 'airlin...","[flight number, flight, airline]",flight number,1,0
1,what flights are available from pittsburgh to ...,flight,1,"{'labels': ['flight', 'flight number', 'flight...","[flight, flight number, flight time]",flight,1,1
2,what is the arrival time in san francisco for ...,flight time,1,"{'labels': ['flight time', 'flight number', 'f...","[flight time, flight number, flight]",flight time,1,1


## Model Performance Metrics

Not let's calculate metrics, namely **Precision** and **F1-score**. 

Note that we don't mention **Recall**. That's because technically speaking we don't have a strictly negative class unlike with binary classification and our labels simply can be farther or closer to the meaning and ontent of a message, but hardly 100% unrelated.

With that in mind and knowing that the formula of **Recall** is TP / (TP + FN) and our FN = 0.0 - the concept of **Recall** becomes trivial and we set it to **1.0**.

One could argue we could or even should have some sort of a top class threshold, below which the ineference is considered as "negative", but we'll leave it outside of the scope of the current exercize for now.

In [12]:
sample_size = data.shape[0]
top_3_hits = data['top_3_hit'].sum()
top_hits = data['top_hit'].sum()

# In our case, by our admission - there are no negative classes, therefore
recall = 1.0

# Calculate precision, recall, and F1 for top 3 classes
top_3_precision = top_3_hits / sample_size
top_3_f1 = 2 * (top_3_precision * recall) / (top_3_precision + recall)

# Calculate precision, recall, and F1 for top class
top_precision = top_hits / sample_size
top_f1 = 2 * (top_precision * recall) / (top_precision + recall)

headers = ["Metric", "Value"]
report = [
    ["Total sample size", sample_size],
    ["Number of classes", len(LABELS)],
    ["Top 3 classes inference Precision", f"{top_3_precision:.2f}"],
    ["Top 3 classes inference F1", f"{top_3_f1:.2f}"],
    ["Top class inference Precision", f"{top_precision:.2f}"],
    ["Top class inference F1", f"{top_f1:.2f}"],
]

print(tabulate(report, headers, tablefmt="grid", stralign="left", numalign="right"))


+-----------------------------------+---------+
| Metric                            |   Value |
| Total sample size                 |    5484 |
+-----------------------------------+---------+
| Number of classes                 |      17 |
+-----------------------------------+---------+
| Top 3 classes inference Precision |    0.85 |
+-----------------------------------+---------+
| Top 3 classes inference F1        |    0.92 |
+-----------------------------------+---------+
| Top class inference Precision     |    0.61 |
+-----------------------------------+---------+
| Top class inference F1            |    0.76 |
+-----------------------------------+---------+


### Metrics Interpretation

#### Top 3 Classes Inference
* The model achieves an **F1 score** of **0.92**, indicating excellent performance and high accuracy in identifying message intents.
* The **Precision** score of **0.85** is acceptable, with some errors in precision, but still correctly identifying a significant proportion of message intents.

#### Top Class Inference
* The **F1 score** of **0.76** indicates reasonable accuracy, with some potential errors in either precision or recall.
* The **Precision** score of **0.61** is low, with many predicted positive instances being False Positives.

Based on these metrics, the "typeform/distilbert-base-uncased-mnli" model appears to be a suitable choice for our project's first iteration, as it provides sufficient accuracy in identifying the top 3 intents for a message.

## Looking into the "weakest" classes

In [13]:
df_labels_exploded = data.assign(label=data['label'].str.split('+')).explode('label')
label_precision = df_labels_exploded.groupby('label')['top_3_hit'].agg(['count', 'sum']).reset_index()
label_precision = label_precision.rename(columns={'sum': 'hits', 'count': 'total entries'})
label_precision['precision'] = label_precision['hits'] / label_precision['total entries']
label_precision = label_precision.sort_values('precision')

label_precision[label_precision['precision'] < 0.8]

Unnamed: 0,label,total entries,hits,precision
8,day name,2,0,0.0
13,ground fare,25,3,0.12
16,quantity,52,9,0.173077
2,airfare,482,177,0.36722
14,ground service,272,183,0.672794
12,flight time,54,41,0.759259


Looking at the list of the least performing (lowest precisoin) classes, i have an intuition that one thing that can be done to improve the detection of those intents is rephrasing or "better" more explicit / accurate wording.

For example `day name` can be replaced with `day of the week` etc.

## Investigate the topic of FPs

Let's also have a look at "low-score" inferences. This might allow us to discover score-thresholds that never lead to a TP inference and can be accepted by us as the model being 'confused' and a case of FP.

In [14]:
report = []
headers = ["Score Threshold", "Score <= Threshold, Total Entries", "Top 3 Cls Precision"]

for ls_threshold in np.linspace(0.0, 0.25, 12):
    ls_inference = data[[x['scores'][0] <= ls_threshold for x in data['inference']]]
    ls_total = ls_inference.shape[0]
    hits_cnt = ls_inference['top_3_hit'].sum()
    precision = hits_cnt / ls_total if hits_cnt > 0 and ls_total > 0 else 0.0
    report.append([f'{ls_threshold:.2f}', ls_total, f'{precision:.2f}'])

print(tabulate(report, headers, tablefmt="grid"))

ls_inference.head(5)

+-------------------+-------------------------------------+-----------------------+
|   Score Threshold |   Score <= Threshold, Total Entries |   Top 3 Cls Precision |
|              0    |                                   0 |                  0    |
+-------------------+-------------------------------------+-----------------------+
|              0.02 |                                   0 |                  0    |
+-------------------+-------------------------------------+-----------------------+
|              0.05 |                                   0 |                  0    |
+-------------------+-------------------------------------+-----------------------+
|              0.07 |                                   0 |                  0    |
+-------------------+-------------------------------------+-----------------------+
|              0.09 |                                  15 |                  0.47 |
+-------------------+-------------------------------------+-----------------

Unnamed: 0,text,label,label_cnt,inference,top_3_classes,top_class,top_3_hit,top_hit
0,i want to fly from boston at 838 am and arrive...,flight,1,"{'labels': ['flight number', 'flight', 'airlin...","[flight number, flight, airline]",flight number,1,0
15,show me the first class fares from boston to d...,airfare,1,"{'labels': ['distance', 'abbreviation', 'fligh...","[distance, abbreviation, flight number]",distance,0,0
19,please give me the flights from boston to pitt...,flight,1,"{'labels': ['flight number', 'flight', 'distan...","[flight number, flight, distance]",flight number,1,0
20,i would like to fly from denver to pittsburgh ...,flight,1,"{'labels': ['airline', 'flight', 'aircraft', '...","[airline, flight, aircraft]",airline,1,0
24,i 'd like to have some information on a ticket...,airfare,1,"{'labels': ['distance', 'airport', 'airfare', ...","[distance, airport, airfare]",distance,1,0


As we can see, there are no entries with a Score less than 0.09 and even at this level we got Precision = 0.47 which is a not bad result considering that we were running inference against 17 classes.

With that, accepting Recall as 1.0 seems to be a good call.