# Fine-tune a Transformer-based architecture on the IMDB Movie Reviews dataset for Sentiment Analysis

## Install dependencies, save requirements.txt

In [1]:
!pip install -q pandas tqdm
!pip install -U torch 
!pip install -q pytorch_transformers pytorch-ignite

Requirement already up-to-date: torch in /opt/conda/lib/python3.6/site-packages (1.1.0)


In [2]:
!rm requirements.txt

In [3]:
%%writefile requirements.txt
pandas
tqdm
torch==1.1.0
pytorch_transformers
pytorch-ignite

Writing requirements.txt


In [1]:
import sys
import os
import logging
from tqdm import tqdm_notebook as tqdm

logger = logging.getLogger()

# text and label column names
TEXT_COL = "text"
LABEL_COL = "label"

# path to data 
DATA_DIR = os.path.abspath('./data')

# path to IMDB
IMDB = os.path.join(DATA_DIR, "aclImdb")

# url to dataset
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

## Download imdb data

In [2]:
import requests
import tarfile

def download_url(url:str, dest:str, overwrite:bool=True, show_progress=True, 
                 chunk_size=1024*1024, timeout=4, retries=5)->None:
    "Download `url` to `dest` unless it exists and not `overwrite`."
    
    dest = os.path.join(dest, os.path.basename(url))
    if os.path.exists(dest) and not overwrite: 
        print(f"File {dest} already exists!")
        return dest

    s = requests.Session()
    s.mount('http://',requests.adapters.HTTPAdapter(max_retries=retries))
    u = s.get(url, stream=True, timeout=timeout)
    try: file_size = int(u.headers["Content-Length"])
    except: show_progress = False
    print(f"Downloading {url}")
    with open(dest, 'wb') as f:
        nbytes = 0
        if show_progress: 
            pbar = tqdm(range(file_size), leave=False)
        try:
            for chunk in u.iter_content(chunk_size=chunk_size):
                nbytes += len(chunk)
                if show_progress: pbar.update(nbytes)
                f.write(chunk)
        except requests.exceptions.ConnectionError as e:
            print(f"Download failed after {retries} retries.")
            import sys;sys.exit(1)
        finally:
            return dest
        
def untar(file_path, dest:str):
    "Untar `file_path` to `dest`"
    print(f"Untar {os.path.basename(file_path)} to {dest}")
    with tarfile.open(file_path) as tf:
        tf.extractall(path=str(dest))


In [32]:
# download imdb dataset
file_path = download_url(url, '/tmp', overwrite=True)

# untar imdb dataset to DATA_DIR
untar(file_path, DATA_DIR)

Downloading https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz


HBox(children=(IntProgress(value=0, max=84125825), HTML(value='')))

Untar aclImdb_v1.tar.gz to /Users/d069049/Develop/transformer-finetuning/data


In [3]:
!ls -lh $IMDB

total 1.7M
-rw-r--r--. 1 7297 1000 882K Jun 11  2011 imdbEr.txt
-rw-r--r--. 1 7297 1000 827K Apr 12  2011 imdb.vocab
-rw-r--r--. 1 7297 1000 4.0K Jun 26  2011 README
drwxr-xr-x. 4 7297 1000  115 Apr 12  2011 test
drwxr-xr-x. 5 7297 1000  183 Jun 26  2011 train


## Read imdb data

In [3]:
import pandas as pd
import re
from pathlib import Path

def clean_html(raw: str):
    "remove html tags and whitespaces"
    cleanr = re.compile('<.*?>')
    clean = re.sub(cleanr, '  ', raw)
    return re.sub(' +', ' ', clean)


def read_imdb(imdb_dir: str, max_lengths={"train": None, "test": None}):
    "Read imdb data to a DataFrame in {'label', 'text'} format."
    imdb_dir = Path(imdb_dir)
    datasets = {}
    
    for t in ["train", "test"]:
        texts, labels = [], []
        for p in ["pos", "neg"]:
            for file in tqdm((imdb_dir/"train"/p).glob("*.txt"), desc=f"reading {t}/{p}"):
                with open(file, 'r') as fin:
                    text = fin.readlines()[0].replace(r'\n', ' ')
                    text = clean_html(text).strip()
                    texts +=  [text]
                    labels += [p]
        df = pd.DataFrame(
            {LABEL_COL: labels, TEXT_COL: texts})
        
        max_length = max_lengths.get(t)
        if max_length is not None and max_length <= len(df):
            # pick max_length
            datasets[t] = df.sample(n=max_length)
        else:
            # just shuffle
            datasets[t] = df.sample(frac=1)
            
    return datasets

In [4]:
MAX_TRAIN = 5000
MAX_TEST = 5000

# read data, 5000-5000 each
datasets = read_imdb(IMDB, max_lengths={"train": MAX_TRAIN, "test": MAX_TEST})

# list of labels
labels = list(set(datasets["train"][LABEL_COL].tolist()))

# labels to integers mapping
label2int = {label: i for i, label in enumerate(labels)}

HBox(children=(IntProgress(value=1, bar_style='info', description='reading train/pos', max=1, style=ProgressSt…




HBox(children=(IntProgress(value=1, bar_style='info', description='reading train/neg', max=1, style=ProgressSt…




HBox(children=(IntProgress(value=1, bar_style='info', description='reading test/pos', max=1, style=ProgressSty…




HBox(children=(IntProgress(value=1, bar_style='info', description='reading test/neg', max=1, style=ProgressSty…




In [5]:
df_train = datasets["train"]
df_test = datasets["test"]

In [6]:
df_train.to_csv("imdb5k_train.csv", index=False)
df_test.to_csv("imdb5k_test.csv", index=False)

In [7]:
dft = pd.read_csv("imdb5k_train.csv")

In [8]:
datasets["train"].head()

Unnamed: 0,label,text
12468,pos,There's a great deal of material from the Mode...
11118,pos,Cliffhanger is a decent action crime adventure...
3410,pos,"In terms of the arts, the 1970s were a very tu..."
5786,pos,"Overall, I agree wholly with Ebert's review. I..."
19117,neg,I watched this movie after having so much of t...


## DataProcessor

In [9]:
import torch
from torch.utils.data import TensorDataset, random_split, DataLoader
import numpy as np
import warnings
from tqdm import tqdm_notebook as tqdm
from typing import List, Tuple

NUM_MAX_POSITIONS = 256
BATCH_SIZE = 32

class TextProcessor:
    
    # special tokens for classification and padding
    CLS = '[CLS]'
    PAD = '[PAD]'
    
    def __init__(self, tokenizer, label2id: dict, num_max_positions:int=512):
        self.tokenizer=tokenizer
        self.label2id = label2id
        self.num_labels = len(label2id)
        self.num_max_positions = num_max_positions
        
    
    def process_example(self, example: Tuple[str, str]):
        "Convert text (example[0]) to sequence of IDs and label (example[1] to integer"
        assert len(example) == 2
        label, text = example[0], example[1]
        assert isinstance(text, str)
        tokens = self.tokenizer.tokenize(text)

        # truncate if too long
        if len(tokens) >= self.num_max_positions:
            tokens = tokens[:self.num_max_positions-1] 
            ids =  self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]]
        # pad if too short
        else:
            pad = [self.tokenizer.vocab[self.PAD]] * (self.num_max_positions-len(tokens)-1)
            ids =  self.tokenizer.convert_tokens_to_ids(tokens) + [self.tokenizer.vocab[self.CLS]] + pad
        
        return ids, self.label2id[label]
    
# download the 'bert-base-cased' tokenizer
from pytorch_transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)

# initialize a TextProcessor
processor = TextProcessor(tokenizer, label2int, num_max_positions=NUM_MAX_POSITIONS)

## Config

In [10]:
from collections import namedtuple
import torch

LOG_DIR = "./logs/"
CACHE_DIR = "./cache/"

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

FineTuningConfig = namedtuple('FineTuningConfig',
      field_names="num_classes, dropout, init_range, batch_size, lr, max_norm, n_epochs,"
                  "n_warmup, valid_pct, gradient_acc_steps, device, log_dir, dataset_cache")

finetuning_config = FineTuningConfig(
                2, 0.1, 0.02, BATCH_SIZE, 6.5e-5, 1.0, 2,
                10, 0.1, 1, device, LOG_DIR, 
                CACHE_DIR+'dataset_cache.bin')

finetuning_config

FineTuningConfig(num_classes=2, dropout=0.1, init_range=0.02, batch_size=32, lr=6.5e-05, max_norm=1.0, n_epochs=2, n_warmup=10, valid_pct=0.1, gradient_acc_steps=1, device=device(type='cuda', index=0), log_dir='./logs/', dataset_cache='./cache/dataset_cache.bin')

## Create datasets

In [11]:
def create_dataloaders(df: pd.DataFrame, processor: TextProcessor, batch_size:int=32, shuffle:bool=False, valid_pct:float=None, 
                   text_col:str="text", label_col:str="label"):
    "Process rows in `df` with `processor` and return a  DataLoader"
    
    features, labels = [], [] 
    for i, row in tqdm(df.iterrows(), total=len(df)):
        ids, lbl = processor.process_example((row[LABEL_COL], row[TEXT_COL]))
        features += [ids]
        labels += [lbl]
    
    dataset = TensorDataset(
                    torch.tensor(features, dtype=torch.long), 
                    torch.tensor(labels, dtype=torch.long))
    
    if valid_pct is not None:
        valid_size = int(valid_pct * len(df))
        train_size = len(df) - valid_size
        valid_dataset, train_dataset = random_split(dataset, [valid_size, train_size])
        valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)    
        return train_loader, valid_loader

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return data_loader


In [12]:
# create train and valid sets by splitting
train_dl, valid_dl = create_dataloaders(datasets["train"], processor, 
                                    batch_size=finetuning_config.batch_size, 
                                    valid_pct=finetuning_config.valid_pct)

test_dl = create_dataloaders(datasets["test"], processor, 
                             batch_size=finetuning_config.batch_size, 
                             valid_pct=None)

HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




HBox(children=(IntProgress(value=0, max=5000), HTML(value='')))




## TransformerWithClfHead

In [13]:
import torch.nn as nn

def get_num_params(model):
    mp = filter(lambda p: p.requires_grad, model.parameters())
    return sum(np.prod(p.size()) for p in mp)

class Transformer(nn.Module):
    "Adopted from https://github.com/huggingface/naacl_transfer_learning_tutorial"

    def __init__(self, embed_dim, hidden_dim, num_embeddings, num_max_positions, num_heads, num_layers, dropout, causal):
        super().__init__()
        self.causal = causal
        self.tokens_embeddings = nn.Embedding(num_embeddings, embed_dim)
        self.position_embeddings = nn.Embedding(num_max_positions, embed_dim)
        self.dropout = nn.Dropout(dropout)

        self.attentions, self.feed_forwards = nn.ModuleList(), nn.ModuleList()
        self.layer_norms_1, self.layer_norms_2 = nn.ModuleList(), nn.ModuleList()
        for _ in range(num_layers):
            self.attentions.append(nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout))
            self.feed_forwards.append(nn.Sequential(nn.Linear(embed_dim, hidden_dim),
                                                    nn.ReLU(),
                                                    nn.Linear(hidden_dim, embed_dim)))
            self.layer_norms_1.append(nn.LayerNorm(embed_dim, eps=1e-12))
            self.layer_norms_2.append(nn.LayerNorm(embed_dim, eps=1e-12))

    def forward(self, x, padding_mask=None):
        """ x has shape [seq length, batch], padding_mask has shape [batch, seq length] """
        positions = torch.arange(len(x), device=x.device).unsqueeze(-1)
        h = self.tokens_embeddings(x)
        h = h + self.position_embeddings(positions).expand_as(h)
        h = self.dropout(h)

        attn_mask = None
        if self.causal:
            attn_mask = torch.full((len(x), len(x)), -float('Inf'), device=h.device, dtype=h.dtype)
            attn_mask = torch.triu(attn_mask, diagonal=1)

        for layer_norm_1, attention, layer_norm_2, feed_forward in zip(self.layer_norms_1, self.attentions,
                                                                       self.layer_norms_2, self.feed_forwards):
            h = layer_norm_1(h)
            x, _ = attention(h, h, h, attn_mask=attn_mask, need_weights=False, key_padding_mask=padding_mask)
            x = self.dropout(x)
            h = x + h

            h = layer_norm_2(h)
            x = feed_forward(h)
            x = self.dropout(x)
            h = x + h
        return h


class TransformerWithClfHead(nn.Module):
    "Adopted from https://github.com/huggingface/naacl_transfer_learning_tutorial"
    def __init__(self, config, fine_tuning_config):
        super().__init__()
        self.config = fine_tuning_config
        self.transformer = Transformer(config.embed_dim, config.hidden_dim, config.num_embeddings,
                                       config.num_max_positions, config.num_heads, config.num_layers,
                                       fine_tuning_config.dropout, causal=not config.mlm)
        
        self.classification_head = nn.Linear(config.embed_dim, fine_tuning_config.num_classes)
        self.apply(self.init_weights)

    def init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.LayerNorm)):
            module.weight.data.normal_(mean=0.0, std=self.config.init_range)
        if isinstance(module, (nn.Linear, nn.LayerNorm)) and module.bias is not None:
            module.bias.data.zero_()

    def forward(self, x, clf_tokens_mask, clf_labels=None, padding_mask=None):
        hidden_states = self.transformer(x, padding_mask)

        clf_tokens_states = (hidden_states * clf_tokens_mask.unsqueeze(-1).float()).sum(dim=0)
        clf_logits = self.classification_head(clf_tokens_states)

        if clf_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
            loss = loss_fct(clf_logits.view(-1, clf_logits.size(-1)), clf_labels.view(-1))
            return clf_logits, loss
        return clf_logits

In [14]:
from pytorch_transformers import cached_path

# download pre-trained model and config
state_dict = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                    "naacl-2019-tutorial/model_checkpoint.pth"), map_location='cpu')

config = torch.load(cached_path("https://s3.amazonaws.com/models.huggingface.co/"
                                        "naacl-2019-tutorial/model_training_args.bin"))

# init model: Transformer base + classifier head
model = TransformerWithClfHead(config=config, fine_tuning_config=finetuning_config).to(finetuning_config.device)

incompatible_keys = model.load_state_dict(state_dict, strict=False)
print(f"Parameters discarded from the pretrained model: {incompatible_keys.unexpected_keys}")
print(f"Parameters added in the model: {incompatible_keys.missing_keys}")

Parameters discarded from the pretrained model: ['lm_head.weight']
Parameters added in the model: ['classification_head.weight', 'classification_head.bias']


In [15]:
get_num_params(model)

50397182

## Prepare fine-tuning loop

In [16]:
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage, Accuracy 
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar
import torch.nn.functional as F
from pytorch_transformers.optimization import AdamW

# Bert optimizer
optimizer = AdamW(model.parameters(), lr=finetuning_config.lr, correct_bias=False) 

def update(engine, batch):
    "update function for training"
    model.train()
    inputs, labels = (t.to(finetuning_config.device) for t in batch)
    inputs = inputs.transpose(0, 1).contiguous() # [S, B]
    _, loss = model(inputs, 
                    clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]), 
                    clf_labels=labels)
    loss = loss / finetuning_config.gradient_acc_steps
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), finetuning_config.max_norm)
    if engine.state.iteration % finetuning_config.gradient_acc_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
    return loss.item()

def inference(engine, batch):
    "update function for evaluation"
    model.eval()
    with torch.no_grad():
        batch, labels = (t.to(finetuning_config.device) for t in batch)
        inputs = batch.transpose(0, 1).contiguous()
        logits = model(inputs,
                       clf_tokens_mask = (inputs == tokenizer.vocab[processor.CLS]),
                       padding_mask = (batch == tokenizer.vocab[processor.PAD]))
    return logits, labels

def predict(model, tokenizer, int2label, input="test"):
    "predict `input` with `model`"
    tok = tokenizer.tokenize(input)
    ids = tokenizer.convert_tokens_to_ids(tok) + [tokenizer.vocab['[CLS]']]
    tensor = torch.tensor(ids, dtype=torch.long)
    tensor = tensor.to(device)
    tensor = tensor.reshape(1, -1)
    tensor_in = tensor.transpose(0, 1).contiguous() # [S, 1]
    logits = model(tensor_in,
                   clf_tokens_mask = (tensor_in == tokenizer.vocab['[CLS]']),
                   padding_mask = (tensor == tokenizer.vocab['[PAD]']))
    val, _ = torch.max(logits, 0)
    val = F.softmax(val, dim=0).detach().cpu().numpy()    
    return {int2label[val.argmax()]: val.max(),
            int2label[val.argmin()]: val.min()}
trainer = Engine(update)
evaluator = Engine(inference)

# add metric to evaluator 
Accuracy().attach(evaluator, "accuracy")

# add evaluator to trainer: eval on valid set after each epoch
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(valid_dl)
    print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']}")
          
# lr schedule: linearly warm-up to lr and then to zero
scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (finetuning_config.n_warmup, finetuning_config.lr),
                                              (len(train_dl)*finetuning_config.n_epochs, 0.0)])
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)


# add progressbar with loss
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

# save checkpoints and finetuning config
checkpoint_handler = ModelCheckpoint(finetuning_config.log_dir, 'finetuning_checkpoint', 
                                     save_interval=1, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'imdb_model': model})

# save config to logdir
torch.save(finetuning_config, os.path.join(finetuning_config.log_dir, 'fine_tuning_args.bin'))          



## Lets fine-tune on imdb!

In [17]:
# fit the model on train_dl
trainer.run(train_dl, max_epochs=finetuning_config.n_epochs)

# evaluate the model on test_dl
evaluator.run(test_dl)
print(f"test results - acc: {100*evaluator.state.metrics['accuracy']:.3f}")

HBox(children=(IntProgress(value=0, max=141), HTML(value='')))

validation epoch: 1 acc: 84.2



HBox(children=(IntProgress(value=0, max=141), HTML(value='')))

validation epoch: 2 acc: 86.4

test results - acc: 89.800


In [18]:
!ls -l $finetuning_config.log_dir

total 196912
-rw-r--r--. 1 root root       318 Jul 17 09:56 fine_tuning_args.bin
-rw-------. 1 root root 201630224 Jul 17 09:59 finetuning_checkpoint_imdb_model_2.pth


In [19]:
int2label = {i:label for label,i in label2int.items()}

In [20]:
predict(model, tokenizer, int2label, input = "I just love how the actors are playing")

{'pos': 0.9117301, 'neg': 0.08826993}

In [21]:
predict(model, tokenizer, int2label, input = "This movie is poorly directed")

{'neg': 0.9916163, 'pos': 0.008383713}

## Build flask app

In [None]:
!wget https://bottlepy.org/bottle.py