In [21]:
import pickle
from sklearn.base import BaseEstimator, TransformerMixin

# Load the pre-trained models from pickled files
with open('../../pickled_models/15_nmb_pipe_model.pkl', 'rb') as f:
    binary_model = pickle.load(f)

with open('../../pickled_models/15_binary_relevance_mnb.pkl', 'rb') as f:
    multilabel_model = pickle.load(f)
       
# Create a custom transformer class
# create sequential model class that inherits properties from BaseEstimator and TransformerMixin
# BaseEstimator allows you to create a custom estimator
# TransformerMixin allows you to transform data within the class

class SequentialModel(BaseEstimator, TransformerMixin):
    
  ##constructor/initializer - needs self, and takes in the binary and multilabel model  
    def __init__(self, binary_model, multilabel_model):
        self.binary_model = binary_model
        self.multilabel_model = multilabel_model
        
# Inside the constructor, self.binary_model and self.multilabel_model are instance variables. 
# They store the references to the binary and multilabel models so that they can be accessed 
# by other methods within the class.
    
    #takes in the x data, doesn't need y because we're not fitting and just trying to get predictions
    
    def transform(self, X, y=None):
        
        #get binary predictions from X data
        binary_predictions = self.binary_model.predict(X)
        
        #if label is toxic, take that prediction and save the prediction and index
        filtered_indices = [i for i, pred in enumerate(binary_predictions) if pred == 1]
        #filter based on those toxic indices
        filtered_data = X[filtered_indices]
        #run that text data through the multilabel model to get multilabel predictions
        multilabel_predictions = self.multilabel_model.predict(filtered_data)
        
        #return those predictions
        return multilabel_predictions
    
    def predict(self, text):
        
        #get binary predictioon
        binary_prediction = self.binary_model.predict([text])
        #if predicts toxic
        if binary_prediction[0] == 1:
            #text goes to multilabel model
            multilabel_prediction = self.multilabel_model.predict([text])
            #takes predictions, converts to a flattened array
            predicted_labels = multilabel_prediction.toarray().flatten()
            #list of toxicity types
            toxicity_types = ['toxic', 'severe_toxic', 'obscene', 
           'threat', 'insult', 'identity_hate']
            
            #getting index of predicted positive labels, taking that index and pulling out toxicity types at that same index
            predicted_column_names = [toxicity_types[i] for i, label in enumerate(predicted_labels) if label == 1]
            #returns toxicity types
            return predicted_column_names
        #or returns neutral if the prediction was 0 in the binary classifier
        else:
            return ['neutral']

  


In [22]:
sequential_model = SequentialModel(binary_model, multilabel_model)

In [23]:
comment1 = 'Today is a great day for a presentation'
comment2 = 'Today super sucks and I hate it! You stupid idiot'

In [25]:
sequential_model.predict(comment2)

['toxic', 'obscene', 'insult']

In [None]:
comment

Sources:

https://www.andrewvillazon.com/custom-scikit-learn-transformers/

https://www.youtube.com/watch?v=DctmeFx8s_k