In [1]:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from evaluate import load
from trl import SFTTrainer
from peft import LoraConfig
import numpy as np
from transformers import logging
logging.set_verbosity_error()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint_path = '/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/retrain_checkpoints/checkpoint-62046'
config_path = "/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/checkpoints/checkpoint-3236/config.json"

config = MambaConfig.from_pretrained(config_path)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaForCausalLM.from_pretrained(checkpoint_path, config=config)

In [None]:
dataset_dir = '/scratch/vetgpt/repo/MedVetGPT/qa_generate/0508_short2_nodigit/'

datasets = load_dataset('json', data_files={
    'train': dataset_dir + 'train.json',
    'test': dataset_dir + 'test.json'
})

In [4]:
# Define the tokenization function
def tokenize_function(examples):
    inputs = [q for q in examples["Question"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = tokenizer(examples["Answer"], max_length=128, truncation=True, padding="max_length")["input_ids"]
    return model_inputs

# Apply tokenization to both train and test datasets
tokenized_datasets = datasets.map(tokenize_function, batched=True)

In [6]:
tokenized_datasets["train"].num_rows

558894

In [None]:
rouge_metric = load("rouge")

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge_metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}  # Scale scores by 100

    result["avg_rouge"] = np.mean(list(result.values()))
    return result

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
from datasets import load_dataset

training_args = TrainingArguments(
    output_dir="/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/mamba_finetune_results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    logging_dir='/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/mamba_finetune_results/mamba_logs',
    logging_steps=100,
    learning_rate=2e-3,
    disable_tqdm=False,
    save_steps=500, 
    save_total_limit=3      
)

lora_config =  LoraConfig(
        r=8,
        target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
        task_type="CAUSAL_LM",
        bias="none"
)
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"]
)


In [None]:
trainer.train()

In [None]:
eval_results = trainer.evaluate()

print(eval_results)

In [15]:
import torch
from torch.cuda.amp import autocast

checkpoint_path = "/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/mamba_finetune_results/checkpoint-100"
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
model = MambaForCausalLM.from_pretrained(checkpoint_path)
model.to("cuda")

from datasets import load_dataset

# Load the test dataset
dataset_dir = '/scratch/vetgpt/repo/MedVetGPT/qa_generate/0508_short2_nodigit/'
test_dataset = load_dataset('json', data_files={'test': dataset_dir + 'test.json'})['test']

# Tokenize the test dataset
def preprocess_function(examples):
    inputs = [q for q in examples["Question"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    return model_inputs

tokenized_test_dataset = test_dataset.map(preprocess_function, batched=True)

def generate_predictions(batch):
    inputs = tokenizer(batch["Question"], return_tensors="pt", padding=True, truncation=True, max_length=128).input_ids
    inputs = inputs.to("cuda")
    
    with torch.no_grad(), autocast():
        outputs = model.generate(inputs, max_new_tokens=30, num_beams=1)
    
    batch["predictions"] = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return batch

predicted_test_dataset = test_dataset.map(generate_predictions, batched=True, batch_size=128)


  with torch.no_grad(), autocast():
Map: 100%|██████████| 62100/62100 [23:30<00:00, 44.03 examples/s]


In [None]:
from rouge_score import rouge_scorer
import pandas as pd

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

rouge1_scores = {'r': [], 'p': [], 'f': []}
rouge2_scores = {'r': [], 'p': [], 'f': []}
rougeL_scores = {'r': [], 'p': [], 'f': []}

for pred, ref in zip(predicted_test_dataset["predictions"], predicted_test_dataset["Answer"]):
    scores = scorer.score(ref, pred)
    rouge1_scores['r'].append(scores['rouge1'].recall)
    rouge1_scores['p'].append(scores['rouge1'].precision)
    rouge1_scores['f'].append(scores['rouge1'].fmeasure)
    
    rouge2_scores['r'].append(scores['rouge2'].recall)
    rouge2_scores['p'].append(scores['rouge2'].precision)
    rouge2_scores['f'].append(scores['rouge2'].fmeasure)
    
    rougeL_scores['r'].append(scores['rougeL'].recall)
    rougeL_scores['p'].append(scores['rougeL'].precision)
    rougeL_scores['f'].append(scores['rougeL'].fmeasure)

rouge1_avg = {k: round(sum(v)/len(v), 4) for k, v in rouge1_scores.items()}
rouge2_avg = {k: round(sum(v)/len(v), 4) for k, v in rouge2_scores.items()}
rougeL_avg = {k: round(sum(v)/len(v), 4) for k, v in rougeL_scores.items()}

data = {
    "Model": ["MambaLMHead"],
    "ROUGE-1 r": [rouge1_avg['r']],
    "ROUGE-1 p": [rouge1_avg['p']],
    "ROUGE-1 f": [rouge1_avg['f']],
    "ROUGE-2 r": [rouge2_avg['r']],
    "ROUGE-2 p": [rouge2_avg['p']],
    "ROUGE-2 f": [rouge2_avg['f']],
    "ROUGE-L r": [rougeL_avg['r']],
    "ROUGE-L p": [rougeL_avg['p']],
    "ROUGE-L f": [rougeL_avg['f']]
}

df = pd.DataFrame(data)

print(df)

df.to_csv("/scratch/vetgpt/vetgpt-rlp/mamba/mamba_ssm/rouge_scores.csv", index=False, float_format="%.4f")

         Model ROUGE-1 r ROUGE-1 p ROUGE-1 f ROUGE-2 r ROUGE-2 p ROUGE-2 f  \
0  MambaLMHead    0.2735    0.0701    0.1118    0.0369    0.0072    0.0122   

  ROUGE-L r ROUGE-L p ROUGE-L f  
0    0.2456    0.0614    0.0954  
