<a href="https://colab.research.google.com/github/marble999/bert-distillation/blob/master/Transformer2CNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!wget -nc http://nlp.stanford.edu/data/glove.6B.zip
!unzip -n glove.6B.zip

# Transformer to CNN: Label-scarce distillation for efficient text classification
Source code for paper submission to NeurIPS Workshop on Compact Deep Neural Networks with Industry Applications

In [0]:
from google.colab import drive
drive.mount('drive')

In [0]:
TASK = "ag_news" #@param ["ag_news", "dbpedia", "yahoo_answers"]

from pathlib import Path
PATH = "drive/My Drive/deep/compact_demo" #@param {type:"string"}
PATH = Path(PATH)  
#@markdown Path to google drive datasets folder (you can add [this folder](https://drive.google.com/drive/folders/1272ZQbiUr-U8lrKy5EEdvYoVxYT7_cAq?usp=sharing) to your drive)
print('TASK: {}, PATH: {}'.format(TASK, PATH))

###Utilities & Models

In [0]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score


def shell_format(path_string):
    return str(path_string).replace(' ', '\ ')


def split_data(data, labelled_limit=100, unlabelled_limit=1000):
    n_classes = len(set(data.target))
    labelled_size = labelled_limit * n_classes
    unlabelled_size = unlabelled_limit * n_classes
    print('Labelled size: {}, Unlabelled_size: {}'.format(
        labelled_size, unlabelled_size))
    x_trn, x_val, y_trn, y_val = train_test_split(
        data.data,
        data.target,
        random_state=42,
        train_size=labelled_size,
        stratify=data.target)
    x_tra = x_val[labelled_size:][:unlabelled_size]
    x_val = x_val[:labelled_size]
    y_tra = y_val[labelled_size:][:unlabelled_size]
    y_val = y_val[:labelled_size]
    return [np.asarray(a) for a in [x_trn, x_val, x_tra, y_trn, y_val, y_tra]]

In [0]:
import matplotlib.pyplot as plt
from IPython.display import clear_output

class LossPlot():
    def __init__(self):
        self.history = []

    def update(self, value):
        self.history.append(value)

    def plot(self):
        clear_output(wait=True)
        plt.plot(self.history)
        plt.show()

In [0]:
!pip install -q finetune==0.4.1
!python -m spacy download en

import warnings
import numpy as np
from tqdm import tqdm
from finetune import Classifier, Regressor
from finetune.utils import iter_data
from sklearn.metrics import f1_score
from sklearn.metrics.pairwise import cosine_similarity


def process_proba(proba):
    # Convert normal predict_proba output dict into a nparray
    n_classes = len(proba[0])
    n_samples = len(proba)
    print('n_classes: {}, n_samples: {}'.format(n_classes, n_samples))
    return np.array(
        [[proba[i][j] for j in range(n_classes)] for i in range(n_samples)])


class OpenAIClassifier(Classifier):
    def _predict_proba_op(self, logits, **kwargs):
        return logits

    def scorer(self, target, pred):
        return f1_score(target, pred, average='weighted')

    def fit(self, x_trn, x_val, y_trn, y_val, val_freq=5):
        lossplot = LossPlot()
        self.config.val_interval = int(
            len(x_trn) / self.config.batch_size / val_freq)
        print('N train: {}, Val interval: {}, Val frequency: {}'.format(
            len(x_trn), self.config.val_interval, val_freq))

        arr_encoded_trn = self._text_to_ids(x_trn)
        arr_encoded_val = self._text_to_ids(x_val)

        batch_size = self.config.batch_size
        n_batch_train = batch_size * max(len(self.config.visible_gpus), 1)
        n_examples = len(x_trn)
        n_updates_total = (n_examples // n_batch_train) * self.config.n_epochs
        self.label_encoder = self._target_encoder()

        ds_trn = (arr_encoded_trn.token_ids, arr_encoded_trn.mask,
                  self.label_encoder.fit_transform(np.asarray(y_trn)))
        ds_val = (arr_encoded_val.token_ids, arr_encoded_val.mask,
                  self.label_encoder.transform(np.asarray(y_val)))

        target_dim = self.label_encoder.target_dim
        self._build_model(
            n_updates_total=n_updates_total, target_dim=target_dim)
        self.is_trained = True

        prev_best = global_step = 0

        for i in range(self.config.n_epochs):
            for xtrn, mtrn, ytrn in iter_data(
                    *ds_trn,
                    tqdm_desc="Training Epoch {}".format(i),
                    n_batch=n_batch_train,
                    verbose=self.config.verbose):
                self._eval(
                    self.target_loss,
                    self.train_op,
                    feed_dict={
                        self.X: xtrn,
                        self.M: mtrn,
                        self.Y: ytrn,
                        self.do_dropout: int(True)
                    })

                global_step += 1
                if global_step % self.config.val_interval == 0 or \
                global_step == n_updates_total - 1:

                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore")
                        preds = []
                        for xval, mval, yval in iter_data(
                                *ds_val,
                                tqdm_desc="Validation",
                                n_batch=n_batch_train):
                            output = self._eval(
                                self.predict_op,
                                feed_dict={
                                    self.X: xval,
                                    self.M: mval,
                                    self.do_dropout: int(False)
                                })
                            preds.append(
                                self.label_encoder.inverse_transform(
                                    output.get(self.predict_op)))

                    preds = np.concatenate(preds).tolist()
                    score = self.scorer(y_val, preds)
                    if score > prev_best:
                        prev_best = score
                        if self.config.autosave_path:
                            self.save(self.config.autosave_path)
                    lossplot.update(score)
                    lossplot.plot()
        return prev_best

In [0]:
!wget -q -nc http://nlp.stanford.edu/data/glove.6B.zip
!unzip -qn glove.6B.zip

def glove_matrix(tok, emb_dim=100, n_special=1):
    with open('glove.6B.{}d.txt'.format(emb_dim)) as file:
        emb_matrix = np.zeros((len(tok.word_index) + n_special, emb_dim))
        for line in file:
            values = line.split()
            word = values[0]
            token_id = tok.word_index.get(word)
            if token_id is not None:
                emb_matrix[token_id] = values[1:]
    return emb_matrix

In [0]:
from keras import (layers, optimizers, models, losses, 
                   metrics, callbacks, preprocessing)

def model_input_with_optional_embedding(
    maxlen, vocab_size, emb_dim, use_embedding):
    inp = layers.Input(shape=(maxlen, ))
    z = layers.Embedding(vocab_size + 1, emb_dim)(inp)
    if not use_embedding:
        inp = layers.Input(shape=(maxlen, emb_dim, ))
        z = inp
    return inp, z

def blendcnn(maxlen,
             vocab_size,
             output_dim,
             emb_dim=100,
             n_layers=8,
             n_filters=100,
             kernel_size=5,
             use_embedding=True):

    inp, z = model_input_with_optional_embedding(
        maxlen, vocab_size, emb_dim, use_embedding)
        
    gmaxpools = []
    for i in range(n_layers):
        z = layers.Conv1D(
            n_filters, kernel_size, padding='same', activation='relu')(z)
        gmaxpools.append(layers.GlobalMaxPool1D()(z))
        z = layers.MaxPool1D()(z)
    z = layers.Concatenate()(gmaxpools)

    for i in range(2):
        z = layers.Dense(768, activation='relu')(z)

    z = layers.Dense(output_dim)(z)
    return models.Model(inp, z)


def stackcnn(maxlen,
             vocab_size,
             output_dim,
             emb_dim=100,
             n_layers=8,
             n_filters=100,
             kernel_size=5,
             use_embedding=True):
    inp, z = model_input_with_optional_embedding(
        maxlen, vocab_size, emb_dim, use_embedding)

    for i in range(n_layers):
        z = layers.Conv1D(
            n_filters, kernel_size, padding='same', activation='relu')(z)
        z = layers.MaxPool1D()(z)

    z = layers.GlobalMaxPool1D()(z)
    for i in range(2):
        z = layers.Dense(100, activation='relu')(z)

    z = layers.Dense(output_dim)(z)
    return models.Model(inp, z)


def kimcnn(maxlen,
           vocab_size,
           output_dim,
           emb_dim=100,
           n_filters=100,
           kernel_sizes=[3, 4, 5],
           n_layers=1,
           use_embedding=True):

    inp, z = model_input_with_optional_embedding(
        maxlen, vocab_size, emb_dim, use_embedding)

    conv_blocks = []
    for sz in kernel_sizes:
        conv = layers.Conv1D(
            n_filters, kernel_size=sz, padding='same', activation='relu')(z)
        conv = layers.GlobalMaxPool1D()(conv)
        conv_blocks.append(conv)
    z = layers.Concatenate()(conv_blocks)
    z = layers.Dense(output_dim)(z)
    return models.Model(inp, z)


def bilstm(maxlen,
           vocab_size,
           output_dim,
           emb_dim=100,
           n_filters=100,
           n_layers=2,
           use_embedding=True):
    
    inp, z = model_input_with_optional_embedding(
        maxlen, vocab_size, emb_dim, use_embedding)
    
    for i in range(n_layers - 1):
        z = layers.Bidirectional(
            layers.CuDNNLSTM(n_filters, return_sequences=True))(z)
    z = layers.Bidirectional(layers.CuDNNLSTM(n_filters))(z)
    z = layers.Dense(output_dim)(z)
    return models.Model(inp, z)


class KerasSaver(callbacks.Callback):
    def __init__(self, val_data, emb_labels=None, score_threshold=0):
        super().__init__()
        self.val_data = val_data
        self.prev_best = 0
        self.score_threshold = score_threshold
        self.lossplot = LossPlot()
        self.emb_labels = emb_labels
        self.scorer = self.logits_scorer
        if emb_labels:
            self.scorer = self.emb_labels_scorer

    def logits_scorer(self, target, pred):
        return f1_score(target, np.argmax(pred, axis=1), average='weighted')

    def emb_labels_scorer(self, target, pred):
        nearest_labels = np.argmax(
            cosine_similarity(pred, self.emb_labels), axis=1)
        return f1_score(target, nearest_labels, average='weighted')

    def validate_and_plot(self):
        x, y = self.val_data
        score = self.scorer(y, self.model.predict(x))
        self.lossplot.update(score)
        if score > self.prev_best:
            models.save_model(self.model, 'model')
            self.prev_best = score
            self.lossplot.plot()
            print('Best score: {}'.format(self.prev_best))

    def on_epoch_end(self, epoch, logs=None):
        self.validate_and_plot()

    def on_batch_end(self, batch, logs=None):
        if self.prev_best > self.score_threshold:
            self.validate_and_plot()

###Loading data

In [0]:
path = shell_format(PATH/'{}_csv.tar.gz'.format(TASK))
!tar -xzf {path}

In [0]:
import pandas as pd
data = pd.read_csv('{}_csv/train.csv'.format(TASK), header=None)
data.dropna(inplace=True)
data['data'] = data[1] + ' ' + data[2]
if TASK == 'yahoo_answers': 
    data.data = data.data + ' ' + data[3]
data['target'] = data[0] - 1  # Convert to zero indexing
classes = {
    'ag_news': ['world', 'sports', 'business', 'science technology'],
    'dbpedia': ['company', 'educational institution', 'artist', 'athlete',
                'office holder', 'mean of transportation', 'building', 
                'natural place', 'village', 'animal', 'plant', 'album', 'film', 
                'written work'],
    'yahoo_answers': ['society culture', 'science mathematics', 'health',
                      'education reference', 'computers internet', 'sports',
                      'business finance', 'entertainment music', 
                      'family relationships', 'politics government']
}
x_trn, x_val, x_tra, y_trn, y_val, y_tra = split_data(data)

###Train simple classifiers as baseline

In [0]:
# fastText
!wget -q -nc https://github.com/facebookresearch/fastText/archive/v0.1.0.zip
!unzip -qn v0.1.0.zip
%cd fastText-0.1.0
!make -s
%cd -

from string import ascii_letters
def format_x(string):
    return ''.join([c for c in string if c in (ascii_letters + ' ')]).lower()

def format_example(x, y):
    return '__label__{} {}'.format(x, y)

def write_fasttext_data(x, y, filename):
    open(filename, 'w').write('\n'.join([
        format_example(y[i], format_x(x[i])) for i in range(len(x))
    ]))
        
write_fasttext_data(x_trn, y_trn, 'trn')
write_fasttext_data(x_val, y_val, 'val')

ft = "fastText-0.1.0/fasttext"
!{ft} supervised -input trn -output model -epoch 2048
!{ft} test model.bin val

In [0]:
# SVM
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import *
from sklearn.svm import *
from sklearn.pipeline import make_pipeline

# pipe = make_pipeline(TfidfVectorizer(), RidgeClassifier())
pipe = make_pipeline(TfidfVectorizer(), LinearSVC())
pipe.fit(x_trn, y_trn)
f1_score(y_val, pipe.predict(x_val), average='weighted')

###Fine-tune OpenAI Classifier

In [0]:
model = OpenAIClassifier(autosave_path='model', lm_loss_coef=0.5)
model.fit(x_trn, x_val, y_trn, y_val)

In [0]:
path = shell_format(PATH/'{}_openai_classifier'.format(TASK))
!cp -a model {path}

###Setup for distillation

In [0]:
# Generate distillation logits
model = OpenAIClassifier.load(PATH/'{}_openai_classifier'.format(TASK))
distill_logits = process_proba(model.predict_proba(np.hstack([x_trn, x_tra])))
np.save(PATH/'{}_y_combined_logits_trn.npy'.format(TASK), distill_logits)

In [0]:
# Fit tokenizer on dataset
import pickle
from keras.preprocessing.text import Tokenizer
tok = Tokenizer()
tok.fit_on_texts(x_trn)
pickle.dump(tok, open(PATH/'{}_tok.pkl'.format(TASK), 'wb'), 
            protocol=pickle.HIGHEST_PROTOCOL)

###Train student model

In [0]:
DISTILL_MODE = "logits" #@param ["none", "logits"]  
#@markdown 'none' -> original labels, logits -> normal distillation
STUDENT_MODEL = "8stackcnn" #@param ["3blendcnn", "8blendcnn", "1kimcnn", "2bilstm", "3stackcnn", "8stackcnn"]
#@markdown Integer prefix denotes number of layers
USE_PSEUDO_LOGITS = True #@param ["True", "False"] {type:"raw"}
#@markdown 'True' -> Labelled + Unlabelled data, 'False' -> Only Labelled data

print('Distill mode:', DISTILL_MODE)
print('Student model:', STUDENT_MODEL)
print('Use pseudo logits:', USE_PSEUDO_LOGITS)

In [0]:
import pickle
from keras.utils import to_categorical
from keras.preprocessing.sequence import pad_sequences

def choose_x_trn(x_trn, x_tra, distill_mode, use_pseudo_logits):
    if (use_pseudo_logits and distill_mode != 'none'):
        return np.hstack([x_trn, x_tra])
    else:
        return x_trn

def choose_y_trn(y_trn, task, distill_mode, use_pseudo_logits):
    if distill_mode == 'none':
        return to_categorical(y_trn)
    else:
        pseudo_logits = np.load(
            PATH/'{}_y_combined_{}_trn.npy'.format(task, distill_mode))
        if use_pseudo_logits:
            return pseudo_logits
        else:
            return pseudo_logits[:len(y_trn)]
    
MAXLEN = 1000
tok = pickle.load(open(PATH/'{}_tok.pkl'.format(TASK), 'rb'))

data_trn = [pad_sequences(tok.texts_to_sequences(
    choose_x_trn(x_trn, x_tra, DISTILL_MODE, USE_PSEUDO_LOGITS)), MAXLEN), 
    choose_y_trn(y_trn, TASK, DISTILL_MODE, USE_PSEUDO_LOGITS)]
data_val = [pad_sequences(tok.texts_to_sequences(x_val), MAXLEN), y_val]

In [0]:
import warnings
architecture = {
    'blendcnn':blendcnn, 
    'kimcnn':kimcnn, 
    'bilstm':bilstm,
    'stackcnn':stackcnn
}[STUDENT_MODEL[1:]]

model = architecture(
    maxlen=MAXLEN, vocab_size=len(tok.word_index), 
    output_dim=data_trn[1].shape[1], n_layers=int(STUDENT_MODEL[0])
)
model.compile(optimizers.Adam(), losses.mean_absolute_error)
model.layers[1].set_weights([glove_matrix(tok)])
    
saver = KerasSaver(data_val)
with warnings.catch_warnings():
    # F-score undefined and slow method on_batch_end warnings
    warnings.simplefilter('ignore')  
    model.fit(*data_trn, epochs=10, batch_size=64, 
              verbose=2, callbacks=[saver])

###Convert student model to TFLite

In [0]:
!git clone -q https://github.com/amir-abdi/keras_to_tensorflow.git

In [0]:
architecture = {
    'blendcnn':blendcnn, 
    'kimcnn':kimcnn, 
    'bilstm':bilstm,
    'stackcnn':stackcnn
}[STUDENT_MODEL[1:]]

MAXLEN = 1000
model = architecture(
    # Standard vocab_size and output_dim for consistency across models
    maxlen=MAXLEN, vocab_size=20000, 
    output_dim=10, n_layers=int(STUDENT_MODEL[0])
)
model.summary()
model.save('model')

In [0]:
saved_model_name = 'model'
!python keras_to_tensorflow/keras_to_tensorflow.py \
--input_model={saved_model_name} \
--output_model={saved_model_name}.pb

In [0]:
import keras
keras.backend.clear_session()
model = models.load_model(saved_model_name)

In [0]:
from tensorflow.contrib import lite
converter = lite.TFLiteConverter.from_frozen_graph(
    graph_def_file='{}.pb'.format(saved_model_name), 
    input_arrays=[model.input.op.name], 
    output_arrays=[model.output.op.name]
)
for use_quantize in [True, False]:
    print('Use quantize:', use_quantize)
    converter.post_training_quantize=use_quantize
    converter.allow_custom_ops=True
    tflite_model = converter.convert()
    open('{}.tflite'.format(saved_model_name), 'wb').write(tflite_model)
    !du -sh model.tflite