In [None]:
# !pip install transformers==4.30.2
# !pip install datasets==2.12.0

In [54]:
from datasets import load_dataset, Dataset
import pandas as pd
import numpy as np
import string
from unidecode import unidecode
import tensorflow as tf 
from sklearn.utils import class_weight
from tensorflow.keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import cloudpickle

In [69]:
class TextPreprocessor:
    def __init__(self, remove_punct: bool = True, remove_digits: bool = True,
                 remove_stop_words: bool = True,
                 remove_short_words: bool = True, minlen: int = 1, maxlen: int = 1, top_p: float = None,
                 bottom_p: float = None):
        self.remove_punct = remove_punct
        self.remove_digits = remove_digits
        self.remove_stop_words = remove_stop_words
        self.remove_short_words = remove_short_words
        self.minlen = minlen
        self.maxlen = maxlen
        self.top_p = top_p
        self.bottom_p = bottom_p
        self.words_to_remove = []
        self.stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 'ours', 'ourselves', 'you',
                           'your', 'yours', 'yourself', 'yourselves', 'he', 'him', 'his', 'himself',
                           'she', 'her', 'hers', 'herself', 'it', 'its', 'itself', 'they', 'them',
                           'their', 'theirs', 'themselves', 'what', 'which', 'who', 'whom', 'this', 'that',
                           'these', 'those', 'am', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has',
                           'had', 'having', 'do', 'does', 'did', 'doing', 'a', 'an', 'the', 'and', 'if', 'or',
                           'because', 'as', 'until', 'while', 'of', 'at', 'by', 'for', 'with', 'about',
                           'into', 'through', 'during', 'before', 'after', 'to', 'from',
                           'in', 'out', 'on', 'off', 'further', 'then', 'once',
                           'here', 'there', 'when', 'where', 'why', 'how', 'all', 'any', 'both', 'each',
                           'other', 'such', 'only', 'own', 'same', 'so', 'than',
                           'too', 'can', 'will', 'just', 'should',
                           'now']

        

    @staticmethod
    def __remove_double_whitespaces(string: str):
        return " ".join(string.split())

    def __remove_punct(self, string_series: pd.Series):
        """
       Removes punctuations from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.copy()
        puncts = [r'\n', r'\r', r'\t']
        puncts.extend(list(string.punctuation))
        for i in puncts:
            clean_string_series = clean_string_series.str.replace(pat=i, repl=" ", regex=False).copy()
        return clean_string_series.map(self.__remove_double_whitespaces)

    def __remove_digits(self, string_series: pd.Series):
        """
       Removes digits from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        clean_string_series = string_series.str.replace(pat=r'\d', repl=" ", regex=True).copy()
        return clean_string_series.map(self.__remove_double_whitespaces)
 

    def __remove_stop_words(self, string_series: pd.Series):
        """
       Removes stop words from the input string.
       :param string_series: pd.Series, input string series
       :return: pd.Series, cleaned string series
       """
        def str_remove_stop_words(string: str):
            stops = self.stop_words
            return " ".join([token for token in string.split() if token not in stops])

        return string_series.map(str_remove_stop_words)

    

    def preprocess(self, string_series: pd.Series, dataset: str = "train"):
        """
        Entry point.
        :param string_series: pd.Series, input string series
        :param dataset: str, "train" for training set, "tesrt" for val/dev/test set.
        :return: pd.Series, cleaned string series
        """
        string_series = string_series.str.lower().copy()
        if self.remove_punct:
            string_series = self.__remove_punct(string_series=string_series)
        if self.remove_digits:
            string_series = self.__remove_digits(string_series=string_series)
        if self.remove_stop_words:
            string_series = self.__remove_stop_words(string_series=string_series)
        

        string_series = string_series.str.strip().copy()
        string_series.replace(to_replace="", value="this is an empty message", inplace=True)

        return string_series

In [61]:
data = pd.read_csv("/kaggle/input/dbpedia-classes/DBP_wiki_data.csv")

In [62]:
data

Unnamed: 0,text,l1,l2,l3,wiki_name,word_count
0,The 1994 Mindoro earthquake occurred on Novemb...,Event,NaturalEvent,Earthquake,1994_Mindoro_earthquake,59
1,The 1917 Bali earthquake occurred at 06:50 loc...,Event,NaturalEvent,Earthquake,1917_Bali_earthquake,68
2,The 1941 Colima earthquake occurred on April 1...,Event,NaturalEvent,Earthquake,1941_Colima_earthquake,194
3,The 1983 Coalinga earthquake occurred on May 2...,Event,NaturalEvent,Earthquake,1983_Coalinga_earthquake,98
4,The 2013 Bushehr earthquake occurred with a mo...,Event,NaturalEvent,Earthquake,2013_Bushehr_earthquake,61
...,...,...,...,...,...,...
342776,WCSH is the NBC-affiliated television station ...,Agent,Broadcaster,TelevisionStation,WCSH,198
342777,Al Jazeera America (AJAM) was an American basi...,Agent,Broadcaster,TelevisionStation,Al_Jazeera_America,226
342778,"CJOH-DT, VHF channel 13, is a CTV owned-and-op...",Agent,Broadcaster,TelevisionStation,CJOH-DT,234
342779,NTTV (North Texas Television) is a student tel...,Agent,Broadcaster,TelevisionStation,NTTV,126


In [63]:
data = data[['text', 'l2']].copy()
data.rename(columns={"l2": "category"}, inplace=True)

In [65]:
data = data.sample(frac=0.4, random_state=42)

In [66]:
data

Unnamed: 0,text,category
48627,"\""Y Yo Sigo Aquí\"" (English: \""And I'm Still R...",MusicalWork
126091,The Contemporary Art Gallery (CAG) is a non-co...,Building
337974,BBC Radio Foyle (Irish: BBC Raidió Feabhail) i...,Broadcaster
87372,The Marsabit clawed frog (Xenopus borealis) is...,Animal
138102,Mount Fiske Glacier is a small glacier located...,NaturalPlace
...,...,...
123348,Sequoia Hospital is a hospital in Redwood City...,Building
74806,Ganolytes is an extinct genus of prehistoric b...,Animal
283346,Juliet Cariaga is an American glamour model wh...,Actor
250753,Stefan Holtz (born 27 February 1981) is a Germ...,Athlete


In [67]:
data['text'][0]

'The 1994 Mindoro earthquake occurred on November 15 at 03:15 local time near Mindoro, the Philippines. It had a moment magnitude of 7.1. It is associated with a 35 kilometer-long ground rupture, called the Aglubang River fault. Seventy eight people were reported dead, and 7,566 houses were damaged. The earthquake generated a tsunami and landslides on the Verde Island.'

In [70]:
tp = TextPreprocessor()
data['text'] = tp.preprocess(data['text'])

In [71]:
data['text'][0]

'mindoro earthquake occurred november local time near mindoro philippines moment magnitude associated kilometer long ground rupture called aglubang river fault seventy eight people reported dead houses damaged earthquake generated tsunami landslides verde island'

In [72]:
label_dist = pd.Series(data['category']).value_counts()
label_dist

Athlete             17705
Person              11146
Animal               8531
Building             6016
Politician           5371
                    ...  
MusicalArtist         126
RaceTrack             104
ComicsCharacter        82
Database               80
VolleyballPlayer       64
Name: category, Length: 70, dtype: int64

In [73]:
# finding categories with proportion > 0.8%
ld = pd.Series(data['category']).value_counts(normalize=True)
cats_to_keep = [*ld[ld > 0.008].index]

In [74]:
# categories with proportion > 0.8%
cats_to_keep

['Athlete',
 'Person',
 'Animal',
 'Building',
 'Politician',
 'Company',
 'Organisation',
 'MusicalWork',
 'WinterSportPlayer',
 'SocietalEvent',
 'RouteOfTransportation',
 'PeriodicalLiterature',
 'SportsTeam',
 'NaturalPlace',
 'Artist',
 'EducationalInstitution',
 'Broadcaster',
 'Cleric',
 'Tournament',
 'SportsEvent',
 'SportsTeamSeason',
 'Settlement',
 'Infrastructure',
 'Plant',
 'CelestialBody',
 'SportFacility',
 'SportsLeague',
 'Stream',
 'Race',
 'Comic',
 'FictionalCharacter',
 'ClericalAdministrativeRegion',
 'FootballLeagueSeason',
 'Cartoon']

In [75]:
# renaming categories < 0.8% to others
data.loc[data['category'].isin(cats_to_keep) == False, 'category'] = "Others"

In [76]:
data

Unnamed: 0,text,category
48627,y yo sigo aquí english m still right song mexi...,MusicalWork
126091,contemporary art gallery cag non collecting pu...,Building
337974,bbc radio foyle irish bbc raidió feabhail bbc ...,Broadcaster
87372,marsabit clawed frog xenopus borealis species ...,Animal
138102,mount fiske glacier small glacier located sier...,NaturalPlace
...,...,...
123348,sequoia hospital hospital redwood city califor...,Building
74806,ganolytes extinct genus prehistoric bony fish ...,Animal
283346,juliet cariaga american glamour model selected...,Others
250753,stefan holtz born february german sprint canoe...,Athlete


In [78]:
label_dist = pd.Series(data['category']).value_counts()
label_dist

Others                          20058
Athlete                         17705
Person                          11146
Animal                           8531
Building                         6016
Politician                       5371
Company                          4722
Organisation                     4117
MusicalWork                      4029
WinterSportPlayer                3634
SocietalEvent                    3439
RouteOfTransportation            3412
PeriodicalLiterature             3198
SportsTeam                       3109
NaturalPlace                     3028
Artist                           2792
EducationalInstitution           2592
Broadcaster                      2568
Cleric                           2546
Tournament                       2335
SportsEvent                      2266
SportsTeamSeason                 2253
Settlement                       2193
Infrastructure                   2162
Plant                            1627
CelestialBody                    1386
SportFacilit

In [79]:
le = LabelEncoder()
data['target'] = le.fit_transform(data['category'])

In [20]:
with open("/kaggle/working/preprocessor_labelencoder.bin", "wb") as model_file_obj:
    cloudpickle.dump((tp, le), model_file_obj)

In [80]:
data

Unnamed: 0,text,category,target
48627,y yo sigo aquí english m still right song mexi...,MusicalWork,15
126091,contemporary art gallery cag non collecting pu...,Building,4
337974,bbc radio foyle irish bbc raidió feabhail bbc ...,Broadcaster,3
87372,marsabit clawed frog xenopus borealis species ...,Animal,0
138102,mount fiske glacier small glacier located sier...,NaturalPlace,16
...,...,...,...
123348,sequoia hospital hospital redwood city califor...,Building,4
74806,ganolytes extinct genus prehistoric bony fish ...,Animal,0
283346,juliet cariaga american glamour model selected...,Others,18
250753,stefan holtz born february german sprint canoe...,Athlete,2


In [81]:
x = data['text'].copy()
y = data['target'].copy()

In [82]:
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1, random_state=42, stratify=y)

In [83]:
x_train.shape, x_test.shape, y_train.shape, y_test.shape

((123400,), (13712,), (123400,), (13712,))

In [84]:
x_train, x_test, y_train, y_test = x_train.to_list(), x_test.to_list(), y_train.to_list(), y_test.to_list()

In [85]:
classes_ = sorted([*y.unique()]).copy()
classes_

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34]

In [86]:
from transformers import DistilBertTokenizerFast

In [87]:
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)

In [88]:
print(x_train[0])
print(tokenizer.tokenize(x_train[0]))
print(tokenizer(x_train[0]))

hiromiyuki cd main belt asteroid discovered february m arai h mori yorii photometric observations asteroid gave light curve period ± hours brightness variation magnitude
['hi', '##rom', '##iy', '##uki', 'cd', 'main', 'belt', 'asteroid', 'discovered', 'february', 'm', 'ara', '##i', 'h', 'mori', 'yo', '##ri', '##i', 'photo', '##metric', 'observations', 'asteroid', 'gave', 'light', 'curve', 'period', '±', 'hours', 'brightness', 'variation', 'magnitude']
{'input_ids': [101, 7632, 21716, 28008, 14228, 3729, 2364, 5583, 12175, 3603, 2337, 1049, 19027, 2072, 1044, 22993, 10930, 3089, 2072, 6302, 12589, 9420, 12175, 2435, 2422, 7774, 2558, 1081, 2847, 18295, 8386, 10194, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [35]:
# train_token_lens = [len(tokenizer.tokenize(i)) for i in x_train]

In [36]:
# pd.Series(train_token_lens).quantile(0.98)

In [89]:
strategy = tf.distribute.MirroredStrategy()

In [90]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
N_TOKENS = 200
N_CLASSES = len(classes_)

In [91]:
train_tokens = tokenizer(x_train, max_length=N_TOKENS, padding="max_length", truncation=True, return_tensors="tf", return_attention_mask=True)
test_tokens = tokenizer(x_test, max_length=N_TOKENS, padding="max_length", truncation=True, return_tensors="tf", return_attention_mask=True)

In [40]:
train_tokens[:5]

[Encoding(num_tokens=200, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=200, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=200, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=200, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing]),
 Encoding(num_tokens=200, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])]

In [92]:
train_tf_data = tf.data.Dataset.from_tensor_slices((dict(train_tokens), to_categorical(y_train)))
test_tf_data = tf.data.Dataset.from_tensor_slices((dict(test_tokens), to_categorical(y_test)))

In [43]:
del(data)
del(train_tokens)
del(test_tokens)

In [44]:
# Using class_weight to handle class imbalance
class_weight_param = {class_: weight for class_, weight in
                          zip(classes_, class_weight.compute_class_weight(class_weight='balanced',
                                                                          classes=classes_, y=y_train))}
class_weight_param

{0: 0.45919696349495776,
 1: 1.4029901654255017,
 2: 0.22126988111674945,
 3: 1.5256227977993448,
 4: 0.6512217003535806,
 5: 3.5434314429289304,
 6: 2.827357085576813,
 7: 1.5389411984785184,
 8: 3.4330226735290026,
 9: 3.2675757976962796,
 10: 0.8295798319327731,
 11: 1.5112362990631314,
 12: 3.307424283034039,
 13: 3.501205844800681,
 14: 1.8117750697401263,
 15: 0.9723426049956663,
 16: 1.2938401048492791,
 17: 0.9516097937150568,
 18: 0.19530879047830077,
 19: 1.2250570832919687,
 20: 0.3514818348832904,
 21: 2.4082747853239654,
 22: 0.7293575270406052,
 23: 3.213960151061336,
 24: 1.1480671721635578,
 25: 1.7860761325806918,
 26: 1.139164551119317,
 27: 2.9234778488509834,
 28: 1.7291389336509493,
 29: 2.930768317301983,
 30: 1.2600837332788726,
 31: 1.7385178923640463,
 32: 3.1763191763191765,
 33: 1.6773141226043224,
 34: 1.077870463379482}

In [45]:
train_tf_data=train_tf_data.prefetch(tf.data.AUTOTUNE)
test_tf_data=test_tf_data.prefetch(tf.data.AUTOTUNE)

In [46]:
for i in train_tf_data.take(1):
    print(i)

({'input_ids': <tf.Tensor: shape=(200,), dtype=int32, numpy=
array([  101,  7632, 21716, 28008, 14228,  3729,  2364,  5583, 12175,
        3603,  2337,  1049, 19027,  2072,  1044, 22993, 10930,  3089,
        2072,  6302, 12589,  9420, 12175,  2435,  2422,  7774,  2558,
        1081,  2847, 18295,  8386, 10194,   102,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,  

In [47]:
from transformers import TFDistilBertModel, DistilBertConfig
from tensorflow.keras.layers import Input, Dense, Dropout, Average, BatchNormalization

In [48]:
config = DistilBertConfig.from_pretrained(model_checkpoint, output_hidden_states=True)

In [49]:
with strategy.scope():
    model = TFDistilBertModel.from_pretrained(model_checkpoint, config=config)
    input_ids = Input(shape=(N_TOKENS,), dtype=tf.int32, name="input_ids")
    attention_mask = Input(shape=(N_TOKENS,), dtype=tf.int32, name="attention_mask")
    x = model([input_ids, attention_mask])
    x = Average()([x[1][0][:,0,:], x[1][1][:,0,:],x[1][2][:,0,:],x[1][3][:,0,:]]) # Finding mean of [CLS] token of last 4 hidden layers
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    x = Dense(128, activation="relu")(x)
    x = Dropout(0.3)(x)
    x = BatchNormalization()(x)
    output = Dense(N_CLASSES, activation="softmax", name="output")(x)
    model = tf.keras.Model(inputs=[input_ids, attention_mask],outputs=output)
    metric = tf.keras.metrics.AUC()
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), metrics=[metric, "categorical_accuracy"], loss="categorical_crossentropy")

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFDistilBertModel: ['vocab_projector.bias', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight']
- This IS expected if you are initializing TFDistilBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFDistilBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.


In [50]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 200)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 200)]        0           []                               
                                                                                                  
 tf_distil_bert_model (TFDistil  TFBaseModelOutput(l  66362880   ['input_ids[0][0]',              
 BertModel)                     ast_hidden_state=(N               'attention_mask[0][0]']         
                                one, 200, 768),                                                   
                                 hidden_states=((No                                           

In [51]:
model.layers[2].trainable = False # Freezing the weights of DistilBERT

In [52]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_ids (InputLayer)         [(None, 200)]        0           []                               
                                                                                                  
 attention_mask (InputLayer)    [(None, 200)]        0           []                               
                                                                                                  
 tf_distil_bert_model (TFDistil  TFBaseModelOutput(l  66362880   ['input_ids[0][0]',              
 BertModel)                     ast_hidden_state=(N               'attention_mask[0][0]']         
                                one, 200, 768),                                                   
                                 hidden_states=((No                                           

In [53]:
from tensorflow.keras.callbacks import EarlyStopping
early_stop = EarlyStopping(monitor="val_loss",patience=1,mode="min")

In [55]:
model.fit(train_tf_data.shuffle(len(train_tf_data)).batch(BATCH_SIZE), validation_data=test_tf_data.shuffle(len(test_tf_data)).batch(BATCH_SIZE), 
          epochs=2, callbacks=[early_stop], class_weight=class_weight_param)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7d02002cfcd0>

In [56]:
del(train_tf_data)
del(test_tf_data)

In [57]:
model.save("/kaggle/working/dbpedia_classifier_hf_distilbert.h5")

In [58]:
del(model)

**INFERENCE**

In [10]:
import transformers
import tensorflow as tf
import cloudpickle
from transformers import DistilBertTokenizerFast

with open("/kaggle/input/dbpedia-model/preprocessor_labelencoder.bin", "rb") as model_file_obj:
    text_preprocessor, label_encoder = cloudpickle.load(model_file_obj)
model = tf.keras.models.load_model('/kaggle/input/dbpedia-model/dbpedia_classifier_hf_distilbert.h5', custom_objects={"TFDistilBertModel": transformers.TFDistilBertModel})
model_checkpoint = "distilbert-base-uncased"
tokenizer = DistilBertTokenizerFast.from_pretrained(model_checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt: 0.00B [00:00, ?B/s]

Downloading (…)/main/tokenizer.json: 0.00B [00:00, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [234]:
import numpy as np
import pandas as pd
def inference(text: str):
    text = text_preprocessor.preprocess(pd.Series(text))[0]
    input_=tf.data.Dataset.from_tensor_slices((dict(tokenizer([text], 
                                                     max_length=200, padding="max_length", 
                                                     truncation=True, return_tensors="tf"))))
    pred = model.predict(input_, verbose=0)
    arg_max = np.argmax(pred[0])
    return [label_encoder.inverse_transform([arg_max])[0],  pred[0][arg_max]]

In [235]:
txt = '''
Shri Narendra Modi was sworn-in as India’s Prime Minister on 30th May 2019, marking the start of his second term in office. The first ever Prime Minister to be born after Independence, Shri Modi has previously served as the Prime Minister of India from 2014 to 2019. He also has the distinction of being the longest serving Chief Minister of Gujarat with his term spanning from October 2001 to May 2014. In the 2014 and 2019 Parliamentary elections, Shri Modi led the Bharatiya Janata Party to record wins, securing absolute majority on both occasions. The last time that a political party 
'''
inference(txt)

['Politician', 0.9767795]

**CONVERT TO ONNX**

In [None]:
!pip install tf2onnx
!pip install onnx
!pip install onnxruntime

In [124]:
import tensorflow as tf
import tf2onnx
import onnx

In [196]:
onnx_model, _ = tf2onnx.convert.from_keras(model)
onnx.save(onnx_model, "/kaggle/working/dbpedia_classifier_hf_distilbert.onnx")

**INFERENCE WITH ONNX**

In [198]:
import onnxruntime

In [230]:
def onnx_inference(text):
    session = onnxruntime.InferenceSession("/kaggle/working/dbpedia_classifier_hf_distilbert.onnx", None)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    text = text_preprocessor.preprocess(pd.Series(text))[0]
    input_ = tokenizer([text], max_length=200, padding="max_length", 
                        truncation=True, return_tensors="tf")
    inputs = {key: value.numpy() for key, value in input_.items()}
    result = session.run(None, inputs)[0][0]
    arg_max = np.argmax(result)  
    return [label_encoder.inverse_transform([arg_max])[0], result[arg_max]]

In [231]:
txt = '''
Shri Narendra Modi was sworn-in as India’s Prime Minister on 30th May 2019, marking the start of his second term in office. The first ever Prime Minister to be born after Independence, Shri Modi has previously served as the Prime Minister of India from 2014 to 2019. He also has the distinction of being the longest serving Chief Minister of Gujarat with his term spanning from October 2001 to May 2014. In the 2014 and 2019 Parliamentary elections, Shri Modi led the Bharatiya Janata Party to record wins, securing absolute majority on both occasions. The last time that a political party 
'''

onnx_inference(txt)

['Politician', 0.9767795]

In [236]:
def compare_tf_vs_onnx_prediction(text):
    print(f"TF: {inference(text)}\tONNX: {onnx_inference(text)}")

In [237]:
text = '''
Shri Narendra Modi was sworn-in as India’s Prime Minister on 30th May 2019, marking the start of his second term in office. The first ever Prime Minister to be born after Independence, Shri Modi has previously served as the Prime Minister of India from 2014 to 2019. He also has the distinction of being the longest serving Chief Minister of Gujarat with his term spanning from October 2001 to May 2014. In the 2014 and 2019 Parliamentary elections, Shri Modi led the Bharatiya Janata Party to record wins, securing absolute majority on both occasions. The last time that a political party 
'''
compare_tf_vs_onnx_prediction(text)

TF: ['Politician', 0.9767795]	ONNX: ['Politician', 0.9767795]


In [238]:
text = '''
Kalvakuntla Chandrashekar Rao, popularly known as KCR, was born in Chintamadaka, Medak Dist, Telangana, India to Sri Raghava Rao and Smt. Venkatamma on February 17, 1954 . After leading the Telangana movement to its desired end, he became the first Chief Minister of the State. He is also the founder President of Telangana Rashtra Samithi (TRS) which was in the forefront of the movement for Statehood to Telangana. He has been a Member of the Legislative Assembly (MLA) for multiple terms from different constituencies. He was also a Member of Parliament and has also held the position of a Cabinet Minister for Labour and Employment.
'''
compare_tf_vs_onnx_prediction(text)

TF: ['Politician', 0.6683515]	ONNX: ['Politician', 0.66835153]


In [242]:
text = '''
Tata Motors Group (Tata Motors) is a $37 billion organisation. It is a leading global automobile manufacturing company. Its diverse portfolio includes an extensive range of cars, sports utility vehicles, trucks, buses and defence vehicles. Tata Motors is one of India's largest OEMs offering an extensive range of integrated, smart and e-mobility solutions
'''
compare_tf_vs_onnx_prediction(text)


TF: ['Company', 0.98643446]	ONNX: ['Company', 0.98643434]


In [246]:
text = '''
Switzerland, federated country of central Europe. Switzerland’s administrative capital is Bern, while Lausanne serves as its judicial centre. Switzerland’s small size—its total area is about half that of Scotland—and its modest population give little indication of its international significance.
'''
compare_tf_vs_onnx_prediction(text)


TF: ['Settlement', 0.9996996]	ONNX: ['Settlement', 0.9996996]


In [248]:
text = '''
IPL, Indian professional Twenty20 (T20) cricket league established in 2008. The league, which is based on a round-robin group and knockout format, has teams in major Indian cities.

The brainchild of the Board of Control for Cricket in India (BCCI), the IPL has developed into the most lucrative and most popular outlet for the game of cricket. Matches generally begin in late afternoon or evening so that at least a portion of them are played under floodlights at night to maximize the television audience for worldwide broadcasts. Initially, league matches were played on a home-and-away basis between all teams, but, with the planned expansion to 10 clubs (divided into two groups of five) in 2011, that format changed so that matches between some teams would be limited to a single encounter. The top four teams contest three play-off matches, with one losing team being given a second chance to reach the final, a wrinkle aimed at maximizing potential television revenue. The play-off portion of the tournament involves the four teams that finished at the top of the tables in a series of knockout games that allows one team that lost its first-round game a second chance to advance to the final match.
'''
compare_tf_vs_onnx_prediction(text)


TF: ['SportsLeague', 0.9999522]	ONNX: ['SportsLeague', 0.9999522]


In [249]:
text = '''
baseball, game played with a bat, a ball, and gloves between two teams of nine players each on a field with four white bases laid out in a diamond (i.e., a square oriented so that its diagonal line is vertical). Teams alternate positions as batters (offense) and fielders (defense), exchanging places when three members of the batting team are “put out.” As batters, players try to hit the ball out of the reach of the fielding team and make a complete circuit around the bases for a “run.” The team that scores the most runs in nine innings (times at bat) wins the game.
'''
compare_tf_vs_onnx_prediction(text)


TF: ['SportsEvent', 0.40093556]	ONNX: ['SportsEvent', 0.4009356]


In [250]:
text = '''
Malla Reddy Engineering College (Autonomous) – MREC, is one of the top notch and highly reputed engineering colleges in Hyderabad, Telangana. MREC is part of Malla Reddy Group of Institutions (MRGI), founded by Sri. Ch. Malla Reddy, currently Hon’ble Minister, Labor and Employment, Factories and Skill Development, Govt. of Telangana State, who has invaluable insights into technical education of highest quality. The college is situated in a serene, lush green environment on Kompally- Bahadurpally Road, opposite Forest Academy, Medchal-Malkajgiri District, Telangana State and adjacent to Urban Forest.
'''
compare_tf_vs_onnx_prediction(text)


TF: ['EducationalInstitution', 0.9998272]	ONNX: ['EducationalInstitution', 0.9998273]
