# **Pretrained BERT with adaptors (additional bottlenecks inside the Transformer Block ![alt text](https://drive.google.com/uc?id=1ZeVWel3DetaY3Y9YUc3BczS-Q4FP0Pfj))**
For more detail, please refer to the original paper @ "Parameter-Efficient Transfer Learning for BERT" Houlsby et al., 2019


In [9]:
!pip install pytorch-pretrained-bert pytorch-ignite ipdb

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)
[K     |████████████████████████████████| 133kB 3.4MB/s 
[?25hCollecting pytorch-ignite
[?25l  Downloading https://files.pythonhosted.org/packages/35/55/41e8a995876fd2ade29bdba0c3efefa38e7d605cb353c70f3173c04928b5/pytorch_ignite-0.3.0-py2.py3-none-any.whl (103kB)
[K     |████████████████████████████████| 112kB 10.7MB/s 
[?25hCollecting ipdb
  Downloading https://files.pythonhosted.org/packages/2c/bb/a3e1a441719ebd75c6dac8170d3ddba884b7ee8a5c0f9aefa7297386627a/ipdb-0.13.2.tar.gz
Building wheels for collected packages: ipdb
  Building wheel for ipdb (setup.py) ... [?25l[?25hdone
  Created wheel for ipdb: filename=ipdb-0.13.2-cp36-none-any.whl size=10522 sha256=3cfc15906e2f7dddcebca50341e891ad10474fbc42c009b5fa6385a92058064a
  Stored in directory: /root/.cache/pip/whee

In [0]:
## Transformer Blocks


import torch
import torch.nn as nn

class Transformer(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal):
        super().__init__()
        self.causal = causal
        self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim)
        self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
        self.dropout = nn.Dropout(dropout)

        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList()
        for _ in range(num_layers):
            self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
            self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                    nn.ReLU(),
                                                    nn.Linear(hidden_dim, embed_dim)))
            self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))

    def forward(self, x, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)
        h = self.dropout(h)

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
                                                                       self.layer_norms_2, self.feed_forwards):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            h = x + h

            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h
            print(h)
        return h

In [0]:
##Transformer with Adapter

class TransformerWithAdapters(Transformer): # Inherit from the pretrained Transformer Block  to have all the modules
    def __init__(self, adapters_dim, embed_dim, hidden_dim, num_embeddings, num_max_positions,
                 num_heads, num_layers, dropout, causal):
        """ Transformer with adapters (small bottleneck layers) """
        super().__init__(embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers,
                         dropout, causal)
        self.adapters_1 = nn.ModuleList()
        self.adapters_2 = nn.ModuleList()
        for _ in range(num_layers):
          
            self.adapters_1.append(nn.Sequential(nn.Linear(embed_dim, adapters_dim),
                                                 nn.ReLU(),
                                                 nn.Linear(adapters_dim, embed_dim))) ## First adapter: bottleneck layers with 2 linear layers and a ReLU(dim is usually small
                                                                                        #32, 64, 128, 256
            
            self.adapters_2.append(nn.Sequential(nn.Linear(embed_dim, adapters_dim), ## Second adapter
                                                 nn.ReLU(),
                                                 nn.Linear(adapters_dim, embed_dim)))

    def forward(self, x, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)  
        h = self.dropout(h)

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        for (layer_norm_1, attention, adapter_1, layer_norm_2, feed_forward, adapter_2)\
                          in zip(self.layer_norms_1, self.attentions,    self.adapters_1,
                                 self.layer_norms_2, self.feed_forwards, self.adapters_2):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            
            x = adapter_1(x) + x  # Add an adapter with a skip-connection after attention module
            
            h = x + h

            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            
            x = adapter_2(x) + x  # Add an adapter with a skip-connection after feed-forward module
            
            h = x + h
        return h

In [0]:
## Adding Classification layer on top of the blocks



class TransformerWithClfHeadAndAdapters(nn.Module):
    def __init__(self, config, fine_tuning_config):
        """ Transformer with a classification head and adapters. """
        super().__init__()
        self.config = fine_tuning_config
        self.transformer = TransformerWithAdapters(fine_tuning_config.adapters_dim, config.embed_dim, config.hidden_dim,
                                                   config.num_embeddings, config.num_max_positions, config.num_heads,
                                                   config.num_layers, fine_tuning_config.dropout, causal=not config.mlm)

        self.classification_head = nn.Linear(config.embed_dim, fine_tuning_config.num_classes)
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, clf_tokens_mask, lm_labels=None, clf_labels=None, padding_mask=None):
        hidden_states = self.transformer(x, padding_mask) #return hidden_states with dimensions of (embed_dim, seq_length, batch)

        clf_tokens_states = (hidden_states * clf_tokens_mask.unsqueeze(-1).float()).sum(dim=0)
        clf_logits = self.classification_head(clf_tokens_states)

        if clf_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(clf_logits.view(-1, clf_logits.size(-1)), clf_labels.view(-1))
            return clf_logits, loss

        return clf_logits

In [12]:
from pytorch_pretrained_bert import BertTokenizer, cached_path

tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

100%|██████████| 213450/213450 [00:00<00:00, 2355786.33B/s]


In [0]:
## Configuration
from collections import namedtuple

Config = namedtuple('Config',
  field_names="embed_dim, hidden_dim, num_max_positions, num_embeddings      , num_heads, num_layers," 
              "dropout, initializer_range, batch_size, lr, max_norm, n_epochs, n_warmup,"
              "mlm, gradient_accumulation_steps, device, log_dir, dataset_cache")
args = Config( 410      , 2100      , 256              , len(tokenizer.vocab), 10       , 16        ,
               0.1    , 0.02             , 16        , 2.5e-4, 1.0 , 50     , 1000    ,
               False, 4, "cuda" if torch.cuda.is_available() else "cpu", "./"   , "./dataset_cache.bin")

AdaptationConfig = namedtuple('AdaptationConfig',
  field_names="adapters_dim, num_classes, dropout, initializer_range, batch_size, lr, max_norm, n_epochs,"
              "n_warmup, valid_set_prop, gradient_accumulation_steps, device,"
              "log_dir, dataset_cache")
adapt_args = AdaptationConfig(
               32         , 6          , 0.1    , 0.02             , 16        , 6.5e-4, 1.0   , 3,
               10      , 0.1           , 1, "cuda" if torch.cuda.is_available() else "cpu",
               "./"   , "./dataset_cache.bin")

In [14]:
# If you have pretrained a model in the first section, you can use its weigths
# state_dict = model.state_dict()

# Otherwise, just load pretrained model weigths (and reload the training config as well)
state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                    "naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu')
args = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                    "naacl-2019-tutorial/model_training_args.bin"))

adaptation_model = TransformerWithClfHeadAndAdapters(config=args, fine_tuning_config=adapt_args)
adaptation_model.to(adapt_args.device)

incompatible_keys = adaptation_model.load_state_dict(state_dict, strict=False)
print(f"Parameters discarded from the pretrained model: {incompatible_keys.unexpected_keys}")
print(f"Parameters added in the adaptation model: {incompatible_keys.missing_keys}")

100%|██████████| 201626725/201626725 [00:04<00:00, 43081899.88B/s]
100%|██████████| 837/837 [00:00<00:00, 580941.99B/s]


Parameters discarded from the pretrained model: ['lm_head.weight']
Parameters added in the adaptation model: ['transformer.adapters_1.0.0.weight', 'transformer.adapters_1.0.0.bias', 'transformer.adapters_1.0.2.weight', 'transformer.adapters_1.0.2.bias', 'transformer.adapters_1.1.0.weight', 'transformer.adapters_1.1.0.bias', 'transformer.adapters_1.1.2.weight', 'transformer.adapters_1.1.2.bias', 'transformer.adapters_1.2.0.weight', 'transformer.adapters_1.2.0.bias', 'transformer.adapters_1.2.2.weight', 'transformer.adapters_1.2.2.bias', 'transformer.adapters_1.3.0.weight', 'transformer.adapters_1.3.0.bias', 'transformer.adapters_1.3.2.weight', 'transformer.adapters_1.3.2.bias', 'transformer.adapters_1.4.0.weight', 'transformer.adapters_1.4.0.bias', 'transformer.adapters_1.4.2.weight', 'transformer.adapters_1.4.2.bias', 'transformer.adapters_1.5.0.weight', 'transformer.adapters_1.5.0.bias', 'transformer.adapters_1.5.2.weight', 'transformer.adapters_1.5.2.bias', 'transformer.adapters_1.6.

In [21]:
## Only train embeddings, classification head, adapter 1, and adapter 2, other weights are not being trained
for name, param in adaptation_model.named_parameters():
    if 'embedding' not in name and 'classification' not in name and 'adapter_1' not in name and 'adapter_2' not in name:
        param.detach_()
        param.requires_grad = False
    
    else:
        param.requires_grad = True

full_parameters = sum(p.numel() for p in adaptation_model.parameters())
trained_parameters = sum(p.numel() for p in adaptation_model.parameters() if p.requires_grad )

print(f"We will train {trained_parameters:3e} parameters out of {full_parameters:3e},"
      f" i.e. {100 * trained_parameters/full_parameters:.2f}%")

We will train 1.199579e+07 parameters out of 5.125265e+07, i.e. 23.41%


In [0]:
import os
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage, Accuracy
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar

In [31]:
import random
from torch.utils.data import TensorDataset, random_split

dataset_file = cached_path("https://s3.amazonaws.com/datasets.huggingface.co/trec/"
                           "trec-tokenized-bert.bin")
datasets = torch.load(dataset_file)

for split_name in ['train', 'test']:

    # Trim the samples to the transformer's input length minus 1 & add a classification token
    datasets[split_name] = [x[:args.num_max_positions-1] + [tokenizer.vocab['[CLS]']]
                            for x in datasets[split_name]]

    # Pad the dataset to max length
    padding_length = max(len(x) for x in datasets[split_name])
    datasets[split_name] = [x + [tokenizer.vocab['[PAD]']] * (padding_length - len(x))
                            for x in datasets[split_name]]

    # Convert to torch.Tensor and gather inputs and labels
    tensor = torch.tensor(datasets[split_name], dtype=torch.long)
    labels = torch.tensor(datasets[split_name + '_labels'], dtype=torch.long)
    datasets[split_name] = TensorDataset(tensor, labels)

# Create a validation dataset from a fraction of the training dataset
valid_size = int(adapt_args.valid_set_prop * len(datasets['train']))
train_size = len(datasets['train']) - valid_size
valid_dataset, train_dataset = random_split(datasets['train'], [valid_size, train_size])

train_loader = DataLoader(train_dataset, batch_size=adapt_args.batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=adapt_args.batch_size, shuffle=False)
test_loader = DataLoader(datasets['test'], batch_size=adapt_args.batch_size, shuffle=False)

100%|██████████| 250835/250835 [00:00<00:00, 2186940.95B/s]


In [34]:
## training loop
optimizer = torch.optim.Adam(adaptation_model.parameters(), lr=adapt_args.lr)

##training function and trainer
def update(engine, batch):
    adaptation_model.train()
    batch, labels = (t.to(adapt_args.device) for t in batch)
    inputs = batch.transpose(0,1).contiguous() #to shape [seq_length, batch]

    _, loss = adaptation_model(inputs, clf_tokens_mask=(inputs == tokenizer.vocab['[CLS]']), clf_labels=labels,
                                                        padding_mask = (batch ==tokenizer.vocab['[PAD]']))
    
    loss = loss/adapt_args.gradient_accumulation_steps
    loss.backward()
    if engine.state.iteration % adapt_args.gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
    return loss.item()
trainer = Engine(update)


def inference(engine, batch):
    adaptation_model.eval()
    with torch.no_grad():
        batch, labels = (t.to(adapt_args.device) for t in batch)
        inputs = batch.transpose(0, 1).contiguous()  # to shape [seq length, batch]
        clf_logits = adaptation_model(inputs, clf_tokens_mask=(inputs == tokenizer.vocab['[CLS]']),
                                      padding_mask=(batch == tokenizer.vocab['[PAD]']))
    return clf_logits, labels
evaluator = Engine(inference)

# Attache metric to evaluator & evaluation to trainer: evaluate on valid set after each epoch
Accuracy().attach(evaluator, "accuracy")
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(valid_loader)
    print(f"Validation Epoch: {engine.state.epoch} Error rate: {100*(1 - evaluator.state.metrics['accuracy'])}")

# Learning rate schedule: linearly warm-up to lr and then to zero
scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (adapt_args.n_warmup, adapt_args.lr),
                                              (len(train_loader)*adapt_args.n_epochs, 0.0)])
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# Add progressbar with loss
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

# Save checkpoints and finetuning config
checkpoint_handler = ModelCheckpoint(adapt_args.log_dir, 'finetuning_checkpoint', save_interval=1, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': adaptation_model})
torch.save(args, os.path.join(adapt_args.log_dir, 'fine_tuning_args.bin'))







In [35]:
trainer.run(train_loader, max_epochs=adapt_args.n_epochs)

HBox(children=(IntProgress(value=0, max=307), HTML(value='')))

Validation Epoch: 1 Error rate: 29.357798165137616



HBox(children=(IntProgress(value=0, max=307), HTML(value='')))

Validation Epoch: 2 Error rate: 21.65137614678899



HBox(children=(IntProgress(value=0, max=307), HTML(value='')))

Validation Epoch: 3 Error rate: 18.899082568807334



State:
	iteration: 921
	epoch: 3
	epoch_length: 307
	max_epochs: 3
	output: 0.2108326107263565
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: 12

In [36]:
evaluator.run(test_loader)
print(f"Test Results - Error rate: {100*(1.00 - evaluator.state.metrics['accuracy']):.3f}")

Test Results - Error rate: 15.200
