In [21]:
from transformers import pipeline, AutoModelForMaskedLM, AutoTokenizer
import random

from datasets import load_dataset

import os
import sys
sys.path.append(os.getcwd()+"/../..")

from src import paths

import pandas as pd
import numpy as np

import tqdm

In [22]:
model = AutoModelForMaskedLM.from_pretrained("GerMedBERT/medbert-512")
tokenizer = AutoTokenizer.from_pretrained("GerMedBERT/medbert-512")

In [23]:
# Load data
data_files = {"train": "ms-diag_clean_train.csv", "validation": "ms-diag_clean_val.csv", "test": "ms-diag_clean_test.csv"}
df = load_dataset(os.path.join(paths.DATA_PATH_PREPROCESSED,'ms-diag'), data_files = data_files)

# Create a masked language model pipeline
unmasker = pipeline('fill-mask', model=model, tokenizer=tokenizer, device_map="auto", top_k=4)

In [28]:
def random_replacement(text:str, unmasker:pipeline, n:int=5, n_words=10)->list():
    """
    Replaces n_words random words in text with [MASK] token. Then fills the [MASK] tokens with the most likely words.
    Returns list with of length n with the augmented texts.
    
    Args:
        text (str): text to be masked
        unmasker (pipeline): HF pipeline to be used for unmasking. Must take a string as input with a single pipeline.tokenizer.mask_token in it. and return a list of dictionaries with the keys 'token_str' and 'score'.
        n (int, optional): number of augmented texts to be returned. Defaults to 5.
        n_words (int, optional): number of words to be masked. Defaults to 10.
        Returns:
            list(str): list of augmented texts
    """

    orig_text_list = text.split()

    # Initialize augmented text array
    augmented_texts = np.tile(orig_text_list, (n,1))
    len_input = len(orig_text_list)

    # Initialize new words and their positions
    replace_ids = []
    new_words = []

    # Truncate text to max length of model
    text = unmasker.tokenizer.encode(text, truncation=True, add_special_tokens=False)
    text = unmasker.tokenizer.decode(text)
    
    # Get random positions for replacement
    for _ in tqdm.tqdm(range(n_words)):
        replace_id = random.randint(1, len_input-3)
        replace_ids.append(replace_id)

        # Mask word
        masked_text_list = orig_text_list.copy()
        masked_text_list[replace_id] = unmasker.tokenizer.mask_token
        masked_text = ' '.join(masked_text_list)

        # Get most likely words
        results = unmasker(masked_text, top_k=n, tokenizer_kwargs={"truncation": True})
        new_words.append([result['token_str'] for result in results])

    new_words = np.stack(new_words, axis=1)

    # Replace words in augmented texts
    augmented_texts[:,replace_ids] = new_words

    # Join words to sentences
    augmented_texts = [' '.join(augmented_text) for augmented_text in augmented_texts]

    return augmented_texts
        
        
def random_insertion(text:str, unmasker:pipeline, n:int=5, n_words=10)->list():
    """
    Inserts n_words random words in text. Then fills the inserted words with the most likely words.
    Returns list with of length n with the augmented texts.
    
    Args:
        text (str): text to be masked
        unmasker (pipeline): HF pipeline to be used for unmasking. Must take a string as input with a single pipeline.tokenizer.mask_token in it. and return a list of dictionaries with the keys 'token_str' and 'score'.
        n (int, optional): number of augmented texts to be returned. Defaults to 5.
        n_words (int, optional): number of words to be inserted. Defaults to 10.
        Returns:
            list(str): list of augmented texts
    """

    # Truncate text to max length of model
    text = unmasker.tokenizer.encode(text, truncation=True, add_special_tokens=False)
    text = unmasker.tokenizer.decode(text)
    orig_text_list = text.split()

    # Initialize augmented text array
    augmented_texts = np.tile(orig_text_list, (n,1))
    len_input = len(orig_text_list)

    # Initialize new words and their positions
    insert_ids = []
    new_words = []
    
    # Get random positions for replacement
    for _ in tqdm.tqdm(range(n_words)):
        insert_id = random.randint(1, len_input-3) # -3 because CLS and SEP tokens are not counted
        insert_ids.append(insert_id)

        # Mask word
        masked_text_list = orig_text_list.copy()
        masked_text_list.insert(insert_id, unmasker.tokenizer.mask_token)
        masked_text = ' '.join(masked_text_list)

        # Get most likely words
        results = unmasker(masked_text, top_k=n, tokenizer_kwargs={"truncation": True})
        new_words.append([result['token_str'] for result in results])

    new_words = np.stack(new_words, axis=1)

    # Insert columns in augmented texts
    augmented_texts = np.insert(augmented_texts, insert_ids, new_words, axis=1)

    # Join words to sentences
    augmented_texts = [' '.join(augmented_text) for augmented_text in augmented_texts]

    return augmented_texts

def augment_text(text:str, unmasker:pipeline, n:int=5, n_replacements = 10, n_insertions = 10)->list():
    """
    Augments text by replacing and inserting words. Then fills the [MASK] tokens with the most likely words.
    Returns list with of length n with the augmented texts.
    
    Args:
        text (str): text to be masked
        unmasker (pipeline): HF pipeline to be used for unmasking. Must take a string as input with a single pipeline.tokenizer.mask_token in it. and return a list of dictionaries with the keys 'token_str' and 'score'.
        n (int, optional): number of augmented texts to be returned. Defaults to 5.
        n_replacements (int, optional): number of words to be replaced. Defaults to 10.
        n_insertions (int, optional): number of words to be inserted. Defaults to 10.
        Returns:
            list(str): list of augmented texts
    """

    # Truncate text to max length of model
    text = unmasker.tokenizer.encode(text, truncation=True, add_special_tokens=False)
    text = unmasker.tokenizer.decode(text)
    orig_text_list = text.split()

    # Initialize augmented text array
    augmented_texts = np.tile(orig_text_list, (n,1))
    len_input = len(orig_text_list)

    # Initialize new words and their positions
    replace_ids = []
    insert_ids = []
    new_replacements = []
    new_insertions = []
    
    # Replacement
    for _ in range(n_replacements):
        replace_id = random.randint(1, len_input-3) # -3 because CLS and SEP tokens are not counted
        replace_ids.append(replace_id)

        # Mask word
        masked_text_list = orig_text_list.copy()
        masked_text_list[replace_id] = unmasker.tokenizer.mask_token
        masked_text = ' '.join(masked_text_list)

        # Get most likely words. If error just return original word
        try:
            results = unmasker(masked_text, top_k=n, tokenizer_kwargs={"truncation": True})
            new_replacements.append([result['token_str'] for result in results])
        except:
            results = n*orig_text_list[replace_id]
            new_replacements.append(results)


    new_replacements = np.stack(new_replacements, axis=1)

    # Insertion
    for _ in range(n_insertions):
        insert_id = random.randint(1, len_input-1)
        insert_ids.append(insert_id)

        # Mask word
        masked_text_list = orig_text_list.copy()
        masked_text_list.insert(insert_id, unmasker.tokenizer.mask_token)
        masked_text = ' '.join(masked_text_list)

        # Get most likely words
        try:
            results = unmasker(masked_text, top_k=n, tokenizer_kwargs={"truncation": True})
            new_insertions.append([result['token_str'] for result in results])
        except:
            results = n*[""]
            new_insertions.append(results)
        

    new_insertions = np.stack(new_insertions, axis=1)

    # Replace words in augmented texts
    augmented_texts[:,replace_ids] = new_replacements

    # Insert columns in augmented texts
    augmented_texts = np.insert(augmented_texts, insert_ids, new_insertions, axis=1)

    # Join words to sentences
    augmented_texts = [' '.join(augmented_text) for augmented_text in augmented_texts]

    return augmented_texts

def augment_df(df, unmasker, n, n_replacement:int=10, n_insertion:int=5, n_replacements_ratio:float=None, n_inserations_ratio:float=None)->pd.DataFrame:
    """
    Augments a dataframe by replacing and inserting words in the text column.
    
    Args:
        df (pd.DataFrame): dataframe to be augmented
        n (int): number of augmented texts to be returned per example. Defaults to 5.
        n_replacement (int, optional): number of replacements per text. Defaults to 10.
        n_insertion (int, optional): number of insertions per text. Defaults to 5.
        n_replacements_ratio (float, optional): ratio of replacements compared to text length per text. Defaults to None. Mutually exclusive with n_replacement.
        n_inserations_ratio (float, optional): ratio of insertions compared to text length per text. Defaults to None. Mutually exclusive with n_insertion.
    
    Returns:
        pd.DataFrame: augmented dataframe
    """
    
    # Augment data
    augmented_data = []
    labels = []
    for i in tqdm.tqdm(range(len(df))):

        text = df['text'][i]
        label = df['labels'][i]
        max_len = min(len(text.split()), unmasker.tokenizer.model_max_length-2)

        # Check if ratios are used
        if n_replacements_ratio is not None:
            n_replacement = int(max(1,max_len*n_replacements_ratio))
        if n_inserations_ratio is not None:
            n_insertion = int(max(1, max_len*n_inserations_ratio))
        
        # Replace words
        augmented_text = augment_text(text, unmasker, n=n, n_replacements=n_replacement, n_insertions=n_insertion)
        augmented_data.extend(augmented_text)
        labels.extend([label]*len(augmented_text))

    # Create augmented dataframe
    augmented_df = pd.DataFrame({"text": augmented_data, "labels": labels})

    return augmented_df

In [25]:
df_spms = df["train"].filter(lambda x: x['labels'] == 'secondary_progressive_multiple_sclerosis')
df_ppms = df["train"].filter(lambda x: x['labels'] == 'primary_progressive_multiple_sclerosis')
df_rrms = df["train"].filter(lambda x: x['labels'] == 'relapsing_remitting_multiple_sclerosis')

In [26]:
# Augment data for ppms, spms and rrms creating even class sizes
augmented_dfs = []
print("Starting With Minority Classes")
for df in (3*[df_ppms] + 4*[df_spms]):
    augmented_df = augment_df(df, unmasker, n=5, n_replacements_ratio=0.2, n_inserations_ratio=0.1)
    augmented_dfs.append(augmented_df)

# print("Starting With Majority Class")
# augmented_dfs.append(augment_df(df_rrms, unmasker, n=1, n_replacements_ratio=0.1, n_inserations_ratio=0.05))

Starting With Minority Classes


100%|██████████| 8/8 [08:53<00:00, 66.75s/it]
100%|██████████| 5/5 [04:12<00:00, 50.45s/it]
100%|██████████| 8/8 [08:59<00:00, 67.39s/it]
100%|██████████| 5/5 [04:04<00:00, 48.92s/it]
100%|██████████| 8/8 [08:39<00:00, 64.99s/it]
 60%|██████    | 3/5 [04:02<02:41, 80.92s/it]


TypeError: string indices must be integers

In [33]:
# Concatenate augmented dataframes
augmented_df = pd.concat(augmented_dfs)
augmented_df["date"] = ""
augmented_df["rid"] = ""

# Save augmented dataframe
augmented_df.to_csv(os.path.join(paths.DATA_PATH_PREPROCESSED, 'ms-diag', 'ms-diag_augmented.csv'), index=False)