In [1]:
from sklearn.base import BaseEstimator, TransformerMixin
from transformers import pipeline
import numpy as np
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from utils import get_reviews, results
import pandas as pd

class BertFeatures(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.nlp = pipeline('feature-extraction')
        
    def fit(self, X, y=None):
        return self
    
    def transform(self, X: np.array):
        X = list(X)
        print(f'Extracts bert features for:{len(X)} sequences')
        features = np.array(
            self.nlp(X, pad_to_max_length=True) # max memory overlod
        )
        print('Done extracting bert features')
        return features
    
class PooledOutput(BaseEstimator, TransformerMixin):
    def fit(self, X, y=None):
        return self
    
    def transform(self, X: np.array, y=None):
        # Size is n_datapints x n_tokens + 2 x nlp_model_dimension
        # +2 because bert adds cls in front and sep at end
        # returns first tokens embedding for each sequence
        return X[:, 0, :]
    
def make_pooled_bert():
    return Pipeline([
        ('bert_features', BertFeatures()),
        ('pooling', PooledOutput()),
        ('classifier', MLPClassifier(
            batch_size=100, hidden_layer_sizes=(64, 64),
            solver='adam', verbose=True, max_iter=400,
        ))
    ])


df = get_reviews()


def first_n_words(series, n):
    return (
        series
        .str.split(' ')
        # selects first n items in sequence
        .str[:n]
        .str.join(' ')
    )


# Modify max length of sequence
# Got memory overload
n = 150
df = df.assign(review=lambda df: first_n_words(df['review'], n))
#df = df.sample(100)

(
    pd.DataFrame(results(
        df,
        make_model=make_pooled_bert,
        n_data_points=[1000, 10000, 20000, len(df)],
        #n_data_points=[50, 100],
        batch_size_inference=1000
    ))
    .set_index('n_data_points')
    .plot(title='Number of data points and train set accuracy')
);

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=230.0, style=ProgressStyle(description_…


Fits for number of data points:50
Extracts bert features for:49 sequences


KeyboardInterrupt: 