In [1]:
from transformers import AutoTokenizer
import torch
import deepspeed
import pandas as pd
from datasets import Dataset, load_dataset, load_from_disk
from transformers import AutoTokenizer

print(torch.cuda.is_available())      # Should print True if a GPU is detected
print(torch.cuda.device_count())      # Number of GPUs available
print(torch.cuda.get_device_name(0))  # Name of the first GPU (if available)
print(deepspeed.__version__)


[2025-06-06 09:51:26,795] [INFO] [real_accelerator.py:239:get_accelerator] Setting ds_accelerator to cuda (auto detect)


W0606 09:51:31.215000 12632 site-packages\torch\distributed\elastic\multiprocessing\redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


True
1
NVIDIA GeForce RTX 5090
0.16.5


In [None]:
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en')
dataset = dataset['train'].train_test_split(test_size=0.1)
print(len(dataset["train"]))
print(len(dataset["test"]))
print(dataset["test"].to_pandas().head())


17733
1971
                                            Question  \
0  In a newborn baby presenting with an absent an...   
1  A patient presents with back pain, elevated ES...   
2  In a 45-year-old man with AIDS presenting with...   
3  A patient with psoriasis was treated with syst...   
4  A 50-year-old man, who underwent a kidney tran...   

                                         Complex_CoT  \
0  Alright, let’s think this through step by step...   
1  Alright, let's see what's going on with this p...   
2  Alright, so let's think about this. We've got ...   
3  I'm thinking about this patient with psoriasis...   
4  So, we have a 50-year-old man who had a kidney...   

                                            Response  
0  In a newborn presenting with an absent anal or...  
1  Given the presentation of back pain, elevated ...  
2  In a 45-year-old man with AIDS presenting with...  
3  The development of generalized pustules in a p...  
4  If a biopsy is taken from the transpl

In [3]:
model_name = "mistralai/Mistral-7B-v0.1"  # Base model, replace with your preferred Mistral version
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token  # Set padding token

In [4]:
# Function to tokenize and format the data
def preprocess_function(examples):
    # Format: <s>[INST] {instruction} [/INST] {response}</s>
    prompts = [f"[INST] {question} [/INST]" for question in examples['Question']]
    responses = examples['Response']
    
    # Combine prompts and responses
    texts = []
    for prompt, response in zip(prompts, responses):
        texts.append(f"{prompt} {response}")
    
    print(texts[:1])  # Print first 5 formatted texts for debugging
    # Tokenize
    tokenized = tokenizer(
        texts,
        truncation=True,
        max_length=512,  # Adjust based on your needs
        padding="max_length",
        return_tensors="pt"
    )
    
    # Prepare labels for the decoder (same as input_ids for causal language modeling)
    tokenized["labels"] = tokenized["input_ids"].clone()
    
    return tokenized

In [5]:
tokenized_dataset = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

Map:   0%|          | 0/17733 [00:00<?, ? examples/s]

["[INST] A patient named Sunder is admitted to the emergency department with severe bradycardia due to a drug overdose. This drug was being used to treat his hypertension. Based on the pharmacological effects, which drug being used for hypertension would not cause bradycardia as a side effect? [/INST] In the scenario where a patient like Sunder is experiencing severe bradycardia due to an overdose of a hypertension medication, it's essential to focus on the type of medication that wouldn't typically cause a slow heart rate. Lisinopril, an ACE inhibitor, is one such drug. ACE inhibitors, like lisinopril, primarily act on the renin-angiotensin-aldosterone system to reduce blood pressure by lowering blood vessel constriction and fluid volume without having a direct effect on heart rate. Therefore, lisinopril does not cause bradycardia as a side effect, making it an unlikely culprit in this scenario."]
["[INST] An 11-year-old boy has a 4-day history of increasing left-sided pain below his 

Map:   0%|          | 0/1971 [00:00<?, ? examples/s]

["[INST] In a newborn baby presenting with an absent anal orifice and meconuria, what is the most appropriate immediate management approach to address these symptoms? [/INST] In a newborn presenting with an absent anal orifice and meconuria, the most appropriate immediate management approach is to perform a colostomy. This procedure addresses the critical issue of intestinal obstruction by creating an opening for the passage of feces, thereby preventing complications like bowel perforation and reducing the risk of infection associated with the fistula. By relieving the obstruction, it stabilizes the infant's condition and allows time for further evaluation and definitive surgical planning for the anorectal malformation."]
["[INST] One day after undergoing a left carotid endarterectomy, a 63-year-old man has a severe headache. He describes it as 9 out of 10 in intensity. He has nausea. He had 80% stenosis in the left carotid artery and received heparin prior to the surgery. He has a his

In [6]:
print(tokenized_dataset["train"].to_pandas().head())

                                           input_ids  \
0  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...   
1  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...   
2  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...   
3  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...   
4  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...   

                                      attention_mask  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
1  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
2  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
3  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   

                                              labels  
0  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...  
1  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...  
2  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...  
3  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...  
4  [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ..

In [7]:
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
import torch

In [8]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",  # Automatically distribute across available GPUs
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
def try_import_wandb():
    try:
        import wandb
        return True
    except ImportError:
        return False

In [10]:
training_args = TrainingArguments(
    output_dir="./mistral-finetuned",
    overwrite_output_dir=True,
    num_train_epochs=1,  # 2-4 is typical; increase for more thorough training
    per_device_train_batch_size=4,  # Try 2 or 4; increase if VRAM allows, decrease if OOM
    per_device_eval_batch_size=4,   # Match train batch size for eval
    gradient_accumulation_steps=8,  # Increase if you want a larger effective batch size
    save_steps=1000,                # Save every 1000 steps
    save_total_limit=1,             # Keep last 3 checkpoints
    eval_steps=500,                 # Evaluate every 500 steps
    logging_dir="./logs",
    logging_steps=50,               # Log more frequently for better monitoring
    learning_rate=1.3566154061739588e-05,             # Good starting point for LLMs
    weight_decay=0.01,
    warmup_steps=200,               # Slightly higher for stability with larger datasets
    lr_scheduler_type="cosine",
    report_to="wandb" if try_import_wandb() else "none",
    fp16=False,                     # Use bf16 if your GPU supports it (RTX 5090 does)
    bf16=True,
    gradient_checkpointing=True,
    # DeepSpeed configuration (optional, for multi-GPU training)
    # deepspeed="ds_config.json"  # Uncomment and specify your DeepSpeed config file
)

In [11]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

In [12]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mluis-orellana777[0m ([33mluis-orellana777-sngular[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
50,1.8294
100,0.4762
150,0.4849
200,0.5147
250,0.5181
300,0.5127
350,0.5039
400,0.4848
450,0.4669
500,0.4704


TrainOutput(global_step=554, training_loss=0.6112520815233031, metrics={'train_runtime': 87269.0943, 'train_samples_per_second': 0.203, 'train_steps_per_second': 0.006, 'total_flos': 3.8724950686275994e+17, 'train_loss': 0.6112520815233031, 'epoch': 0.9995489400090212})

In [None]:
trainer.save_model("./model-mistral-finetuned-medical-reasoning")

In [None]:
tokenizer.save_pretrained("./tokenizer-mistral-finetuned-medical-reasoning")

('./tokenizer-mistral-finetuned-medical-reasoning\\tokenizer_config.json',
 './tokenizer-mistral-finetuned-medical-reasoning\\special_tokens_map.json',
 './tokenizer-mistral-finetuned-medical-reasoning\\tokenizer.model',
 './tokenizer-mistral-finetuned-medical-reasoning\\added_tokens.json',
 './tokenizer-mistral-finetuned-medical-reasoning\\tokenizer.json')

In [None]:
trainer.push_to_hub("Luis-Orellana777/model-mistral-finetuned-medical-reasoning")

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.65k [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/Luis-Orellana777/mistral-finetuned/commit/4d7713346a3f24480dd6a737bcc4f109c00d2e66', commit_message='Luis-Orellana777/model-mistral-finetuned-medical-reasoning', commit_description='', oid='4d7713346a3f24480dd6a737bcc4f109c00d2e66', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Luis-Orellana777/mistral-finetuned', endpoint='https://huggingface.co', repo_type='model', repo_id='Luis-Orellana777/mistral-finetuned'), pr_revision=None, pr_num=None)

In [None]:
tokenizer.push_to_hub("Luis-Orellana777/model-mistral-finetuned-medical-reasoning")

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Luis-Orellana777/tokenizer-mistral-finetuned-medical-reasoning/commit/31241a69f29716ee8967cfbcdf266199b3c6f860', commit_message='Upload tokenizer', commit_description='', oid='31241a69f29716ee8967cfbcdf266199b3c6f860', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Luis-Orellana777/tokenizer-mistral-finetuned-medical-reasoning', endpoint='https://huggingface.co', repo_type='model', repo_id='Luis-Orellana777/tokenizer-mistral-finetuned-medical-reasoning'), pr_revision=None, pr_num=None)

In [17]:
def generate_response(prompt, model, tokenizer, max_length=512):
    input_ids = tokenizer.encode(f"[INST] {prompt} [/INST]", return_tensors="pt").to(model.device)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(model.device)
    
    output = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        pad_token_id=tokenizer.pad_token_id  # Explicitly set pad_token_id
    )
    
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [18]:
test_prompt = "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?"
response = generate_response(test_prompt, model, tokenizer)
print(response)



[INST] Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings? [/INST] The symptoms you've described—sudden weakness in the left arm and leg, recent long-distance travel, and swollen and tender right lower leg—are highly suggestive of a condition known as embolism. Embolism occurs when a blood clot, typically originating from the deep veins of the legs (a condition known as deep vein thrombosis or DVT), breaks loose and travels to the heart.

In this scenario, the embolism is most likely to come from the right lower leg, as evidenced by the swelling and tenderness. The clot then travels to the heart and can subsequently be dislodged, causing an embolic event in the cerebral arteries, leading to the sudden weakness in the left arm and leg.

The recent long-distance travel is a s