In [1]:
import pandas as pd
from transformers import pipeline
from tqdm import tqdm
tqdm.pandas()

In [2]:
def zero_shot_predict_single_model(classifier, sequence_to_classify: str, candidate_labels: list):
    try:
        predictions = classifier(sequence_to_classify, candidate_labels)
        result = {predictions['labels'][i]: predictions['scores'][i] for i in range(len(predictions['labels']))}
        return result
    except Exception as e:
        print("The following error occured, returned empty string")
        print(e)
        return {}

In [3]:
data_all = pd.read_excel('../temp_training/medallion/gold/gold_COMBINED.xlsx')

In [9]:
data_sample = data_all.copy(deep=True)
retained_columns = ['URI', 'TOPIC', 'TITLE', 'TEXT', 'RELEVANCE_CLASS']
data_sample['TEXT'] = data_sample['BODY_SUMMARY'] # COMBINING title and body texts
data_sample['TEXT'] = data_sample['TEXT'].apply(lambda x: x.split('. ')) # split to sentences
# Explode the list of sentences into separate rows, duplicating values in other columns
data_sample_long = data_sample.explode('TEXT').reset_index(drop=True)
df_result = data_sample_long[retained_columns]
df_result

Unnamed: 0,URI,TOPIC,TITLE,TEXT,RELEVANCE_CLASS
0,7501230662,warehouse_fire,"Fire destroys Sangre Grande block factory, house",Fire destroyed a block factory in Sangre Grand...,1
1,7501230662,warehouse_fire,"Fire destroys Sangre Grande block factory, house",Losses were said to be millions of dollars,1
2,7501230662,warehouse_fire,"Fire destroys Sangre Grande block factory, house",Owner of the Pallet and Brick Factory claimed ...,1
3,7501230662,warehouse_fire,"Fire destroys Sangre Grande block factory, house",He said by the time the fire officers arrived ...,1
4,7501230662,warehouse_fire,"Fire destroys Sangre Grande block factory, house","However, the officers contained the fire from ...",1
...,...,...,...,...,...
14498,7397746307,air,"Auckland Airport warns worst is yet to come, 4...",Airline is now focusing its efforts on getting...,1
14499,7397668836,protest_riot,Residents mark anniversary of 'Freedom Convoy'...,A small group of Ottawa residents marked a fla...,0
14500,7397668836,protest_riot,Residents mark anniversary of 'Freedom Convoy'...,"Dubbed the Battle of Billings Bridge, the coun...",0
14501,7397668836,protest_riot,Residents mark anniversary of 'Freedom Convoy'...,Citizens blocked the intersection for hours,0


In [10]:
df_result.to_csv('gold_COMBINED_sentences_body_summary.csv')

In [10]:
df_sample = df_result.sample(100)

In [11]:
# 1. Load class specifications and gold data:
print("Loading class specifications...")
class_labels_data_path = '../data/input/class_label_by_topic_v1.0.csv'
class_labels_data = pd.read_csv(class_labels_data_path)
candidate_labels = list(class_labels_data['CLASS_DESCRIPTION'])
class_labels_df = class_labels_data.set_index('CLASS_DESCRIPTION')
class_labels_dict = class_labels_df.to_dict(orient='index')

# 2. Load pretrained model
print("Loading pretrained model...")
model_path = '../models/pretrained/bart-large-mnli/'
loaded_classifier = pipeline("zero-shot-classification", model=model_path)

# prediction on title + body summary.
print("Predicting...")
df_sample['PREDICTIONS'] = df_sample.progress_apply(lambda row: 
                                                zero_shot_predict_single_model(
                                                    classifier=loaded_classifier, 
                                                    sequence_to_classify=(row['TEXT']), 
                                                    candidate_labels=candidate_labels), 
                                                    axis=1)

Loading class specifications...
Loading pretrained model...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Predicting...


100%|██████████| 100/100 [14:47<00:00,  8.88s/it]


In [12]:
df_sample.to_excel('sample.xlsx')