In [1]:
from datasets import load_dataset, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
import torch
import pandas as pd
import accelerate
from peft import LoraConfig, get_peft_model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = "Qwen/Qwen2-1.5B"
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config
)
model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1536)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear4bit(in_features=1536, out_features=1536, bias=True)
          (k_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (v_proj): Linear4bit(in_features=1536, out_features=256, bias=True)
          (o_proj): Linear4bit(in_features=1536, out_features=1536, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (up_proj): Linear4bit(in_features=1536, out_features=8960, bias=False)
          (down_proj): Linear4bit(in_features=8960, out_features=1536, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((1536,), eps=1e-06)


In [3]:
data = load_dataset("openlifescienceai/medmcqa", split="train")
data_df = data.to_pandas()
data_df.head()

Unnamed: 0,id,question,opa,opb,opc,opd,cop,choice_type,exp,subject_name,topic_name
0,e9ad821a-c438-4965-9f77-760819dfa155,Chronic urethral obstruction due to benign pri...,Hyperplasia,Hyperophy,Atrophy,Dyplasia,2,single,Chronic urethral obstruction because of urinar...,Anatomy,Urinary tract
1,e3d3c4e1-4fb2-45e7-9f88-247cc8f373b3,Which vitamin is supplied from only animal sou...,Vitamin C,Vitamin B7,Vitamin B12,Vitamin D,2,single,Ans. (c) Vitamin B12 Ref: Harrison's 19th ed. ...,Biochemistry,Vitamins and Minerals
2,5c38bea6-787a-44a9-b2df-88f4218ab914,All of the following are surgical options for ...,Adjustable gastric banding,Biliopancreatic diversion,Duodenal Switch,Roux en Y Duodenal By pass,3,multi,"Ans. is 'd' i.e., Roux en Y Duodenal Bypass Ba...",Surgery,Surgical Treatment Obesity
3,cdeedb04-fbe9-432c-937c-d53ac24475de,Following endaerectomy on the right common car...,Central aery of the retina,Infraorbital aery,Lacrimal aery,Nasociliary aretry,0,multi,The central aery of the retina is a branch of ...,Ophthalmology,
4,dc6794a3-b108-47c5-8b1b-3b4931577249,Growth hormone has its effect on growth through?,Directly,IG1-1,Thyroxine,Intranuclear receptors,1,single,"Ans. is 'b' i.e., IGI-1GH has two major functi...",Physiology,


In [4]:
data_df["input"] = data_df[["question", "opa", "opb", "opc", "opd"]].apply(
    lambda x: f"Question: {x['question']} \n Choice: \n A. {x['opa']} \n B. {x['opb']} \n C. {x['opc']} \n D. {x['opd']}", axis=1
)

data_df["exp"] = "Context: " + data_df["exp"]
data_df = data_df[["input", "exp"]]
data_df.head()

Unnamed: 0,input,exp
0,Question: Chronic urethral obstruction due to ...,Context: Chronic urethral obstruction because ...
1,Question: Which vitamin is supplied from only ...,Context: Ans. (c) Vitamin B12 Ref: Harrison's ...
2,Question: All of the following are surgical op...,"Context: Ans. is 'd' i.e., Roux en Y Duodenal ..."
3,Question: Following endaerectomy on the right ...,Context: The central aery of the retina is a b...
4,Question: Growth hormone has its effect on gro...,"Context: Ans. is 'b' i.e., IGI-1GH has two maj..."


In [5]:
data_df["input"] = data_df["input"].apply(lambda x: " ".join(x) if isinstance(x, list) else str(x))
data_df["exp"] = data_df["exp"].apply(lambda x: " ".join(x) if isinstance(x, list) else str(x))

data_df["text"] = data_df["input"] + " " + data_df["exp"]
data = Dataset.from_pandas(data_df)

def tokenize(sample):
    #print(f"[DEBUG] Input type: {type(sample['input'])}, Sample input: {sample['input'][:100]}")  # 打印前100个字符作为示例
    #print(f"[DEBUG] Exp type: {type(sample['exp'])}, Sample exp: {sample['exp'][:100]}")  # 打印前100个字符作为示例

    inputs = tokenizer(sample["input"], padding=True, truncation=True, max_length=512)
    labels = tokenizer(sample["exp"], padding=True, truncation=True, max_length=512)

    #print(f"[DEBUG] Inputs (input_ids): {inputs['input_ids'][:10]}")  # 打印前10个 token id
    #print(f"[DEBUG] Labels (input_ids): {labels['input_ids'][:10]}")  # 打印前10个 token id

    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "labels": labels["input_ids"]  # 使用 "exp" 列作为标签
    }

tokenized_data = data.map(tokenize, batched=True, desc="Tokenizing data", remove_columns=["input", "exp", "text"])
tokenized_data

Tokenizing data: 100%|██████████| 182822/182822 [00:44<00:00, 4117.08 examples/s]


Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 182822
})

In [None]:
eval_data = load_dataset("openlifescienceai/medmcqa", split="train")

In [None]:
training_args = TrainingArguments(
    output_dir="./qwen_gsm8k_finetune",
    report_to = 'none',
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=5e-5,
    lr_scheduler_type="cosine",
    save_strategy="epoch",
    logging_steps=100,
    max_steps=1000,
    fp16 = True,
    #num_train_epochs=1,
)


trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_data,
    eval_dataset = eval_data,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

In [None]:
trainer.train()
print("The training process is ended.")

torch.cuda.empty_cache()
print("Cuda cache is removed.")

#################################TODO####################################
```py
gsm8k = load_dataset("openai/gsm8k", "main", split='test')

class GSM8KDataset(Dataset):
    def __init__(self, data):
        self.questions = data['question']

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        return self.questions[idx]

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = GSM8KDataset(gsm8k)
batch_size = 16 # Adjust as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2-1.5B', trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    '/content/qwen_gsm8k_finetune/checkpoint-1000',
    trust_remote_code=True,
    torch_dtype=torch.float16,
).to(device)

if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = 'left'

model.config.pad_token_id = tokenizer.pad_token_id

model.eval()
model = torch.compile(model)
results = []

max_results = 160  # Adjust as needed
num_batches = (max_results + batch_size - 1) // batch_size

with torch.inference_mode():
    for batch in tqdm(dataloader, desc="Inference", unit="batch", total=num_batches):
        inputs = tokenizer(
            batch,
            return_tensors='pt',
            padding=True,
            truncation=True,
            max_length=512
        ).to(device)

        outputs = model.generate(
            **inputs,
            max_new_tokens=128,
            num_beams=1,
            do_sample=True,
            temperature = 0.2,
            top_p = 0.9,
            top_k = 50,
            repetition_penalty = 1.0,
            length_penalty = 1.0,
            pad_token_id=tokenizer.pad_token_id
        )

        answers = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)

        for question, answer in zip(batch, answers):
            results.append({'Question': question, 'Answer': answer})

        # Comment out if for running whole dataset
        if len(results) >= max_results:
            break

df = pd.DataFrame(results)
display(df)
df.to_csv("output.csv", index=False)
```
