In [2]:
import os
os.environ['HF_HOME'] = '/pscratch/sd/g/gzhao27/huggingface'

In [3]:
import pandas as pd
import numpy as np
from tqdm import tqdm, trange
import csv
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import statistics

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertConfig

from keras_preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split

import transformers
from transformers import BertForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup

from seqeval.metrics import f1_score, accuracy_score

model_name = "allenai/scibert_scivocab_cased"
fine_tune_save_path = "fine_tuned_mlm"

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print(torch.cuda.get_device_name(0))

num_cores = os.cpu_count()
print("Number of CPU cores:", num_cores)



NVIDIA A100-PCIE-40GB
Number of CPU cores: 256


## Preprocessing SSL data

In [5]:
import os

def get_all_file_paths(directory):
    file_paths = []
    
    for root, _, files in os.walk(directory):
        for file in files:
            file_paths.append(os.path.join(root, file))
    
    return file_paths

directory_path = "./sample_articles/"
file_paths_list = get_all_file_paths(directory_path)


In [6]:
from datasets import load_dataset
ssl_dataset = load_dataset("text", split='train', data_files=file_paths_list)

Resolving data files:   0%|          | 0/1670 [00:00<?, ?it/s]

Found cached dataset text (/pscratch/sd/g/gzhao27/huggingface/datasets/text/default-2d45f076affa0596/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2)


In [32]:
ssl_dataset

Dataset({
    features: ['text'],
    num_rows: 10000
})

In [7]:
#filter empty rows
ssl_dataset = ssl_dataset.filter(lambda example: example['text'])

Loading cached processed dataset at /pscratch/sd/g/gzhao27/huggingface/datasets/text/default-2d45f076affa0596/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-142f1d00a0804776.arrow


In [21]:
# decrease size of ssl dataset
# ssl_dataset = ssl_dataset.select(range(10000))

In [22]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)
def preprocess_function(examples):
    return tokenizer(examples["text"])

In [23]:
ssl_tokenized = ssl_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=256,
    remove_columns=ssl_dataset.column_names,
)

Loading cached processed dataset at /pscratch/sd/g/gzhao27/huggingface/datasets/text/default-2d45f076affa0596/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-a39c583148aa3085_*_of_00256.arrow


In [24]:
block_size = 128


def group_texts(examples):
    # Concatenate all texts.
    
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    
    #print(total_length)
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    return result

In [26]:
ssl_lm = ssl_tokenized.map(group_texts, batched=True, num_proc=1)

Loading cached processed dataset at /pscratch/sd/g/gzhao27/huggingface/datasets/text/default-2d45f076affa0596/0.0.0/cb1e9bd71a82ad27976be3b12b407850fe2837d80c22c5e03a28949843a8ace2/cache-061edbdd8d760c28.arrow


In [27]:
from transformers import DataCollatorForLanguageModeling

tokenizer.add_special_tokens({'pad_token': '[PAD]'})
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

In [28]:
from transformers import AutoModelForMaskedLM, TrainingArguments, Trainer
model = AutoModelForMaskedLM.from_pretrained(model_name)

Some weights of the model checkpoint at allenai/scibert_scivocab_cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [29]:
# model.bert.encoder.layer[2].attention.self.query.weight

In [30]:
ssl_grouped = ssl_lm.train_test_split(test_size=0.2)

In [31]:
training_args = TrainingArguments(
    output_dir="fine_tuned_mlm_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=15,
    weight_decay=0.01,
    push_to_hub=False,
    report_to="tensorboard",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ssl_grouped['train'],
    eval_dataset=ssl_grouped['test'],
    data_collator=data_collator,
)

trainer.train()

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,1.8661,1.615575
2,1.686,1.586112
3,1.6059,1.567068
4,1.5642,1.534472
5,1.497,1.52044
6,1.479,1.465342
7,1.4306,1.465601
8,1.3722,1.483973
9,1.3473,1.439742
10,1.3318,1.455248


TrainOutput(global_step=7620, training_loss=1.436728091502753, metrics={'train_runtime': 799.4697, 'train_samples_per_second': 76.175, 'train_steps_per_second': 9.531, 'total_flos': 4007321056972800.0, 'train_loss': 1.436728091502753, 'epoch': 15.0})

In [None]:
model.save_pretrained(fine_tune_save_path)