In [None]:
%%capture

%pip install -U transformers
%pip install -U datasets
%pip install -U accelerate
%pip install -U peft
%pip install -U trl
%pip install -U bitsandbytes
%pip install -U wandb

In [None]:
import os, torch, wandb

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    TrainerCallback,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)

from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from dataclasses import dataclass

In [None]:
@dataclass
class Config:
    model_name = "Qwen/Qwen2.5-0.5B"
    dataset_name = "elvispresniy/synthetic-allenai"
    new_model = "MMP-0.5b-it"
    torch_dtype = torch.bfloat16
    attn_implementation = "eager"
cfg = Config()

In [None]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("huggingface_token")

login(token = hf_token)

wb_token = user_secrets.get_secret("wandb_api_key")

wandb.login(key=wb_token)
run = wandb.init(
    project='SciMMP-0.5b-it', 
    job_type="training", 
    anonymous="allow"
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    device_map="auto",
    attn_implementation=cfg.attn_implementation,
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
model, tokenizer = setup_chat_format(model, tokenizer)
tokenizer.padding_side = 'right'
tokenizer.padding_token = '<|pad|>'

In [None]:
peft_config = LoraConfig(
    r=4,
    lora_alpha=8,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [None]:
dataset = load_dataset(cfg.dataset_name, split="all")

In [None]:
# prefix_string = "<|im_start|>user\n"
# postfix_string = " As soon as you finished return [EOS] token.<|im_end|>\n"
# eos_token = "[EOS]"
prefix_string = "<|im_start|>user\nAnswer only on the subject. Don't be too much verbose. Provide scientific prooves. As soon as the answer is provided return [EOS] token. "
postfix_string = "<|im_end|>\n"
prefix_answer = "assistant\n"
eos_token = "[EOS]"

def preprocess_dataset(row):
    question = row['question'].split(prefix_string)[-1].split(postfix_string)[0]
    answer = row['answer'].split(prefix_answer)[-1].split(eos_token)[0]
    
    text_template = [
        {"role": "user", "content": question},
        {"role": "assistant", "content": answer}
    ]
    
    text = tokenizer.apply_chat_template(text_template, tokenize=False)
    
    for k in list(row.keys()):
        row.pop(k)
        
    return {"text" : text}

In [None]:
SEED = 2025

dataset_sh = dataset.shuffle(seed=SEED).map(preprocess_dataset)
dataset_sh = dataset_sh.train_test_split(50/11679, seed=SEED)
dataset_sh

In [None]:
def generate_text(prompt, max_length=50):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    generated = input_ids.to(model.device)

    for _ in range(max_length):
        outputs = model(generated)
        logits = outputs.logits

        next_token_logits = logits[:, -1, :]
        probabilities = torch.softmax(next_token_logits, dim=-1)

        next_token = torch.multinomial(probabilities, num_samples=1)

        generated = torch.cat((generated, next_token), dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return generated_text

def generate_text_it(prompt, max_length=50):
    prompt = [
        { "role": "user", "content": prompt },
    ]
    prompt = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    generated = input_ids.to(model.device)

    for _ in range(max_length):
        outputs = model(generated)
        logits = outputs.logits

        next_token_logits = logits[:, -1, :]
        probabilities = torch.softmax(next_token_logits, dim=-1)

        next_token = torch.multinomial(probabilities, num_samples=1)

        generated = torch.cat((generated, next_token), dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated[0], skip_special_tokens=True)
    return generated_text

In [None]:
set_of_prompts = ["The results of the experiment demonstrated a significant increase in cell proliferation when exposed to ",
"According to the theory of general relativity, space-time is curved by ",
"The chemical reaction between sodium chloride and water results in the formation of",
"Quantum mechanics suggests that particles exist in multiple states until they are",
"The primary advantage of using CRISPR technology in gene editing is its ability to",
"In a double-blind clinical trial, the control group was administered a placebo while the experimental group received",
"The neural network was trained using a dataset of over one million images, and its performance was evaluated based on ",
"Photosynthesis in plants is driven by light energy, which is absorbed by chlorophyll molecules located in the",
"""The Role of Inflammatory Markers in Cardiovascular Disease
Cardiovascular disease is one of the leading causes of mortality worldwide. Recent studies have highlighted the significance of inflammatory markers such as C-reactive protein (CRP) and interleukin-6 (IL-6) in""",
"""Advances in Nanotechnology for Drug Delivery
Nanotechnology has revolutionized the field of drug delivery by enabling the targeted delivery of therapeutics with minimal side effects. In recent years, several novel nanocarriers such as liposomes, dendrimers, and""",
"""Machine Learning Approaches to Predict Protein-Protein Interactions
Protein-protein interactions (PPIs) play a crucial role in biological processes. However, experimental identification of PPIs is time-consuming and costly. Machine learning models have been developed to""",
"Recent advances in artificial intelligence have led to the development of deep learning models capable of surpassing traditional machine learning algorithms in various tasks. For example, Convolutional Neural Networks (CNNs) have shown remarkable performance in",
"Several studies have investigated the effect of atmospheric CO2 levels on global climate patterns. The seminal work by Smith et al. (2010) demonstrated a clear correlation between ",
"A study measured the blood pressure of 200 patients before and after administering a new antihypertensive drug. On average, blood pressure decreased by 15%. This suggests that the drug...",
"In a recent clinical trial, 60% of participants in the treatment group showed symptom improvement, while only 20% of participants in the control group reported similar results. This indicates that",
"Результаты эксперимента продемонстрировали значительное увеличение пролиферации клеток при воздействии",
"Согласно общей теории относительности, пространство-время искривляется на",]

set_of_prompts_it = ["Explain the role of transcription factors in gene expression, and give an example of how they can regulate cellular differentiation.",
"Summarize the process of mitosis, focusing on the key stages and their significance for cellular division.",
"Describe the evolutionary significance of horizontal gene transfer in bacteria and its impact on antibiotic resistance.",
"How does the structure of the phospholipid bilayer contribute to the selective permeability of the cell membrane?",
"Compare and contrast the mechanisms of SN1 and SN2 nucleophilic substitution reactions, including the factors that favor each.",
"Explain the concept of Gibbs free energy and its relevance in predicting the spontaneity of chemical reactions.",
"Describe how infrared spectroscopy can be used to identify functional groups in an organic compound.",
"Write a step-by-step explanation of the process of acid-base titration, including how to determine the equivalence point.",
"Explain how the Heisenberg Uncertainty Principle limits our ability to measure both the position and momentum of a particle simultaneously.",
"Describe the fundamental differences between general relativity and quantum mechanics, and explain why they are difficult to reconcile.",
"In the context of thermodynamics, explain the second law and how it relates to entropy in isolated systems.",
"Provide a detailed explanation of how the photoelectric effect supports the quantum theory of light."]

In [None]:
class TextGenerationCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.cnt = 0
        self.cnt_it = 0

    def on_evaluate(self, args, state, control, **kwargs):
        model.eval()
#         sample_input = set_of_prompts[self.cnt]
#         self.cnt = (self.cnt + 1) % len(set_of_prompts)
#         generated_text = generate_text(sample_input)
#         print(f"Generated text at step {state.global_step}: {generated_text}")
        
        sample_input = set_of_prompts_it[self.cnt_it]
        self.cnt_it = (self.cnt_it + 1) % len(set_of_prompts_it)
        generated_text = generate_text_it(sample_input)
        print(f"Generated instruct text at step {state.global_step}: {generated_text}")
        
        !rm -r /kaggle/working/MMP-0.5b-it/checkpoint-*

In [None]:
training_arguments = TrainingArguments(
    output_dir=cfg.new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=1000,
    logging_steps=100,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=True,
    group_by_length=True,
    report_to="wandb",
    run_name="SciMMP-0.5b-it",
#     lr_scheduler_type='linear'
    
    push_to_hub=True,
    hub_model_id="elvispresniy/SciMMP-0.5b-it",
    hub_strategy="every_save",
)

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_sh["train"],
    eval_dataset=dataset_sh["test"],
    peft_config=peft_config,
    max_seq_length=4096,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
    callbacks=[TextGenerationCallback()]
)

In [None]:
trainer.train()

In [None]:
model = model.merge_and_unload()