# MeDAL — Minimal MLM fine-tune + mask prediction

Compact workflow: expand MeDAL rows into single-abbreviation rows, fine-tune a masked language model on a small subset, then predict expansions for a masked token. Keep GPU on if available.


In [None]:
# Minimal imports
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset, load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForMaskedLM,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
)

In [None]:
dataset_path = "medal_dataset"
dataset = load_from_disk(dataset_path)

# dataset = load_dataset("lutful2004/MeDAL-dataset")

In [None]:
# Process and reduce dataset size to 20% for faster training
train_split = (
    dataset["train"].shuffle(seed=42).select(range(int(len(dataset["train"]) * 0.2)))
)
test_split = (
    dataset["test"].shuffle(seed=42).select(range(int(len(dataset["test"]) * 0.2)))
)
# The key is "validation" not "valid"
valid_split = (
    dataset["validation"]
    .shuffle(seed=42)
    .select(range(int(len(dataset["validation"]) * 0.2)))
)

# Update dataset with smaller splits
dataset["train"] = train_split
dataset["test"] = test_split
dataset["valid"] = valid_split  # Add a "valid" key for easier reference later
dataset

NameError: name 'dataset' is not defined

In [None]:
# csv_path = "/kaggle/input/medal-emnlp/pretrain_subset/test.csv"
model_checkpoint = "bert-base-uncased"
output_dir = "bert_analyzed"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


def tokenize(ex):
    return tokenizer(ex["TEXT"], truncation=True, padding="max_length", max_length=256)

In [None]:
tokenized_data = dataset.map(tokenize)

In [None]:
# def expand_dataset(df):
#     "Convert rows with pipe-separated LABEL and LOCATION into one row per abbreviation."
#     out = []
#     for _, r in df.iterrows():
#         text = str(r.get("TEXT", ""))
#         labels = str(r.get("LABEL", "")).split("|")
#         locations = str(r.get("LOCATION", "")).split("|")
#         tokens = text.split()
#         for loc, lab in zip(locations, labels):
#             try:
#                 i = int(loc)
#             except Exception:
#                 continue
#             if 0 <= i < len(tokens):
#                 out.append(
#                     {"TEXT": text, "LOCATION": i, "ABBREV": tokens[i], "LABEL": lab}
#                 )
#     return pd.DataFrame(out)


# def prepare_tokenized_dataset(
#     csv_path,
#     model_checkpoint,
#     output_dir,
#     max_rows=500,
#     max_length=128,
#     test_size=0.1,
#     val_size=0.1,
# ):
#     """
#     Load CSV data, expand it to one-abbreviation-per-row format, tokenize it, and split into train/test/val sets.

#     Args:
#         csv_path: Path to the CSV file
#         tokenizer: Tokenizer to use
#         max_rows: Maximum number of rows to use
#         max_length: Maximum length for tokenization
#         test_size: Proportion of data for testing (default 0.1)
#         val_size: Proportion of data for validation (default 0.1)

#     Returns:
#         tokenized_datasets: Dictionary containing train, test, and validation datasets
#     """
#     df = pd.read_csv(csv_path)

#     df = expand_dataset(df)

#     print(df.shape)
#     if len(df) == 0:
#         raise ValueError(
#             "No rows after expand_dataset — check CSV columns TEXT/LABEL/LOCATION"
#         )
#     # df = df.sample(frac=1, random_state=42).reset_index(drop=True).iloc[:max_rows]

#     ds = Dataset.from_pandas(df)
#     tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

#     def tokenize(ex):
#         return tokenizer(
#             ex["TEXT"], truncation=True, padding="max_length", max_length=max_length
#         )

#     tokenized = ds.map(tokenize, batched=True, remove_columns=ds.column_names)

#     # Split into train/test/validation sets
#     test_val_size = test_size + val_size
#     train_test_split = tokenized.train_test_split(test_size=test_val_size, seed=42)

#     # If validation set is requested, split the test set further
#     if val_size > 0:
#         # Calculate the proportion of test_val that should be validation
#         val_proportion = val_size / test_val_size
#         test_val_split = train_test_split["test"].train_test_split(
#             test_size=val_proportion, seed=42
#         )

#         tokenized_datasets = {
#             "train": train_test_split["train"],
#             "test": test_val_split["train"],  # This is now the test set
#             "valid": test_val_split["test"],  # This is the validation set
#         }
#     else:
#         tokenized_datasets = {
#             "train": train_test_split["train"],
#             "test": train_test_split["test"],
#             "valid": None,
#         }

#         tokenizer.save_pretrained(output_dir)
#     return tokenizer, tokenized_datasets

In [None]:
# csv_path = "/kaggle/input/medal-emnlp/pretrain_subset/test.csv"
# model_checkpoint = "bert-base-uncased"
# output_dir="bert_analyzed"

# tokenizer,tokenized_data=prepare_tokenized_dataset(csv_path=csv_path,model_checkpoint=model_checkpoint,output_dir=output_dir)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torch
import gc
import numpy as np
from transformers import (
    AutoModelForMaskedLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments
)

def select_and_fine_tune_model( 
    model_checkpoint,
    tokenizer,
    train_data,
    test_data,
    output_dir,
    device,
    epochs=5,
    batch_size=16,
):
    """Load data, tokenize, fine-tune MLM on a small subset, save model and tokenizer. 
       Uses mixed precision + memory cache cleanup.
    """

    model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
    model.to(device)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15
    )

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        preds = np.argmax(logits, axis=-1)
        mask = labels != -100
        if mask.sum() == 0:
            return {"accuracy": 0.0}
        correct = (preds[mask] == labels[mask]).sum()
        total = mask.sum()
        return {"accuracy": float(correct) / int(total)}

    args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=4,

        # ✅ runtime mixed precision
        fp16=torch.cuda.is_available(),   # use FP16 if supported
        bf16=torch.cuda.is_bf16_supported(),  # use BF16 if supported (Ampere+ GPUs)

        evaluation_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        save_total_limit=1,
        eval_accumulation_steps=50,
        logging_steps=500,
        report_to=[],
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_data,
        eval_dataset=test_data,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    trainer.train()
    trainer.save_model(output_dir)

    # --- Memory cleanup ---
    del trainer
    torch.cuda.empty_cache()
    gc.collect()

    return model, device


In [None]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model, device = select_and_fine_tune_model(
    model_checkpoint=model_checkpoint,
    tokenizer=tokenizer,
    train_data=tokenized_data["train"],
    test_data=tokenized_data["test"],
    output_dir=output_dir,
    device=device,
)

In [None]:
def mask_text(text, location):
    tokens = text.split()
    if not (0 <= location < len(tokens)):
        raise IndexError("location out of range")
    tokens[location] = "[MASK]"
    return " ".join(tokens)


def predict_expansion_single(text, location, tokenizer, model, device, top_k=5):
    """Predict expansion for a single text with a single masked location."""
    masked = mask_text(text, location)
    inputs = tokenizer(masked, return_tensors="pt", truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        out = model(**inputs)

    mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[
        1
    ]
    if len(mask_idx) == 0:
        return {"original": text, "masked": masked, "preds": ["[No mask found]"]}

    logits = out.logits[0, mask_idx, :]
    topk = torch.topk(logits, top_k, dim=-1).indices[0].cpu().tolist()
    preds = [tokenizer.decode([t]).strip() for t in topk]

    return {"original": text, "masked": masked, "preds": preds}


def predict_expansion(text, location, tokenizer, model, device, top_k=5):
    """Predict expansion for a text with a masked location."""
    return predict_expansion_single(text, location, tokenizer, model, device, top_k)


def evaluate_model_accuracy(valid_data, tokenizer, model, device):
    """Evaluate model accuracy on the dataset using built-in evaluation."""

    # Get the validation dataset
    eval_dataset = valid_data

    if eval_dataset is None:
        print(f"No {valid_data} dataset available for evaluation")
        return None

    # Create data collator for masked language modeling
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=True, mlm_probability=0.15
    )

    # Setup evaluation arguments
    eval_args = TrainingArguments(
        output_dir="./eval_output",
        per_device_eval_batch_size=16,
        report_to=[],
    )

    # Create trainer for evaluation
    trainer = Trainer(
        model=model,
        args=eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # Evaluate and get metrics
    eval_results = trainer.evaluate()

    print(f"Evaluation results on {valid_data} set:")
    print(f"Perplexity: {np.exp(eval_results['eval_loss']):.2f}")
    print(f"Loss: {eval_results['eval_loss']:.4f}")

    return eval_results

In [None]:
evaluate_model_accuracy(valid_data=tokenized_data["valid},tokenizer=tokenizer,model=model,device=device)

In [None]:
# # Example usage (small quick run)
# csv_path = "smaller_datasets/full_data_small.csv"
# checkpoint = "bert-base-uncased"
# out_dir = "fine_tuned_medal"
# # Fine-tune briefly (epochs=1, small max_rows). Increase for production.
# tokenizer, model, device = fine_tune_model_and_tokenizer(
#     checkpoint, csv_path, output_dir=out_dir, epochs=1, batch_size=8, max_rows=200
# )
# # Reload as MLM for prediction
# tokenizer = AutoTokenizer.from_pretrained(out_dir)
# model = AutoModelForMaskedLM.from_pretrained(out_dir).to(device)

In [None]:
# pick an example from the CSV to test
df = pd.read_csv(csv_path).head(1)
if not df.empty:
    df_exp = expand_dataset(df)
    if not df_exp.empty:
        ex = df_exp.iloc[0]
        res = predict_expansion(
            ex["TEXT"], int(ex["LOCATION"]), tokenizer, model, device, top_k=5
        )
        print("Masked:", res["masked"])
        print("Predictions:", res["preds"])
    else:
        print("No expandable rows in the sample")
else:
    print("CSV is empty or missing")

In [None]:
pd.read_csv(csv_path).head(1)