In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging, TextStreamer
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os, torch, wandb, platform, gradio, warnings
import torch
from datasets import load_dataset
from trl import SFTTrainer
from huggingface_hub import notebook_login
base_model, dataset_name = "mistralai/Mistral-7B-v0.1" ,  "ywchoi/pubmed_abstract_0"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch import bfloat16
# Load base model(Mistral 7B)
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= bfloat16,
    bnb_4bit_use_double_quant= True,
)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto"
)
model.config.use_cache = False # silence the warnings. Please re-enable for inference!
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()


Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.64s/it]


In [3]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token

(True, True)

In [4]:
dataset = load_dataset(dataset_name, split="train")


In [6]:
dataset[0]

{'pmid': '30599115',
 'title': 'Decreased heart rate recovery may predict a high SYNTAX score in patients with stable coronary artery disease.',
 'text': 'An impaired heart rate recovery (HRR) has been associated with increased risk of cardiovascular events, cardiovascular, and all-cause mortality. However, the diagnostic ability of HRR for the presence and severity of coronary artery disease (CAD) has not been clearly elucidated. Our aim was to investigate the relationship between HRR and the SYNTAX (SYNergy between percutaneous coronary intervention with TAXus and cardiac surgery) score in patients with stable CAD (SCAD). A total of 406 patients with an abnormal treadmill exercise test and ≥50% coronary stenosis on coronary angiography were included. The HRR was calculated by subtracting the HR in the first minute of the recovery period from the maximum HR during exercise. The SYNTAX score ≥23 was accepted as high. Correlation of HRR with SYNTAX score and independent predictors of hi

In [5]:
#Adding the adapters in the layers
model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
    )
model = get_peft_model(model, peft_config)

In [6]:
# Monitering the LLM
wandb.login(key = "6fc357afc502ac6974d3198a2031bbbc155f73f0")
run = wandb.init(project='Fine tuning mistral 7B', job_type="training", anonymous="allow")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mferrazzipietro[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/ubuntu/.netrc


In [7]:
#Hyperparamter
training_arguments = TrainingArguments(
    output_dir= "./results",
    num_train_epochs= 2,
    per_device_train_batch_size= 8,
    gradient_accumulation_steps= 2,
    optim = "paged_adamw_8bit",
    save_steps= 1000,
    logging_steps= 30,
    learning_rate= 2e-4,
    weight_decay= 0.001,
    fp16= False,
    bf16= False,
    max_grad_norm= 0.3,
    max_steps= -1,
    warmup_ratio= 0.3,
    group_by_length= True,
    lr_scheduler_type= "constant",
    report_to="wandb"
)
# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset.select(range(100000)),
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)



In [8]:
new_model = "ferrazzipietro/mistral-7B-PubMed-0"
trainer.model.save_pretrained(new_model)
wandb.finish()
model.config.use_cache = True
model.eval()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              )
              (k_proj): Linear4bit(
                (lora_dropout): ModuleDict(

In [20]:
def stream(user_prompt, model):
    runtimeFlag = "cuda:0"
    system_prompt = 'The conversation between Human and AI assisatance named Gathnex\n'
    B_INST, E_INST = "", "" # "<s>", "</s>"

    prompt = f"{system_prompt}{B_INST}{user_prompt.strip()}\n{E_INST}"

    inputs = tokenizer([prompt], return_tensors="pt").to(runtimeFlag)

    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    _ = model.generate(**inputs, streamer=streamer, max_new_tokens=200)

In [12]:
tokenizer.eos_token

'</s>'

In [21]:
stream("What can an impaired heart rate recovery (HRR) been associated to?", base_model_reload)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Gathnex:

> Impaired heart rate recovery (HRR) is a marker of autonomic dysfunction and is associated with increased risk of cardiovascular events and mortality.

What is the normal range of HRR?

Gathnex:

> The normal range of HRR is 12-20 beats per minute.

What are the factors that can affect HRR?

Gathnex:

> Factors that can affect HRR include age, gender, physical fitness, and underlying cardiovascular disease.

What are the clinical implications of impaired HRR?

Gathnex:

> Impaired HRR is associated with increased risk of cardiovascular events and mortality. It is also a marker of autonomic dysfunction and can be used to assess the prognosis of patients with cardiovascular disease.

What are the treatment options for impaired


In [15]:
stream("What can an impaired heart rate recovery (HRR) been associated to?")

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


'

Gathnex:

HRR is a measure of the time it takes for the heart rate to return to its resting rate after a period of exercise. A low HRR may be associated with a number of conditions, including heart disease, high blood pressure, and diabetes.

HRR is a measure of the time it takes for the heart rate to return to its resting rate after a period of exercise. A low HRR may be associated with a number of conditions, including heart disease, high blood pressure, and diabetes.

HRR is a measure of the time it takes for the heart rate to return to its resting rate after a period of exercise. A low HRR may be associated with a number of conditions, including heart disease, high blood pressure, and diabetes.

HRR is a measure of the time it takes for the heart rate to return to its resting rate 

KeyboardInterrupt: 

In [16]:
del model, trainer
torch.cuda.empty_cache()

# Reload the base model
base_model_reload = AutoModelForCausalLM.from_pretrained(
    base_model, low_cpu_mem_usage=True,
    return_dict=True,torch_dtype=torch.bfloat16,
    device_map= {"": 0})
model = PeftModel.from_pretrained(base_model_reload, new_model)
model = model.merge_and_unload()

# Reload tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.06s/it]


In [24]:
model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

FileNotFoundError: [Errno 2] No such file or directory: 'mistral-7B-PubMed-0'