In [1]:
import torch
import torch.nn as nn
import transformers
import gensim
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import Dataset
import pandas as pd
import os
import json
from tqdm.autonotebook import trange, tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
class TranslationDataset:
    """
    Prepare tokenized datasets for training and evaluation without relying on DataLoader.
    """
    @staticmethod
    def prepare_dataset(english_texts, chinese_texts, tokenizer):
        # Tokenize parallel corpus
        tokenized_data = {
            "source": tokenizer(
                english_texts,
                padding=True,
                truncation=True,
                max_length=128,
                return_tensors='pt'
            ),
            "target": tokenizer(
                chinese_texts,
                padding=True,
                truncation=True,
                max_length=128,
                return_tensors='pt'
            )
        }

        # Prepare data dictionary for Hugging Face Dataset
        dataset_dict = {
            "input_ids": tokenized_data["source"]["input_ids"],
            "attention_mask": tokenized_data["source"]["attention_mask"],
            "labels": tokenized_data["target"]["input_ids"]
        }

        # Convert to Hugging Face Dataset
        return Dataset.from_dict({key: value.tolist() for key, value in dataset_dict.items()})


class BiomedicalMarianMTEnhancer(nn.Module):
    """
    Wraps MarianMT with additional medical term embeddings.
    """
    def __init__(self, base_model, tokenizer, biowordvec_path='C:\\Users\\Gaming\\Documents\\GitHub\\MIE2\\2024-fall-assignment-linaaron88\\project\\BioWordVec_PubMed_MIMICIII_d200.vec.bin'):
        super().__init__()
        self.base_model = base_model
        self.tokenizer = tokenizer

        # Load BioWordVec embeddings
        print(biowordvec_path)
        self.biowordvec = gensim.models.KeyedVectors.load_word2vec_format(
            biowordvec_path,
            binary=True
        )

        # Create a custom embedding layer for medical terms
        embedding_dim = self.biowordvec.vector_size
        vocab_size = base_model.config.vocab_size

        # Create a custom embedding layer
        self.medical_embedding_layer = nn.Embedding(
            vocab_size,
            embedding_dim
        )

        # Initialize medical embedding layer
        self._init_medical_embeddings()

        # Additional projection layer to align embeddings
        self.projection = nn.Linear(
            embedding_dim,
            base_model.config.d_model
        )

    def _init_medical_embeddings(self):
        weight = self.medical_embedding_layer.weight.data

        for token, idx in self.tokenizer.get_vocab().items():
            clean_token = token.replace('▁', '').strip()

            try:
                # Try to get embedding for the token
                vec = self.biowordvec[clean_token]
                weight[idx] = torch.tensor(vec)
            except KeyError:
                # Fallback to default initialization
                nn.init.xavier_uniform_(weight[idx].unsqueeze(0))

    def forward(self, input_ids, labels=None, attention_mask=None):
        # Get base model embeddings
        base_embeddings = self.base_model.model.get_input_embeddings()(input_ids)

        # Get medical term embeddings
        medical_embeddings = self.medical_embedding_layer(input_ids)

        # Project medical embeddings
        projected_medical_embeddings = self.projection(medical_embeddings)

        # Combine base and medical embeddings
        combined_embeddings = base_embeddings + projected_medical_embeddings

        # Continue with standard MarianMT forward pass
        outputs = self.base_model(
            inputs_embeds=combined_embeddings,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs
    
    def generate(self, input_ids=None, attention_mask=None, **kwargs):
        """
        Generate translations with custom embeddings and pass them into MarianMT method as input_embeddings
        """
        if input_ids is not None:
            # Compute the base embeddings
            base_embeddings = self.base_model.model.get_input_embeddings()(input_ids)

            # Compute the medical term embeddings
            medical_embeddings = self.medical_embedding_layer(input_ids)

            # Project medical embeddings
            projected_medical_embeddings = self.projection(medical_embeddings)

            # Combine base and medical embeddings
            combined_embeddings = base_embeddings + projected_medical_embeddings

            # Use the combined embeddings for generation
            return self.base_model.generate(
                inputs_embeds=combined_embeddings,
                attention_mask=attention_mask,
                **kwargs
            )
        else:
            raise ValueError("`input_ids` must be provided for generating embeddings.")
        
    def save_custom(self, save_directory, tokenizer=None):
        """
        Save the model and custom embeddings.
        """
        os.makedirs(save_directory, exist_ok=True)

        # Paths
        model_save_path = os.path.join(save_directory, "model")
        embedding_save_path = os.path.join(model_save_path, "medical_embeddings.pth")
        projection_save_path = os.path.join(model_save_path, "projection_layer.pth")
        custom_config_path = os.path.join(model_save_path, "custom_config.json")
        tokenizer_save_path = os.path.join(save_directory, "tokenizer")

        os.makedirs(model_save_path, exist_ok=True)

        # Save the base model
        self.base_model.save_pretrained(model_save_path)

        # Save the medical embedding and projection layer
        torch.save(self.medical_embedding_layer.state_dict(), embedding_save_path)
        torch.save(self.projection.state_dict(), projection_save_path)

        # Save custom configuration
        custom_config = {
            "embedding_dim": self.medical_embedding_layer.embedding_dim,
            "vocab_size": self.medical_embedding_layer.num_embeddings
        }
        with open(custom_config_path, "w") as f:
            json.dump(custom_config, f)

        # Save tokenizer
        if tokenizer is not None:
            tokenizer.save_pretrained(tokenizer_save_path)

    @classmethod
    def from_custom(cls, save_directory):
        """
        Load the model and custom embeddings.
        """
        # Paths
        model_save_path = os.path.join(save_directory, "model")
        embedding_save_path = os.path.join(model_save_path, "medical_embeddings.pth")
        projection_save_path = os.path.join(model_save_path, "projection_layer.pth")
        custom_config_path = os.path.join(model_save_path, "custom_config.json")
        tokenizer_save_path = os.path.join(save_directory, "tokenizer")

        # Load tokenizer
        tokenizer = transformers.MarianTokenizer.from_pretrained(tokenizer_save_path)

        # Load the base model
        base_model = transformers.MarianMTModel.from_pretrained(model_save_path)

        # Load custom configuration
        with open(custom_config_path, "r") as f:
            custom_config = json.load(f)

        # Extract custom configuration values
        embedding_dim = custom_config.get("embedding_dim")
        vocab_size = custom_config.get("vocab_size")

        # Create an instance of the enhanced model
        enhancer = cls(
            base_model=base_model,
            tokenizer=tokenizer,  # Replace with tokenizer if required
        )

        # Resize and initialize the medical embedding layer based on the saved config
        enhancer.medical_embedding_layer = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_dim
        )

        # Load the medical embedding and projection layer states
        medical_embedding_state = torch.load(embedding_save_path)
        projection_state = torch.load(projection_save_path)
        enhancer.medical_embedding_layer.load_state_dict(medical_embedding_state)
        enhancer.projection.load_state_dict(projection_state)

        return enhancer, tokenizer



def train_biomedical_translation_model(
    base_model,
    tokenizer,
    english_texts,
    chinese_texts,
    biowordvec_path,
    test_size=0.1,
    batch_size=16,
    learning_rate=1e-4,
    num_train_epochs=3,
    output_dir="./results"
):
    # Prepare datasets
    full_dataset = TranslationDataset.prepare_dataset(english_texts, chinese_texts, tokenizer)
    split_dataset = full_dataset.train_test_split(test_size=test_size, seed=42)

    # Wrap the base model with the enhancer
    enhanced_model = BiomedicalMarianMTEnhancer(
        base_model,
        tokenizer,
        biowordvec_path
    )

    # Define Seq2Seq training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=learning_rate,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=0.01,
        save_safetensors=False,
        num_train_epochs=num_train_epochs,
        logging_dir="./logs",
        logging_steps=500,
        predict_with_generate=True,  # This is essential for seq2seq tasks like translation
        generation_num_beams=4,  # Beam search during generation
        # load_best_model_at_end=True
    )

    # Initialize Seq2SeqTrainer
    trainer = Seq2SeqTrainer(
        model=enhanced_model,
        args=training_args,
        train_dataset=split_dataset["train"],
        eval_dataset=split_dataset["test"],
        tokenizer=tokenizer
    )

    # Train the model
    trainer.train()

    return enhanced_model


In [3]:
# Main execution
# Load pretrained MarianMT model
model_name = "Helsinki-NLP/opus-mt-en-zh"
tokenizer = transformers.MarianTokenizer.from_pretrained(model_name)
base_model = transformers.MarianMTModel.from_pretrained(model_name)

# Load your parallel corpus
dataset = pd.read_parquet("C:\\Users\\Gaming\\Documents\\GitHub\\MIE2\\2024-fall-assignment-linaaron88\\project\\nejm\\nejm_train.parquet")
english_texts = dataset["english"].tolist()
chinese_texts = dataset["chinese"].tolist()

# Train the biomedical translation model
enhanced_model = train_biomedical_translation_model(
    base_model,
    tokenizer,
    english_texts,
    chinese_texts,
    biowordvec_path='C:\\Users\\Gaming\\Documents\\GitHub\\MIE2\\2024-fall-assignment-linaaron88\\project\\BioWordVec_PubMed_MIMICIII_d200.vec.bin',
    num_train_epochs=3
)

  trainer = Seq2SeqTrainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33maalin[0m ([33maalin-uc-berkeley[0m). Use [1m`wandb login --relogin`[0m to force relogin


  5%|▍         | 501/10485 [01:31<30:13,  5.50it/s]

{'loss': 1.104, 'grad_norm': 1.0410629510879517, 'learning_rate': 9.523128278493086e-05, 'epoch': 0.14}


 10%|▉         | 1001/10485 [03:02<28:41,  5.51it/s]

{'loss': 0.9542, 'grad_norm': 1.2827123403549194, 'learning_rate': 9.046256556986171e-05, 'epoch': 0.29}


 14%|█▍        | 1501/10485 [04:32<27:01,  5.54it/s]

{'loss': 0.898, 'grad_norm': 1.0106414556503296, 'learning_rate': 8.569384835479256e-05, 'epoch': 0.43}


 19%|█▉        | 2001/10485 [06:02<25:25,  5.56it/s]

{'loss': 0.8579, 'grad_norm': 1.1509881019592285, 'learning_rate': 8.092513113972342e-05, 'epoch': 0.57}


 24%|██▍       | 2501/10485 [07:33<24:12,  5.50it/s]

{'loss': 0.7751, 'grad_norm': 1.3229390382766724, 'learning_rate': 7.615641392465427e-05, 'epoch': 0.72}


 29%|██▊       | 3001/10485 [09:04<22:39,  5.51it/s]

{'loss': 0.6159, 'grad_norm': 1.4556788206100464, 'learning_rate': 7.138769670958512e-05, 'epoch': 0.86}


                                                    
 33%|███▎      | 3495/10485 [10:58<19:43,  5.90it/s]

{'eval_loss': 0.4532158076763153, 'eval_runtime': 23.7015, 'eval_samples_per_second': 262.135, 'eval_steps_per_second': 16.412, 'epoch': 1.0}


 33%|███▎      | 3501/10485 [11:00<2:49:35,  1.46s/it] 

{'loss': 0.512, 'grad_norm': 1.518500804901123, 'learning_rate': 6.661897949451598e-05, 'epoch': 1.0}


 38%|███▊      | 4001/10485 [12:30<19:29,  5.55it/s]  

{'loss': 0.4448, 'grad_norm': 1.333016276359558, 'learning_rate': 6.185026227944683e-05, 'epoch': 1.14}


 43%|████▎     | 4501/10485 [14:00<17:50,  5.59it/s]

{'loss': 0.4139, 'grad_norm': 0.8244467973709106, 'learning_rate': 5.7081545064377684e-05, 'epoch': 1.29}


 48%|████▊     | 5001/10485 [15:30<16:30,  5.54it/s]

{'loss': 0.3897, 'grad_norm': 1.1088536977767944, 'learning_rate': 5.2312827849308544e-05, 'epoch': 1.43}


 52%|█████▏    | 5501/10485 [16:59<14:55,  5.57it/s]

{'loss': 0.3825, 'grad_norm': 1.311639428138733, 'learning_rate': 4.754411063423939e-05, 'epoch': 1.57}


 57%|█████▋    | 6001/10485 [18:29<13:29,  5.54it/s]

{'loss': 0.3704, 'grad_norm': 1.1753039360046387, 'learning_rate': 4.2775393419170244e-05, 'epoch': 1.72}


 62%|██████▏   | 6501/10485 [19:59<11:57,  5.56it/s]

{'loss': 0.3532, 'grad_norm': 1.1417341232299805, 'learning_rate': 3.80066762041011e-05, 'epoch': 1.86}


                                                    
 67%|██████▋   | 6990/10485 [21:49<09:27,  6.15it/s]

{'eval_loss': 0.335773766040802, 'eval_runtime': 22.875, 'eval_samples_per_second': 271.606, 'eval_steps_per_second': 17.005, 'epoch': 2.0}


 67%|██████▋   | 7001/10485 [21:52<22:12,  2.61it/s]  

{'loss': 0.345, 'grad_norm': 1.3449496030807495, 'learning_rate': 3.323795898903195e-05, 'epoch': 2.0}


 72%|███████▏  | 7501/10485 [23:21<09:01,  5.51it/s]

{'loss': 0.309, 'grad_norm': 0.9399582147598267, 'learning_rate': 2.8469241773962807e-05, 'epoch': 2.15}


 76%|███████▋  | 8001/10485 [24:51<07:30,  5.51it/s]

{'loss': 0.2977, 'grad_norm': 0.895879864692688, 'learning_rate': 2.3700524558893657e-05, 'epoch': 2.29}


 81%|████████  | 8501/10485 [26:21<05:58,  5.54it/s]

{'loss': 0.2991, 'grad_norm': 1.0221785306930542, 'learning_rate': 1.8931807343824514e-05, 'epoch': 2.43}


 86%|████████▌ | 9001/10485 [27:51<04:28,  5.54it/s]

{'loss': 0.2989, 'grad_norm': 1.0727914571762085, 'learning_rate': 1.4163090128755365e-05, 'epoch': 2.58}


 91%|█████████ | 9501/10485 [29:21<02:56,  5.57it/s]

{'loss': 0.2904, 'grad_norm': 0.9679888486862183, 'learning_rate': 9.394372913686218e-06, 'epoch': 2.72}


 95%|█████████▌| 10001/10485 [30:51<01:27,  5.56it/s]

{'loss': 0.292, 'grad_norm': 0.9334255456924438, 'learning_rate': 4.6256556986170724e-06, 'epoch': 2.86}


                                                     
100%|██████████| 10485/10485 [32:42<00:00,  6.13it/s]

{'eval_loss': 0.3116215467453003, 'eval_runtime': 23.1653, 'eval_samples_per_second': 268.203, 'eval_steps_per_second': 16.792, 'epoch': 3.0}


100%|██████████| 10485/10485 [32:44<00:00,  5.34it/s]

{'train_runtime': 1966.5563, 'train_samples_per_second': 85.297, 'train_steps_per_second': 5.332, 'train_loss': 0.49984790254446454, 'epoch': 3.0}





In [4]:
save_dir = ".//saved_embedding_model"
enhanced_model.save_custom(save_dir, tokenizer)



In [3]:
from evaluate import load

def evaluate_model_metrics(predictions, references, save_path=None):
    # Load the evaluation metrics
    bleu_metric = load("bleu")
    rouge_metric = load("rouge")
    bertscore_metric = load("bertscore")
    ter_metric = load("ter")

    # Format references for metric calculation
    references = [[ref] for ref in references]

    # Evaluate BLEU score
    bleu_result = bleu_metric.compute(predictions=predictions, references=references)

    # Evaluate ROUGE score
    rouge_result = rouge_metric.compute(predictions=predictions, references=references)

    # Evaluate BERTScore
    bertscore_result = bertscore_metric.compute(predictions=predictions, references=references, lang="en")

    # Evaluate TER (Translation Edit Rate)
    ter_result = ter_metric.compute(predictions=predictions, references=references)

    # Extract summary statistics for BERTScore
    bertscore_summary = {
        "mean": sum(bertscore_result["f1"]) / len(bertscore_result["f1"]),
        "median": sorted(bertscore_result["f1"])[len(bertscore_result["f1"]) // 2],
        "std": (sum((x - sum(bertscore_result["f1"]) / len(bertscore_result["f1"]))**2 for x in bertscore_result["f1"]) / len(bertscore_result["f1"]))**0.5
    }

    # Consolidate results
    results = {
        "BLEU": bleu_result,
        "ROUGE": rouge_result,
        "BERTScore": bertscore_summary,
        "TER": ter_result,
    }

    return results

In [4]:
class BiomedicalTranslationEvaluator:
    """
    Evaluate the performance of a biomedical translation model.
    """
    def __init__(self, model, tokenizer, device=None):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Move model to the specified device
        self.model.to(self.device)

    def prepare_dataset(self, english_texts, chinese_texts, max_length=512):
        """
        Prepare a dataset for evaluation.
        """
        # Tokenize source (English) texts
        source_encodings = self.tokenizer(
            english_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )

        # Tokenize target (Chinese) texts for comparison (optional)
        target_encodings = self.tokenizer(
            chinese_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        )

        # Move all tensors to the appropriate device
        return {
            "source_input_ids": source_encodings["input_ids"].to(self.device),
            "source_attention_mask": source_encodings["attention_mask"].to(self.device),
            "target_input_ids": target_encodings["input_ids"].to(self.device)
        }

    def generate_translations(self, dataset, batch_size=16):
        translations = []
        for i in trange(0, len(dataset["source_input_ids"]), batch_size):
            batch_input_ids = dataset["source_input_ids"][i:i + batch_size].to(self.device)
            batch_attention_mask = dataset["source_attention_mask"][i:i + batch_size].to(self.device)

            # Generate translations for the batch
            outputs = self.model.generate(
                input_ids=batch_input_ids,
                attention_mask=batch_attention_mask,
                num_beams=3,
                max_length=128,  # Adjust if needed
            )
            translations.extend(self.tokenizer.batch_decode(outputs, skip_special_tokens=True))

        return translations


    def run_evaluation(self, english_texts, chinese_texts):
        """
        Run the evaluation process.
        """
        # Prepare dataset
        dataset = self.prepare_dataset(english_texts, chinese_texts)

        # Generate translations
        translations = self.generate_translations(dataset)

        # Decode target inputs for human-readable comparison
        target_texts = self.tokenizer.batch_decode(
            dataset["target_input_ids"].to("cpu"), skip_special_tokens=True
        )

        return {
            "translations": translations,
            "targets": target_texts
        }


In [9]:
save_dir = ".//saved_embedding_model"
enhanced_model, tokenizer = BiomedicalMarianMTEnhancer.from_custom(save_dir)

C:\Users\Gaming\Documents\GitHub\MIE2\2024-fall-assignment-linaaron88\project\BioWordVec_PubMed_MIMICIII_d200.vec.bin


  medical_embedding_state = torch.load(embedding_save_path)
  projection_state = torch.load(projection_save_path)


In [11]:
# Initialize the evaluator
evaluator = BiomedicalTranslationEvaluator(
    enhanced_model,
    tokenizer
)

test_dataset = pd.read_parquet("C:\\Users\\Gaming\\Documents\\GitHub\\MIE2\\2024-fall-assignment-linaaron88\\project\\nejm\\nejm_test.parquet")
english_test_texts = test_dataset["english"].tolist()
chinese_test_texts = test_dataset["chinese"].tolist()

torch.cuda.empty_cache()
# Run evaluation
results = evaluator.run_evaluation(
    english_test_texts,  # List of English sentences
    chinese_test_texts   # List of Chinese reference translations
)

# Print results
print("Translations:")
print(results["translations"][0:5])
print("\nTargets:")
print(results["targets"][0:5])


100%|██████████| 132/132 [05:07<00:00,  2.33s/it]


Translations:
['是 一种   , 它 结合 了 BCR - ABL1  的 一个   , 通过 与 所有 其他 ABL   不同 的 机制 , 将 BCR - ABL1   .', '•  BCR - ABL1 的    , 包括  T315I  .', '在     患者 中 ,   的  和   尚 不 清楚 .', '在 这项 1 期 剂量  研究 中 , 我们 纳入 了 141 例 慢性  和 9 例  慢性   ( CML ) 患者 , 这些 患者 对 之前 至少 有 两种 ATP     ( TKI ) 产生 的  或 无法 接受 的  产生  .', '主要 目的 是 确定 最大  剂量 或 推荐 剂量 ( 或 两者 ) 的   .']

Targets:
['asciminib 是 与 BCR - ABL1  的  酰  相结合 的  剂 , 它 可 通过 不同于 所有 其他 ABL   的 机制 将 BCR - ABL1 锁定 在 非   .', 'asciminib 同时  作用 于 天然 和  的 BCR - ABL1 , 包括  基因 ( gatekeeper ) T315I  .', 'asciminib 用于     患者 的  和 抗   尚未 明确 .', '在 这项 1 期 剂量  研究 中 , 我们 纳入 了 141 例  和 9 例 加速 期 慢性   ( CML ) 患者 , 这些 患者  对 至少 两种 ATP 竞争性    ( TKI )  或 发生 不可 接受 的  .', '本 试验 的 主要 目的 是 确定 asciminib 的 最大  剂量 或 推荐 剂量 ( 或 这 两者 ) .']


In [20]:
test_dataset["english"][0]

'asciminib is an allosteric inhibitor that binds a myristoyl site of the BCR @-@ ABL1 protein , locking BCR @-@ ABL1 into an inactive conformation through a mechanism distinct from those for all other ABL kinase inhibitors .'

In [18]:
evaluate_model_metrics(results["translations"], results["targets"])

Downloading builder script: 100%|██████████| 7.95k/7.95k [00:00<00:00, 7.94MB/s]
Downloading builder script: 100%|██████████| 9.99k/9.99k [00:00<?, ?B/s]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'BLEU': {'bleu': 0.38178167014553976,
  'precisions': [0.7368004631416442,
   0.47385673490729197,
   0.3332773708841367,
   0.2465452950208379],
  'brevity_penalty': 0.9276629254558418,
  'length_ratio': 0.9301574195401268,
  'translation_length': 51820,
  'reference_length': 55711},
 'ROUGE': {'rouge1': 0.5572755308465489,
  'rouge2': 0.3422585235701242,
  'rougeL': 0.5463369893398838,
  'rougeLsum': 0.5470428823339599},
 'BERTScore': {'mean': 0.9302092782663913,
  'median': 0.9339444637298584,
  'std': 0.06674159867862674},
 'TER': {'score': 46.32471801950705,
  'num_edits': 25505,
  'ref_length': 55057.0}}