In [1]:
%load_ext autoreload
%autoreload 2

import logging

import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from torch.nn import CrossEntropyLoss, MSELoss

from util_funcs import *
from data_processors import *
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from tensorboardX import SummaryWriter
from distortions import *

from tqdm import tqdm_notebook as tqdm
from tqdm import trange

from trainer import Trainer, DeepTwistTrainer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
processor = processors['qqp']()
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO, 
                    filename=f"log_dir/{get_log_name()}.txt")
logger = logging.getLogger(__name__)

runtime_config = dict(data_dir = "glue_data/QQP",
                      bert_model = "bert-base-uncased",
                      output_mode = "classification",
                      max_seq_length = 64,
                      local_rank = -1,
                      batch_size = 32,
                      num_train_epochs = 32,
                      do_lower_case=True,
                      do_train=True,
                      train_batch_size=32,
                      gradient_accumulation_steps = 1,
                      n_gpu = 1,
                      learning_rate = 5e-5,
                      logger=logger,
                      warmup_proportion = 0.1)
locals().update(runtime_config)
assert train_batch_size == batch_size

label_list, num_labels, tokenizer, train_examples, \
           num_train_optimization_steps, train_dataloader = get_data(processor, runtime_config)

eval_examples = processor.get_dev_examples(data_dir)
eval_dataloader = get_dataloader(
                eval_examples, label_list,
                tokenizer, eval_data=True,
                **runtime_config)

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
loss_fn = CrossEntropyLoss()

In [None]:
base_model = BertForSequenceClassification.from_pretrained(
            bert_model, num_labels=num_labels).cpu().state_dict()

decorator = to_bert(base_model)
diff_prune = decorator(weight_prune)

In [None]:
for p in [.99]:
    model = BertForSequenceClassification.from_pretrained(
        bert_model, num_labels=num_labels).to(device).train()
    optimizer = get_optimizer(
        model, num_train_optimization_steps=num_train_optimization_steps, 
        **runtime_config)


    tensorboard_log_dir = "tensorboard_data/"
    tb_writer = SummaryWriter(log_dir=tensorboard_log_dir)
    output_dir = f"output/qqp/deeptwist/diff_prune{int(100*p)}_{get_log_name()}/"
    os.mkdir(output_dir)
    
    trainer = DeepTwistTrainer(
        model=model, data=train_dataloader, val_data=eval_dataloader,
        num_labels=num_labels,
        output_dir=output_dir,
        twist_frequency=10,
        loss_fn=loss_fn,
        optimizer=optimizer,
        distort=diff_prune,
        twist_args={'p': p},
        writer=tb_writer,
        device=device,
        **runtime_config,
    )

    trainer.train(num_train_epochs=num_train_epochs, 
                  report_frequency=12, patience=1, report_validation_frequency=100,
                  report_validation=True)