In [2]:
import torch

from util_funcs import *
from data_processors import *
import logging
from trainer import Trainer, DeepTwistTrainer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [74]:
import re
from glob import glob

f_names = glob('output/deeptwist/*/pytorch_model.bin')

def parse_fname(f_name):
    m = re.match('output/(deeptwist)/(prune|svd|diff_prune)([0-9]*)_([0-9\-_]*)/pytorch_model[.]bin', 
             f_name)
    return m.groups() + (f_name,)
    

models_df = pd.DataFrame([parse_fname(f_name) for f_name in f_names], 
             columns=['framework', 'type', 'param', 'date', 'f_name'])
models_df['date'] = pd.to_datetime(models_df.date, format='%Y-%m-%d_%H_%M')

In [76]:
processor = processors['sst-2']()
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO, 
                    filename=f"log_dir/{get_log_name()}.txt")
logger = logging.getLogger(__name__)

runtime_config = dict(data_dir = "glue_data/SST-2",
                      bert_model = "bert-base-uncased",
                      output_mode = "classification",
                      max_seq_length = 64,
                      local_rank = -1,
                      batch_size = 32,
                      num_train_epochs = 32,
                      do_lower_case=True,
                      do_train=True,
                      train_batch_size=32,
                      gradient_accumulation_steps = 1,
                      n_gpu = 1,
                      learning_rate = 5e-5,
                      logger=logger,
                      warmup_proportion = 0.1)
locals().update(runtime_config)
assert train_batch_size == batch_size

label_list, num_labels, tokenizer, train_examples, \
           num_train_optimization_steps, train_dataloader = get_data(processor, runtime_config)

eval_examples = processor.get_dev_examples(data_dir)
eval_dataloader = get_dataloader(
                eval_examples, label_list,
                tokenizer, eval_data=True,
                **runtime_config)

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
loss_fn = CrossEntropyLoss()

base_model = BertForSequenceClassification.from_pretrained(
            bert_model, num_labels=num_labels).cpu().state_dict()

decorator = to_bert(base_model)
diff_prune = decorator(weight_prune)

model = BertForSequenceClassification.from_pretrained(
    bert_model, num_labels=num_labels).to(device)

optimizer = get_optimizer(
    model, num_train_optimization_steps=num_train_optimization_steps, 
    **runtime_config)

tensorboard_log_dir = "tensorboard_data/"
tb_writer = SummaryWriter(log_dir=tensorboard_log_dir)
trainer = DeepTwistTrainer(
    model=model, data=train_dataloader, val_data=eval_dataloader,
    num_labels=num_labels,
    output_dir=output_dir,
    twist_frequency=10,
    loss_fn=loss_fn,
    optimizer=optimizer,
    distort=diff_prune,
    twist_args={'p': 0.1},
    writer=tb_writer,
    device=device,
    **runtime_config,
)

In [77]:
def eval_model(model):
    trainer.model = model
    return trainer.validate()

def eval_bert_state_dict(f_name, model=model):
    try:
        state_dict = torch.load(f_name)
        model.load_state_dict(state_dict)
        return eval_model(model)
    except RuntimeError:
        return (np.NaN, np.NaN)

In [78]:
val_loss = []
val_acc = []
for f_name in tqdm(f_names):
    loss, acc = eval_bert_state_dict(f_name)
    val_loss.append(loss)
    val_acc.append(acc)

HBox(children=(IntProgress(value=0, max=19), HTML(value='')))




In [79]:
models_df['val_loss'] = val_loss
models_df['val_acc'] = val_acc
models_df = models_df[['framework', 'type', 'param', 'val_loss', 'val_acc', 'date', 'f_name']]
models_df.sort_values(by='val_acc', ascending=False).reset_index(drop=True).to_csv('report/result_table.csv')

In [83]:
models_df.sort_values(by='val_acc', ascending=False).reset_index(drop=True) \
    .style.bar(subset=['val_acc'], color='lightgreen', vmin=0, vmax=1)

Unnamed: 0,framework,type,param,val_loss,val_acc,date,f_name
0,deeptwist,diff_prune,60,0.20893,0.919725,2019-04-29 03:23:00,output/deeptwist/diff_prune60_2019-04-29_03_23/pytorch_model.bin
1,deeptwist,diff_prune,80,0.22175,0.916284,2019-04-29 07:46:00,output/deeptwist/diff_prune80_2019-04-29_07_46/pytorch_model.bin
2,deeptwist,diff_prune,90,0.216181,0.916284,2019-04-29 13:48:00,output/deeptwist/diff_prune90_2019-04-29_13_48/pytorch_model.bin
3,deeptwist,diff_prune,95,0.220871,0.916284,2019-04-29 21:58:00,output/deeptwist/diff_prune95_2019-04-29_21_58/pytorch_model.bin
4,deeptwist,diff_prune,99,0.221661,0.912844,2019-04-30 05:42:00,output/deeptwist/diff_prune99_2019-04-30_05_42/pytorch_model.bin
5,deeptwist,diff_prune,40,0.203041,0.909404,2019-04-28 23:29:00,output/deeptwist/diff_prune40_2019-04-28_23_29/pytorch_model.bin
6,deeptwist,prune,20,0.36991,0.889908,2019-04-26 20:44:00,output/deeptwist/prune20_2019-04-26_20_44/pytorch_model.bin
7,deeptwist,prune,20,0.291031,0.886468,2019-04-28 00:57:00,output/deeptwist/prune20_2019-04-28_00_57/pytorch_model.bin
8,deeptwist,svd,200,0.45869,0.845183,2019-04-27 17:31:00,output/deeptwist/svd200_2019-04-27_17_31/pytorch_model.bin
9,deeptwist,prune,40,0.424183,0.845183,2019-04-28 10:22:00,output/deeptwist/prune40_2019-04-28_10_22/pytorch_model.bin
