# this script will be used for training, it will read data processed using data_preprocessing notebook

In [1]:
!pip install datasets yake

In [None]:
import os
import torch
import transformers
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, RobertaForMaskedLM, AdamW, get_linear_schedule_with_warmup, get_scheduler
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
import yake
import spacy
import collections
import math
import random
import argparse
import shutil
import logging
import time
import pandas as pd
import datasets

In [None]:
# params
model_dir = "train_phrase_mask_lr5e5"
src_file = None
model_size = 'base'
add_pos = False
mask_random = False
num_train_epochs = 20
batch_size = 64
learning_rate = 5e-5
is_test = False
max_len_english = 64
max_len = 128
m_ratio = 0.15

is_test False


In [3]:
if(os.path.isdir(model_dir)):
    print(f"Model Directory {model_dir} exists.")
else:
    os.mkdir(model_dir)
    if isinstance(src_file, str):
        if os.path.exists(src_file):
            shutil.copy(src_file, model_dir)
    log_dir = os.path.join(model_dir, "logs")
    print(f"Model directory {model_dir} created.")

log_file = "train_phrase_mask_lr5e5.txt"
logging.basicConfig(filename=log_file, filemode='a',
                    format='%(asctime)s %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logging.Formatter.converter = time.gmtime
logger = logging.getLogger(__name__)
transformers.utils.logging.set_verbosity(logging.INFO)

logger.info("-"*30)


SEED = 10
random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    device = torch.device("cuda")
    print('Using the GPU.')
    if torch.cuda.device_count() > 1:
        print(f"Let's use {torch.cuda.device_count()} GPUs!")
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

ROBERTA_MODEL = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(ROBERTA_MODEL)
model = RobertaForMaskedLM.from_pretrained(ROBERTA_MODEL)
if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)
model.to(device)

In [None]:
# yake parameters
top_n = 20
language = "en"
max_ngram_size = 1
deduplication_threshold = 0.9
deduplication_algo = 'seqm'
windowSize = 1
custom_kw_extractor = yake.KeywordExtractor(lan=language, n=max_ngram_size, dedupLim=deduplication_threshold, dedupFunc=deduplication_algo, windowsSize=windowSize, top=top_n, features=None)
nlp = spacy.load("en_core_web_sm")
logger.info("Modules Loaded")
print("Modules Loaded")


Modules Loaded


In [None]:
def tokenize_sentence(txt, tokenizer):
    """
    Sentence tokenizer
    """
    result = tokenizer(txt, max_length=max_len_english, padding='max_length', truncation=True)
    word_ids = result.word_ids()
    if tokenizer.is_fast:
        result["word_ids"] = [word_ids[i] for i in range(len(result["input_ids"]))]
    return result

def get_word_mapping(tok):
    """
    once sentence is converted into token, this function maps the word id to token id
    """
    word_ids = tok["word_ids"].copy()
    mapping = collections.defaultdict(list)
    current_word_index = -1
    current_word = None
    for idx, word_id in enumerate(word_ids):
        if word_id is not None:
            if word_id != current_word:
                current_word = word_id
                current_word_index += 1
            mapping[current_word_index].append(idx)
    return mapping

def get_pos_tags(doc):
    """
    From the sentence we get the POS tags, used in masking
    """
    pos_tags = {}
    for token in doc:
        if(not (token.is_stop or token.is_punct or token.is_space or token.text.lower() in stop_words)):
            if(token.tag_ in lst_pos_tags):
                pos_tags[token.text] = token.tag_
    return pos_tags

def get_mask_phrases(txt, tok, mapping, add_pos):
    """
    This function mask the phrases from the sentence
    """
    prev_word = None
    prev_id = None
    next = False
    if(mask_random):
        n_sample = math.ceil(0.15*len(mapping))
        mask = random.sample(range(len(mapping)),n_sample)
        mask_words = []
        for idx in mask:
            start, end = tok.word_to_chars(idx)
            word = txt[start:end].lower()
            mask_words.append(word)
    else:
        yake_doc = txt.replace(tokenizer.eos_token, "")
        yake_doc = yake_doc.replace(tokenizer.bos_token, "")
        yake_doc = yake_doc.strip()
        max_keyword = max(3, math.ceil(m_ratio*len(mapping)))
        keywords = custom_kw_extractor.extract_keywords(yake_doc)[:max_keyword]
        lst_kw = [kw[0].lower() for kw in keywords]
        if(len(lst_kw)<max_keyword and add_pos):
            n = max_keyword-len(lst_kw)
            txt_doc = nlp(txt)
            pos_tags = get_pos_tags(txt_doc)
            for w in pos_tags:
                if(w not in lst_kw):
                    lst_kw.append(w.lower())
                    n = n-1
                    if(n==0):
                        break

        mask = []
        mask_words = []
        for idx in mapping:
            start, end = tok.word_to_chars(idx)
            word = txt[start:end].lower()
            if word in lst_kw or next:
                if prev_word is not None:
                    mask.append(prev_id)
                    mask_words.append(prev_word)
                    mask.append(idx)
                    mask_words.append(word)
                    prev_word = None
                else:
                    mask.append(idx)
                    mask_words.append(word)
                    prev_word = None
                if word in lst_kw:
                    next = True
                else:
                    next = False
            else:
                prev_word = word
                prev_id = idx
                next = False
    return mask, mask_words


def get_mask_words(txt, tok, mapping, add_pos):
    """
    This function mask the words from the sentence
    """

    if(mask_random):
        n_sample = math.ceil(0.15*len(mapping))
        mask = random.sample(range(len(mapping)),n_sample)
        mask_words = []
        for idx in mask:
            start, end = tok.word_to_chars(idx)
            word = txt[start:end].lower()
            mask_words.append(word)
    else:
        yake_doc = txt.replace(tokenizer.eos_token, "")
        yake_doc = yake_doc.replace(tokenizer.bos_token, "")
        yake_doc = yake_doc.strip()
        max_keyword = max(3, math.ceil(m_ratio*len(mapping)))
        keywords = custom_kw_extractor.extract_keywords(yake_doc)[:max_keyword]
        lst_kw = [kw[0].lower() for kw in keywords]
        if(len(lst_kw)<max_keyword and add_pos):
            n = max_keyword-len(lst_kw)
            txt_doc = nlp(txt)
            pos_tags = get_pos_tags(txt_doc)
            for w in pos_tags:
                if(w not in lst_kw):
                    #lst_kw.append(w)
                    lst_kw.append(w.lower())
                    n = n-1
                    if(n==0):
                        break

        mask = []
        mask_words = []
        for idx in mapping:
            start, end = tok.word_to_chars(idx)
            word = txt[start:end].lower()
            if word in lst_kw:
                mask.append(idx)
                mask_words.append(word)
    return mask, mask_words

def get_masked_tokens(tokenizer, tok, mapping, mask):
    """
    once we get the mask word id,this function replace with masked tokens
    """
    input_ids = tok["input_ids"].copy()
    labels = [-100]*len(input_ids)
    for word_id in mask:
        for idx in mapping[word_id]:
            labels[idx] = input_ids[idx]
            input_ids[idx] = tokenizer.mask_token_id
    return input_ids, labels

def prepare_features(df):
    """
    helper function to collate function, to prepare the features i.e. input_ids, lablel
    """
    out = {}
    english = df['English']
    german = df['German']
    tok_english = tokenize_sentence(english, tokenizer)
    map_english_words = get_word_mapping(tok_english)
    mask, mask_words = get_mask_words(english, tok_english, map_english_words, False)
    english_masked, label = get_masked_tokens(tokenizer, tok_english, map_english_words, mask)
    tok_german = tokenize_sentence(german, tokenizer)
    german_labels = [-100]*len(tok_german['input_ids'])
    out["input_ids"] = tok_german['input_ids']+english_masked
    out["label"] = german_labels+label
    return out

def collate_mlm_data(features):
    """
    collate function used in data processing
    """
    batch = {}

    lst_input_ids = [f["input_ids"] for f in features]
    lst_labels = [f["label"] for f in features]
    lst_attn_mask = []
    for i in range(len(lst_input_ids)):
        m = len(lst_input_ids[i])
        lst_input_ids[i].extend([tokenizer.pad_token_id]*(max_len-m))
        lst_labels[i].extend([-100]*(max_len-m))
        attention = [1]*m
        attention.extend([0]*(max_len-m))
        lst_attn_mask.append(attention)

    batch["input_ids"] = torch.tensor(lst_input_ids, dtype=torch.long)
    batch["attn_mask"] = torch.tensor(lst_attn_mask, dtype=torch.long)
    batch["labels"] = torch.tensor(lst_labels, dtype=torch.long)
    return batch

In [None]:
tokenized_dataset = datasets.load_from_disk("tokenized_dataset_sample_batch")

batch_size = 12
logger.info("Train Data:-")
train_dataloader = DataLoader(tokenized_dataset["train"], batch_size,
                              shuffle=True, collate_fn=collate_mlm_data)
logger.info(f'len(train_dataloader): {len(train_dataloader)}')

logger.info("Validation Data:-")
valid_dataloader = DataLoader(tokenized_dataset["test"], batch_size,
                              shuffle=False, collate_fn=collate_mlm_data)
logger.info(f'len(valid_dataloader): {len(valid_dataloader)}')

logger.info(f"Data Loaded")
print("Data Loaded")

#----------------------------

Data Loaded


In [None]:
def evaluate_loss(dataloader):
    """
    cross entropy loss
    """
    total_loss = 0.0
    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            b_input_ids = batch["input_ids"].to(device)
            b_attn_mask = batch["attn_mask"].to(device)
            b_labels = batch["labels"].to(device)
            if torch.cuda.device_count() > 1:
                b_inputs_embeds = model.module.roberta.embeddings.word_embeddings(b_input_ids)
            else:
                b_inputs_embeds = model.roberta.embeddings.word_embeddings(b_input_ids)
            output = model(inputs_embeds=b_inputs_embeds, attention_mask=b_attn_mask, labels=b_labels)
            loss = output.loss
            if torch.cuda.device_count() > 1:
                loss = loss.mean()
            total_loss = total_loss + loss.item()

    avg_loss = total_loss/len(dataloader)
    return avg_loss

In [2]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

# tokenizer.save_pretrained(model_dir)
progress_bar = tqdm(range(num_training_steps))
best_valid_loss = 9999
best_epoch = -1
logger.info("-"*40)


for epoch in range(num_train_epochs):
    model.train()
    logger.info("Epoch {} ----".format(epoch+1))
    total_loss = 0.0
    for i, batch in enumerate(train_dataloader):
        b_input_ids = batch["input_ids"].to(device)
        b_attn_mask = batch["attn_mask"].to(device)
        b_labels = batch["labels"].to(device)

        model.zero_grad(set_to_none=True)
        if torch.cuda.device_count() > 1:
            b_inputs_embeds = model.module.roberta.embeddings.word_embeddings(b_input_ids)
        else:
            b_inputs_embeds = model.roberta.embeddings.word_embeddings(b_input_ids)
        output = model(inputs_embeds=b_inputs_embeds, attention_mask=b_attn_mask, labels=b_labels)
        loss = output.loss
        if torch.cuda.device_count() > 1:
            total_loss = total_loss + loss.mean().item()
            loss.mean().backward()
        else:
            total_loss = total_loss + loss.item()
            loss.backward()
        optimizer.step()
        lr_scheduler.step()
        progress_bar.update(1)

    train_loss = total_loss/len(train_dataloader)
    valid_loss = evaluate_loss(valid_dataloader)
    logger.info("Epoch {} : Train loss = {} : Valid loss = {}".format(epoch+1, train_loss, valid_loss))
    print("Epoch {} : Train loss = {} : Valid loss = {}".format(epoch+1, train_loss, valid_loss))

    if(valid_loss < best_valid_loss):
        best_valid_loss = valid_loss
        best_epoch = epoch+1
        if torch.cuda.device_count() > 1:
            model.module.save_pretrained(model_dir)
        else:
            model.save_pretrained(model_dir)
        logger.info(f"Epoch {epoch+1} model saved.")
        print(f"Epoch {epoch+1} model saved.")
        logger.info("-"*40)
        print("-"*40)
    else:
        logger.info("Early Stopping ...")
        print("Early Stopping ...")
        logger.info("-"*40)
        print("-"*40)
        logger.info("Saving back up model")
        print("Saving back up model")
        back_dir = os.path.join(model_dir, "back_up_model")
        os.mkdir(back_dir)
        if torch.cuda.device_count() > 1:
            model.module.save_pretrained(back_dir)
        else:
            model.save_pretrained(back_dir)
        break

logger.info(f"Best Model: {best_epoch}")
print(f"Best Model: {best_epoch}")
logger.info("-"*40)
print("done")