<a href="https://colab.research.google.com/github/meti-94/TextClassification/blob/main/notebooks/GuideToTransformersDomainAdaptation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Guide to Transformers Domain Adaptation
This guide illustrates an end-to-end workflow of domain adaptation, where we domain-adapt a transfomer model for biomedical NLP applications.

It showcases the two domain adaptation techniques we investigated in our research:
1. Data Selection
2. Vocabulary Augmentation

Following that, we demonstrate how such a domain-adapted Transformers model is compatible with 🤗 `transformers`'s training interface and how it outperforms an out-of-the-box (non-domain adapted) model.

These techniques are applied to BERT small but the codebase is written to be generalizable to other classes of Transformers supported by HuggingFace.

### Caveats
For this guide, we use a much smaller subset (<0.05%) of the in-domain corpora due to memory and time constraints. 

### Setup: Install dependencies
We begin by installing `transformers-domain-adaptation` using `pip`.

In [3]:
%%capture
!pip install -U pip
!pip install transformers-domain-adaptation

### Setup: Download demo files

In [13]:
%%capture
!wget --no-check-certificate --no-proxy http://georgian-toolkit.s3.amazonaws.com/transformers-domain-adaptation/colab/files.zip
!unzip files.zip

In [30]:
!rm -rf results/
!rm -rf runs/
!rm -rf output/

## Constants
We first define some constants, including the appropriate model card and relevant paths to text corpora.

There are two types of corpora in the context of Domain Adaptation:

1. Fine-Tuning Corpus
> Given an NLP task (e.g. text classification, summarization, etc.), the text portion of this dataset is the fine-tuning corpus.

2. In-Domain Corpus
> This is an unsupervised text dataset that is used for domain pre-training. The text domain is the same as, if not broader than, the domain of fine-tuning corpus.

In [4]:
model_card = 'HooshvareLab/bert-fa-zwnj-base'

# Domain-pre-training corpora
dpt_corpus_train = 'data/pubmed_subset_train.txt'
dpt_corpus_train_data_selected = 'data/pubmed_subset_train_data_selected.txt'
dpt_corpus_val = 'data/pubmed_subset_val.txt'

# Fine-tuning corpora
# If there are multiple downstream NLP tasks/corpora, you can concatenate those files together
ft_corpus_train = 'data/BC2GM_train.txt'

### Load model and tokenizer
Next we load the model and its corresponding tokenizer.

In [5]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = AutoModelForMaskedLM.from_pretrained(model_card)
tokenizer = AutoTokenizer.from_pretrained(model_card)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=565.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=473451616.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=426422.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1108824.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=134.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=292.0, style=ProgressStyle(description_…




## Data Selection
Not all data in the in-domain corpora may be helpful or relevant during domain pre-training. For irrelevant documents, at best, it does not degrade the domain-adapted model performance. At worst, the model regresses and loses valuable pre-trained information — catastrophic forgetting.

As such, we select documents from the in-domain corpus that are likely to be relevant for the downstream fine-tuning dataset(s), using a variety of similarity and diversity metrics designed by Ruder & Plank.

Reference:
- Sebastian Ruder and Barbara Plank. Learning to select data for transfer learning with Bayesian Optimization. In *Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing.* 2017.

In [9]:
from pathlib import Path

from transformers_domain_adaptation import DataSelector


selector = DataSelector(
    keep=0.75,  # TODO Replace with `keep`
    tokenizer=tokenizer,
    similarity_metrics=['euclidean'],
    diversity_metrics=[
        "type_token_ratio",
        "entropy",
    ],
)

In [10]:
# Load text data into memory
fine_tuning_texts = Path(ft_corpus_train).read_text().splitlines()
print(fine_tuning_texts[:10])
training_texts = Path(dpt_corpus_train).read_text().splitlines()
print(training_texts[:10])
# Fit on fine-tuning corpus
selector.fit(fine_tuning_texts)

# Select relevant documents from in-domain training corpus
selected_corpus = selector.transform(training_texts)

# Save selected corpus to disk under `dpt_corpus_train_data_selected`
Path(dpt_corpus_train_data_selected).write_text('\n'.join(selected_corpus));

['حتی', 'در', 'جنگ', 'جهانی', 'دوم', '،', '<e1>', 'سربازان', '</e1>', 'آمریکایی']
['اولین انتقال و نفوذ طبیعی فرهنگ و تمدن اسلامی به اروپا از طریق کانون های جغرافیایی مصر، اندلس و سیسیل انجام گرفت و آنچه توانست به روند این انتقال سرعت بخشد جنگ های صلیبی بود.', 'اولین انتقال و نفوذ طبیعی فرهنگ و تمدن اسلامی به اروپا از طریق کانون های جغرافیایی مصر، اندلس و سیسیل انجام گرفت و آنچه توانست به روند این انتقال سرعت بخشد جنگ های صلیبی بود.', 'اولین انتقال و نفوذ طبیعی فرهنگ و تمدن اسلامی به اروپا از طریق کانون های جغرافیایی مصر، اندلس و سیسیل انجام گرفت و آنچه توانست به روند این انتقال سرعت بخشد جنگ های صلیبی بود.', 'ویژگی های هنر عصر اموی: ۱- تلفیقی بودن ۲- بازنمایی نوعی تفنن و تفریح ۳- نقاشی های تزئینی و تندیس های بی کیفیت', 'ویژگی های هنر عصر اموی: ۱- تلفیقی بودن ۲- بازنمایی نوعی تفنن و تفریح ۳- نقاشی های تزئینی و تندیس های بی کیفیت', 'ویژگی های هنر عصر اموی: ۱- تلفیقی بودن ۲- بازنمایی نوعی تفنن و تفریح ۳- نقاشی های تزئینی و تندیس های بی کیفیت', 'قبه الصخره یکی از تجلی گاه های زیبایی و ظرا

computing similarity: 100%|██████████| 1/1 [00:02<00:00,  2.93s/metric]
computing diversity: 100%|██████████| 2/2 [00:00<00:00,  2.46metric/s]


Since we specified `keep=0.5` in the `DataSelector`, the selected corpus should be half the size of the in-domain corpus, containing the top 50% most relevant documents.

In [11]:
len(training_texts), len(selected_corpus)

(14532, 10899)

In [12]:
selected_corpus[0]

'روز سیزده آبان در تاریخ ایران با وقایع مهمی مصادف شده است: سیزدهم آبان ماه سال ۴۳، نظام طاغوت در اقدامی عجولانه به اقامتگاه امام خمینی رحمه الله در قم هجوم برد. امام رحمه الله با اتومبیل به تهران منتقل شدند و همان روز با هواپیما به ترکیه تبعید گردیدند.همچنین تصرف سفارت آمریکا که در ادبیات سیاسی جمهوری اسلامی ایران تسخیر لانهٔ جاسوسی خوانده می\u200cشود، در تاریخ ۱۳ آبان ۱۳۵۸اتفاق افتاد.در روز ۱۳ آبان ۱۳۵۷ نیز واقعه کشتار جمعی دانش\u200cآموزان تهرانی که به نشانه اعتراض به حکومت پهلوی در محوطه دانشگاه تهران جمع شده بودند، اتفاق افتاد. به منظور گرامی\u200cداشت این روز، سیزده آبان در تقویم جمهوری اسلامی ایران به عنوان روز دانش\u200cآموز نامگذاری شده\u200cاست. '

## Vocabulary Augmentation
We can extend the existing vocabulary of the model to include domain-specific terminology. This allows for the representation such terminology to be explicit learnt during domain pre-training.

In [13]:
from transformers_domain_adaptation import VocabAugmentor

target_vocab_size = 42_300  # len(tokenizer) == 30_522

augmentor = VocabAugmentor(
    tokenizer=tokenizer, 
    cased=False, 
    target_vocab_size=target_vocab_size
)

# Obtain new domain-specific terminology based on the fine-tuning corpus
new_tokens = augmentor.get_new_tokens(ft_corpus_train)

In [14]:
print(new_tokens[:200])

['e2', '</', 'e1', 'آن', 'آنها', 'آب', 'ارائه', 'آسیب', 'آتش', 'آمد', 'آغاز', 'آورد', 'آمده', 'رئیس', 'آنجا', 'آموزان', 'آزمایش', 'آید', 'آماده', 'آموزش', 'آینده', 'آزاد', 'آمریکایی', 'آبی', 'آهنگ', 'مسائل', 'آشپزخانه', 'آوری', 'آورده', 'آنچه', 'آمریکا', 'هیئت', 'آموزشی', 'آثار', 'آبجو', 'آخرین', 'كه', 'آرد', 'مي', 'آلودگی', 'جزئیات', 'آنلاین', 'آمیز', 'آسمان', 'آوردن', 'آکنه', 'آموز', '\u200b\u200b', 'آلبوم', 'درآمد', 'آپارتمان', 'آویزان', 'آنرا', "''", 'پروتئین', 'آخر', 'آقای', 'آلمانی', 'آهن', 'مسئول', 'وسیلهی', 'آفریقا', 'آزادی', 'آرام', 'يك', 'آرامی', 'کنندهی', 'آلوده', 'دربارهی', 'lrb', 'شدهی', 'فرآیند', 'آفریقای', 'آشنا', 'آزمایشگاه', 'مسئولیت', 'آیا', 'آیند', 'آشپزی', 'آمدن', 'تئوری', 'آنتی', 'آنکه', 'تأثیر', 'علائم', 'آورند', 'آلمان', 'کوکائین', 'دهندهی', 'آزمون', 'آگاه', 'آمدند', 'اسرائیل', 'آسان', 'آوردند', 'ریختیم', 'آهنگساز', 'آلات', 'آل', 'مطمئن', 'آزمایشی', 'تزئین', 'اين', 'آژانس', 'ژوئن', 'آرامش', 'آنجایی', 'آزمایشات', 'ژوئیه', 'آوریل', 'آنجلس', 'آوردم', 'آنتن', 'دستهی'

#### Update model and tokenizer with new vocab terminologies

In [15]:
tokenizer.add_tokens(new_tokens)
model.resize_token_embeddings(len(tokenizer))

Embedding(42300, 768)

## Domain Pre-Training
Domain pre-training is the third step in domain adaptation — we continue training Transformer models with the same pre-training procedure on the in-domain corpus.

#### Create dataset

In [16]:
import itertools as it
from pathlib import Path
from typing import Sequence, Union, Generator

from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments

In [17]:
datasets = load_dataset(
    'text', 
    data_files={
        "train": dpt_corpus_train_data_selected, 
        "val": dpt_corpus_val
    }
)

tokenized_datasets = datasets.map(
    lambda examples: tokenizer(examples['text'], truncation=True, max_length=model.config.max_position_embeddings), 
    batched=True
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1038.0, style=ProgressStyle(description…




Using custom data configuration default


Downloading and preparing dataset text/default-c950a03e5fbc56a8 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/text/default-c950a03e5fbc56a8/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-c950a03e5fbc56a8/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




#### Instantiate TrainingArguments and Trainer

In [18]:
training_args = TrainingArguments(
    output_dir="./results/domain_pre_training",
    overwrite_output_dir=True,
    max_steps=500,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    save_steps=50,
    save_total_limit=2,
    logging_steps=50,
    seed=42,
    # fp16=True,
    dataloader_num_workers=2,
    disable_tqdm=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['val'],
    data_collator=data_collator,
    tokenizer=tokenizer,  # This tokenizer has new tokens
)

In [19]:
trainer.train()

Step,Training Loss,Validation Loss,Runtime,Samples Per Second
50,3.2371,3.318074,26.4545,118.241


Step,Training Loss,Validation Loss,Runtime,Samples Per Second
50,3.2371,3.318074,26.4545,118.241
100,3.2359,3.278585,26.4311,118.346
150,3.1945,3.213656,26.3455,118.73
200,3.0867,3.121018,26.26,119.116
250,2.9879,3.127592,26.2454,119.183
300,2.9228,3.131969,26.7772,116.816
350,2.9969,3.072414,26.7657,116.866
400,2.8201,3.052264,26.6905,117.195
450,2.9118,3.007958,26.6972,117.166
500,3.0384,3.026256,26.6788,117.247


TrainOutput(global_step=500, training_loss=3.0432140197753905, metrics={'train_runtime': 501.6731, 'train_samples_per_second': 0.997, 'total_flos': 204948604598400, 'epoch': 0.37})

In [21]:
from transformers import AutoConfig, AutoTokenizer, AutoModel, TFAutoModel, BertForSequenceClassification, GPT2ForSequenceClassification
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from sklearn import preprocessing
import numpy as np
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report
from sklearn.metrics import plot_confusion_matrix
import pickle
from sklearn.metrics import confusion_matrix
# from cmpp import plot_confusion_matrix_from_data, pretty_plot_confusion_matrix
import os

In [22]:
model_name_or_path = "/content/results/domain_pre_training/checkpoint-500"
config = AutoConfig.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = BertForSequenceClassification.from_pretrained(model_name_or_path, num_labels=19)

Some weights of the model checkpoint at /content/results/domain_pre_training/checkpoint-500 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpo

In [25]:
bertified = lambda sample:sample.strip('"').replace('<e2>', ' <e2> ').replace('</e2>', ' </e2> ').replace('<e1>', ' <e1> ').replace('</e1>', ' </e1> ').replace('\u200c', ' ')
le = preprocessing.LabelEncoder()
with open('./train.txt', 'r', encoding='utf-8') as fin:
    content = fin.read().split('\n')
X_train = [bertified(item.split('\t')[-1]) for item in content[0::4]]
y_train = [item.strip().strip('"') for item in content[1::4]]
major_class = [item.split('(')[0] for item in y_train]
y_train = le.fit_transform(y_train)
with open('./test.txt', 'r', encoding='utf-8') as fin:
    content = fin.read().split('\n')
X_test = [bertified(item.split('\t')[-1]) for item in content[0::4]]
y_test = [item.strip().strip('"') for item in content[1::4]]
y_test = le.transform(y_test)
y_test = y_test.astype('int64')

# print(type(y_test), type(y_train), y_test.dtype, y_train.dtype)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.1, random_state=99, stratify=major_class)

In [26]:
SPECIAL_TOKENS = ["<e1>", "<e2>", '</e1>', '</e2>', "<pad>"]

ATTR_TO_SPECIAL_TOKEN = {'bos_token': '<bos>', 'eos_token': '<eos>', 'pad_token': '<pad>',
                         'additional_special_tokens': ["<e>", "<e2>", '</e1>', '</e2>']}
tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
model.resize_token_embeddings(len(tokenizer))

Embedding(42307, 768)

In [27]:
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
X_test_tokenized = tokenizer(X_test, padding=True, truncation=True, max_length=512)

In [28]:
class Dataset(torch.utils.data.Dataset):    
    def __init__(self, encodings, labels=None):          
        self.encodings = encodings        
        self.labels = labels
     
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if any(self.labels):
            item["labels"] = torch.tensor(self.labels[idx])
        return item     
    def __len__(self):
        return len(self.encodings["input_ids"])

train_dataset = Dataset(X_train_tokenized, y_train)
val_dataset = Dataset(X_val_tokenized, y_val)
test_dataset = Dataset(X_test_tokenized, y_test) 

In [None]:
def compute_metrics(p):    
    pred, labels = p
    pred = np.argmax(pred, axis=1)
    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average='macro')
    precision = precision_score(y_true=labels, y_pred=pred, average='macro')
    f1 = f1_score(y_true=labels, y_pred=pred, average='macro')    
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1} # Define Trainer


args = TrainingArguments(
    output_dir="output",
    evaluation_strategy="steps",
    eval_steps=100,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
#     num_train_epochs=3,
    save_steps=100,
    seed=0,
    load_best_model_at_end=True,
    num_train_epochs=100,
    fp16 =True,
    learning_rate=2e-5, )

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],)
 
# Train pre-trained model
trainer.train()

Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Runtime,Samples Per Second
100,No log,2.122516,0.32625,0.300694,0.200658,0.151411,21.1105,37.896
200,No log,1.557455,0.51625,0.389253,0.376674,0.359038,21.0824,37.946


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1,Runtime,Samples Per Second
100,No log,2.122516,0.32625,0.300694,0.200658,0.151411,21.1105,37.896
200,No log,1.557455,0.51625,0.389253,0.376674,0.359038,21.0824,37.946
300,No log,1.25435,0.6275,0.587541,0.575367,0.561533,21.1439,37.836
400,No log,1.08859,0.66625,0.641128,0.596549,0.601788,21.1373,37.848
500,1.615100,1.03013,0.6875,0.662383,0.624392,0.628962,21.1704,37.789
600,1.615100,1.015365,0.68875,0.629418,0.667964,0.637736,21.1776,37.776
700,1.615100,0.952845,0.71625,0.635689,0.680764,0.650028,21.1758,37.779
800,1.615100,0.914506,0.715,0.719936,0.67848,0.678838,21.1453,37.834
900,1.615100,0.918964,0.7225,0.789992,0.695751,0.721741,21.1477,37.829
1000,0.701900,1.009469,0.69875,0.720962,0.691779,0.695678,21.1385,37.846


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [28]:
model_path = "output/checkpoint-900"
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=19)
# Define test trainer
test_trainer = Trainer(model) 
# Make prediction
raw_pred, _, _ = test_trainer.predict(test_dataset) # Preprocess raw predictions
y_pred = np.argmax(raw_pred, axis=1)

In [29]:
print(classification_report(y_pred, y_test))

              precision    recall  f1-score   support

           0       0.90      0.90      0.90       135
           1       0.89      0.79      0.84       217
           2       0.77      0.76      0.76       163
           3       0.75      0.72      0.74       157
           4       0.83      0.80      0.82       158
           5       0.77      0.81      0.79        37
           6       0.82      0.87      0.84       273
           7       0.00      0.00      0.00         0
           8       0.75      0.78      0.77       203
           9       0.70      0.94      0.80        35
          10       0.45      0.71      0.56        14
          11       0.68      0.81      0.74       112
          12       0.31      0.67      0.43        15
          13       0.80      0.76      0.78       213
          14       0.90      0.74      0.81       257
          15       0.76      0.70      0.73        56
          16       0.50      0.44      0.47       520
          17       0.58    

  _warn_prf(average, modifier, msg_start, len(result))
