# Zero-Shot Event Classification

In [1]:
import json
import torch
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from sentence_transformers import SentenceTransformer, util

## Implementing Simple Zero-Shot Classifier

Our approach in a nutshell:
* We use a sentence encoder from `sentence-transformers` to convert both label descriptions and texts to predict into embeddings that live in the same embedding space.
* At test time, we embed a new text and compare it to each label embedding via cosine similarity.
* We assign the label with the highest similarity to the item.
* Optionally, we define a minimum similarity threshold that a label needs to pass. If no label passes this threshold, we assign the "OTHER" class.


In [2]:
class ZeroShotClassifier:
    
    def __init__(self, model=None, threshold=None, null_label="OTHER"):
        self.model = model
        self.labels = []
        self.label_embeddings = None
        self.threshold = threshold
        self.null_label = null_label
    
    def train(self, labels, descriptions):
        self.labels = labels
        self.label_embeddings = model.encode(descriptions)
    
    def predict(self, input_texts=None, input_embeddings=None, output_scores=False):
        
        if input_embeddings is None:
            input_embeddings = self.model.encode(input_texts)
            
        S = util.pytorch_cos_sim(input_embeddings, self.label_embeddings)
        
        predicted_labels = []
        predicted_scores = []
        for i in range(input_embeddings.shape[0]):
            label_scores = S[i].tolist()
            scored = sorted(
                zip(self.labels, label_scores),
                key=lambda x: x[1],
                reverse=True
            )
            pred, score = scored[0]
            if self.threshold is not None and score < self.threshold:
                pred = self.null_label
                
            predicted_scores.append(scored)
            predicted_labels.append(pred)        
        
        if output_scores:
            return predicted_labels, predicted_scores
        else:
            return predicted_labels

## Preparing Data

In [3]:
import pandas as pd 
occ_df = pd.read_csv('occ-mgr-logs-2022-0124-0131.csv')
occ_df.head(20)

Unnamed: 0,_time,ENTRY
0,2022-01-28T02:06:00.000-0800,EOR.
1,2022-01-28T02:02:00.000-0800,M-Line Blanket established.
2,2022-01-28T01:52:00.000-0800,A-Line Blanket established.
3,2022-01-28T01:45:00.000-0800,"SA at M60 reports a sleeper on the platform, w..."
4,2022-01-28T01:38:00.000-0800,Core-Line Blanket established.
5,2022-01-28T01:33:00.000-0800,BPD requesting an extended dwell at W30-1 for ...
6,2022-01-28T01:31:00.000-0800,"(ref 0027) T323 released ATO, now (14) minutes..."
7,2022-01-28T01:25:00.000-0800,"T203 A50-1, no ATO doors, 311. Lead unit is C434."
8,2022-01-28T01:08:00.000-0800,Power Support reports when KMA D02 was tripped...
9,2022-01-28T01:03:00.000-0800,Final Trains for the evening released at K30.


In [4]:
import importlib.util

spec = importlib.util.spec_from_file_location("classifier_utils", "../../classifier_utils.py")
classifier_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(classifier_utils)

occ_df = classifier_utils.expand_acronyms(occ_df, "ENTRY", "bart-acronyms.csv")
occ_df.head(20)

Unnamed: 0,_time,ENTRY
0,2022-01-28T02:06:00.000-0800,EOR.
1,2022-01-28T02:02:00.000-0800,M-Line Blanket established.
2,2022-01-28T01:52:00.000-0800,A-Line Blanket established.
3,2022-01-28T01:45:00.000-0800,"SA at M60 reports a sleeper on the platform, w..."
4,2022-01-28T01:38:00.000-0800,Core-Line Blanket established.
5,2022-01-28T01:33:00.000-0800,BART Police Department requesting an extended...
6,2022-01-28T01:31:00.000-0800,(ref 0027) T323 released Automatic Train Opera...
7,2022-01-28T01:25:00.000-0800,"T203 A50-1, no Automatic Train Operation doors..."
8,2022-01-28T01:08:00.000-0800,Power Support reports when MacArthur Substatio...
9,2022-01-28T01:03:00.000-0800,Final Trains for the evening released at K30.


## Initializing Classifier

In [5]:
from sentence_transformers import SentenceTransformer, LoggingHandler
from sentence_transformers import models, util, datasets, evaluation, losses
from torch.utils.data import DataLoader

import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Jason\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [6]:
# Define sentence transformer model using CLS pooling
model_name = 'sentence-transformers/paraphrase-MiniLM-L6-v2'
word_embedding_model = models.Transformer(model_name)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Define a list with sentences (1k - 100k sentences)
train_sentences = occ_df['ENTRY'].tolist()[:2000]

# Create the special denoising dataset that adds noise on-the-fly
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences)

# DataLoader to batch data
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Use the denoising auto-encoder loss
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)

# Call the fit method -- may take a while to train without GPU utilization
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=3,
    weight_decay=0,
    scheduler='constantlr',
    optimizer_params={'lr': 3e-5},
    show_progress_bar=True
)

model.save('output/tsdae-model')
print("done")

When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.
Some weights of BertLMHeadModel were not initialized from the model checkpoint at sentence-transformers/paraphrase-MiniLM-L6-v2 and are newly initialized: ['encoder.layer.4.crossattention.self.key.bias', 'encoder.layer.5.crossattention.output.dense.weight', 'encoder.layer.0.crossattention.output.dense.bias', 'encoder.layer.1.crossattention.self.key.weight', 'encoder.layer.0.crossattention.self.value.weight', 'encoder.layer.2.crossattention.output.dense.bias', 'encoder.layer.2.crossattention.self.key.weight', 'encoder.layer.2.crossattention.output.dense.weight', 'encoder.layer.4.crossattention.output.dense.weight', 'encoder.layer.4.crossattention.self.value.bias', 'encoder.layer.1.crossattention.output.LayerNorm.bias', 'encoder.layer.2.crossattention.self.query.bias', 'encoder.layer.5.crossattention.self.query.weight', 'encoder.layer.1.crossattention.self.value.bias', 'encoder.layer.1.crossattention.output.dense.bi

Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1330 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1330 [00:00<?, ?it/s]

Iteration:   0%|          | 0/1330 [00:00<?, ?it/s]

done


## Evaluating Zero-Shot Classifier

In [10]:
import pandas as pd 

df = pd.read_csv('unlabeled.csv')[['Time', 'Log']].head(100)
df = classifier_utils.expand_acronyms(df, "Log", "bart-acronyms.csv")

df.head(5)

Unnamed: 0,Time,Log
0,1940,T451 no Automatic Train Operation doors at R30...
1,1957,T507 A10-1 BART Police Department hold for lo...
2,1957,T371 no Automatic Train Operation doors at M16...
3,2000,Medic10 and Medic16 checked out
4,2003,T507 released Automatic Train Operation. 2 min...


In [14]:
my_classifier = ZeroShotClassifier(
    model=model,
    threshold=0.7,    
    null_label="OTHER"
)

my_classifier.train(
    labels=["Medical", "Police", 'Delays', 'Mechanical', 'Electrical'],
    descriptions=[
        "medical",
        "police activity",
        "delays late",
        "no ATO doors",
        'electrical'
    ]
)

pred = my_classifier.predict(
    df['Log'].tolist()
)
df['Pred'] = pred
df.head(40)

Unnamed: 0,Time,Log,Pred
0,1940,T451 no Automatic Train Operation doors at R30...,Mechanical
1,1957,T507 A10-1 BART Police Department hold for lo...,Police
2,1957,T371 no Automatic Train Operation doors at M16...,Mechanical
3,2000,Medic10 and Medic16 checked out,Medical
4,2003,T507 released Automatic Train Operation. 2 min...,Delays
5,2004,T365 no Automatic Train Operation doors at M90...,Mechanical
6,2022,T369 Y10-2 double dashes.\r\nNo call from Central,Mechanical
7,2033,T223 no Automatic Train Operation doors at S20...,Mechanical
8,2052,T445 R10-1 possible medical emergency.TO to ch...,Mechanical
9,2055,A99 is at R10 and checking on the patron.,Mechanical
