In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nn_pruning.sparse_trainer import SparseTrainer
from nn_pruning.patch_coordinator import SparseTrainingArguments, ModelPatchingCoordinator
from nn_pruning.inference_model_patcher import optimize_model

import torch
import datasets
import numpy as np
import transformers
datasets.logging.set_verbosity_error()
transformers.logging.set_verbosity_error()

from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using transformers v{transformers.__version__} and datasets v{datasets.__version__} and torch v{torch.__version__}")
print(f"Running on device: {device}")

Using transformers v4.3.3 and datasets v1.4.1 and torch v1.8.0
Running on device: cuda


In [None]:
ds = load_dataset("sms_spam")["train"].train_test_split()
ds

DatasetDict({
    train: Dataset({
        features: ['sms', 'label'],
        num_rows: 4180
    })
    test: Dataset({
        features: ['sms', 'label'],
        num_rows: 1394
    })
})

In [None]:
ds["train"][0]

{'label': 0,
 'sms': "Your board is working fine. The issue of overheating is also reslove. But still software inst is pending. I will come around 8'o clock.\n"}

In [None]:
ds = ds.map(lambda x : {"labels": x["labels"][0]})

## Tokenize and encode

In [None]:
bert_ckpt = "bert-base-uncased"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_ckpt)

In [None]:
ds_enc = ds.map(lambda x : bert_tokenizer(x["sms"], truncation=True), batched=True)

In [None]:
accuracy_score = load_metric('accuracy')

def compute_metrics(pred):
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy_score.compute(predictions=predictions, references=labels)

In [None]:
class PruningTrainer(SparseTrainer, Trainer):
    def __init__(self, sparse_args, *args, **kwargs):
        Trainer.__init__(self, *args, **kwargs)
        SparseTrainer.__init__(self, sparse_args)
        
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        We override the default loss in SparseTrainer because it throws an 
        error when run without distillation
        """
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        # We don't use .loss here since the model may return tuples instead of ModelOutput.
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        self.metrics["ce_loss"] += float(loss)
        self.loss_counter += 1
        return (loss, outputs) if return_outputs else loss

In [None]:
sparse_args = SparseTrainingArguments()

d = {
  "initial_warmup": 1,
  "final_warmup": 3,
  "initial_threshold": 1.0, # When using topK set to 1 (initial density). With sigmoied_threshold, use 0.0 (cutoff)
  "final_threshold": 0.5, # When using topK, this is the final density. With sigmoied_threshold, use 0.1 (final cutoff, which is a bit arbitrary of course, set regularization_final_lambda to adjust final sparsity)
  "dense_pruning_method": "topK:1d_alt", #"sigmoied_threshold:1d_alt",
  "dense_block_rows":1,
  "dense_block_cols":1,
  "dense_lambda":0.25,
  "attention_pruning_method": "topK", #"sigmoied_threshold",
  "attention_block_rows":32,
  "attention_block_cols":32,
  "attention_lambda":1.0,
  "ampere_pruning_method": "disabled",
  "mask_init": "constant",
  "mask_scale": 0.0,
  "regularization": None, # "l1" when pruning_methods are sigmoied_threshold
  "regularization_final_lambda": 20, # To be tweaked to adjust sparsity : the higher, the more sparse. Try different values by multiplying by 2x several times
  "distil_teacher_name_or_path":None,
  "distil_alpha_ce": 0.1,
  "distil_alpha_teacher": 0.9,
  "attention_output_with_dense": 0,
  "layer_norm_patch" : 0,
  "gelu_patch":0
}

for k,v in d.items():
    if hasattr(sparse_args, k):
        setattr(sparse_args, k, v)
    else:
        print(f"sparse_args does not have an argument {k}")



In [None]:
batch_size = 32
learning_rate = 2e-5
num_train_epochs = 3
logging_steps = len(ds_enc["train"]) // batch_size
# warmup for 10% of training steps
warmup_steps = logging_steps * num_train_epochs * 0.1

args = TrainingArguments(
    output_dir='checkpoints',
    evaluation_strategy='epoch',
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    learning_rate=learning_rate,
    weight_decay=0.01,
    logging_steps=logging_steps,
    disable_tqdm=False,
    report_to=None
)

bert_model = AutoModelForSequenceClassification.from_pretrained(bert_ckpt, num_labels=2).to(device)

In [None]:
mpc = ModelPatchingCoordinator(
    sparse_args=sparse_args, 
    device=device, 
    cache_dir="checkpoints", 
    logit_names="logits", 
    teacher_constructor=AutoModelForSequenceClassification)


In [None]:
mpc.patch_model(bert_model)

# bert_model.save_pretrained("patched")

LAYER NORM PATCH {'patched': 72}


<nn_pruning.training_patcher.BertLinearModelPatcher at 0x7fdedeafb460>

In [None]:
trainer = PruningTrainer(
    sparse_args=sparse_args,
    args=args,
    model=bert_model,
    train_dataset=ds_enc["train"],
    eval_dataset=ds_enc["test"],
    tokenizer=bert_tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.set_patch_coordinator(mpc)

In [None]:
trainer.evaluate()

{'eval_loss': 3.1822924613952637,
 'eval_accuracy': 0.0,
 'eval_runtime': 14.01,
 'eval_samples_per_second': 99.5,
 'eval_threshold': 0.5,
 'eval_ampere_temperature': 20.0,
 'eval_regu_lambda': 20.0,
 'ce_loss': 3.182358671318401}

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second,Threshold,Ampere Temperature,Regu Lambda,Loss
1,0.094,0.193136,0.988522,16.0435,86.889,0.5,20.0,20.0,0.188796
2,0.0326,0.034705,0.991392,16.1713,86.202,0.5,20.0,20.0,0.033099
3,0.0153,0.040123,0.991392,16.4002,84.999,0.5,20.0,20.0,0.037339


TrainOutput(global_step=393, training_loss=0.046956961923551394, metrics={'train_runtime': 297.5098, 'train_samples_per_second': 1.321, 'total_flos': 625040476196640, 'eval_threshold': 0.5, 'eval_ampere_temperature': 20.0, 'eval_regu_lambda': 20.0, 'epoch': 3.0})