In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset, Dataset
from trl import SFTTrainer
from tqdm import tqdm
import json
import gc

In [2]:
ds = load_dataset("lmsys/lmsys-chat-1m")
model_name = "Qwen/QwQ-32B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto", 
    cache_dir="/n/holylfs06/LABS/krajan_lab/Lab/cfang/hf_cache/"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/14 [00:00<?, ?it/s]

# Helper functions

In [5]:
def format_conversation_for_qwq(user_prompt):
    """Format conversation for QwQ-32B input"""
    # QwQ expects specific formatting - adjust based on model's chat template
    formatted = tokenizer.apply_chat_template(
        user_prompt,  # Exclude the last assistant message we want to replace
        tokenize=False,
        add_generation_prompt=True
    )
    return formatted

In [None]:
# def generate_response(prompt, max_new_tokens=2048, temperature=0.7, top_p=0.9):
#     """Generate response using QwQ-32B"""
#     inputs = tokenizer(prompt, return_tensors="pt", max_length=4096)
#     inputs = inputs.to(model.device)
    
#     with torch.no_grad():
#         outputs = model.generate(
#             **inputs,
#             max_new_tokens=max_new_tokens,
#             temperature=temperature,
#             top_p=top_p,
#             do_sample=True,
#         )
    
#     response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
#     return response.strip()

In [None]:
# def process_conversation(conversation):
#     """Process a single conversation from the dataset"""

#     user_prompt = conversation[0]
#     assert user_prompt['role'] == 'user'
#     formatted_prompt = format_conversation_for_qwq([user_prompt])
#     new_response = generate_response(formatted_prompt)
#     new_conversation = [user_prompt, {'role': 'assistant', 'content': new_response}]
#     return new_conversation

In [None]:
# def process_conversation(conversation):
#     """Process a single conversation from the dataset"""

#     user_prompt = conversation[0]
#     assert user_prompt['role'] == 'user'
#     formatted_prompt = format_conversation_for_qwq([user_prompt])
#     new_response = generate_response(formatted_prompt)
#     new_conversation = [user_prompt, {'role': 'assistant', 'content': new_response}]
#     return new_conversation

In [None]:
# def process_dataset_batch(dataset, batch_size=1, max_samples=None):
#     """Process dataset in batches to manage memory"""
#     processed_conversations = []
    
#     # Limit samples if specified
#     if max_samples:
#         dataset = dataset.select(range(min(max_samples, len(dataset))))
    
#     for i in tqdm(range(0, len(dataset), batch_size), desc="Processing conversations"):
#         batch = dataset[i:i+batch_size]
        
#         for conversation in batch['conversation']:
#             processed_conv = process_conversation(conversation)
#             if processed_conv:
#                 processed_conversations.append({'conversation': processed_conv})
        
#         # Clear GPU memory periodically
#         if i % 100 == 0:
#             torch.cuda.empty_cache()
#             gc.collect()
    
#     return processed_conversations

In [3]:
def process_dataset_batch(dataset, batch_size=1, max_samples=None):
    """Process dataset in batches to manage memory and parallelize generation"""
    processed_conversations = []
    
    # Limit samples if specified
    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))
    
    for i in tqdm(range(0, len(dataset), batch_size), desc="Processing conversations"):
        batch = dataset[i:i+batch_size]
        user_prompts = []
        user_turns = []
        for conversation in batch['conversation']:
            user_prompt = conversation[0]
            assert user_prompt['role'] == 'user'
            formatted_prompt = format_conversation_for_qwq([user_prompt])
            user_prompts.append(formatted_prompt)
            user_turns.append(user_prompt)
        
        # Tokenize as a batch
        inputs = tokenizer(
            user_prompts,
            return_tensors="pt",
            padding_side='left',
            padding=True,
            truncation=True,
            max_length=4096
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=500,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
            )
        
        # Decode only the generated part for each sample
        for idx in range(len(user_prompts)):
            input_len = inputs['input_ids'][idx].shape[0]
            generated_ids = outputs[idx][inputs['input_ids'][idx].shape[0]:]
            response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            new_conversation = [user_turns[idx], {'role': 'assistant', 'content': response}]
            processed_conversations.append({'conversation': new_conversation})
        
        # Clear GPU memory periodically
        if i % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    return processed_conversations

In [6]:
# Process the dataset
print("Starting dataset processing...")
processed_data = process_dataset_batch(
    ds['train'], 
    batch_size=5,
    max_samples=5  # Adjust based on your needs and compute resources
)

Starting dataset processing...


Processing conversations: 100%|██████████| 1/1 [00:20<00:00, 20.59s/it]


In [7]:
processed_data

[{'conversation': [{'content': 'how can identity protection services help protect me against identity theft',
    'role': 'user'},
   {'role': 'assistant',
    'content': "Okay, so I need to figure out how identity protection services help protect against identity theft. Let me start by recalling what I know about identity theft. It's when someone steals your personal information, like your Social Security number or credit card details, to commit fraud or other crimes. That's pretty scary. Now, identity protection services—maybe they offer some tools or services to prevent this from happening?\n\nFirst, I think they might monitor credit reports. I know there are three major credit bureaus: Equifax, Experian, and TransUnion. Maybe these services check all three for unauthorized activity? Like, if someone suddenly opens a new account in your name, the service would detect that and alert you. That makes sense because if you get a notification right away, you can act before things get wors

In [None]:
# Create new dataset
new_dataset = Dataset.from_list(processed_data)
print(f"Created dataset with {len(new_dataset)} conversations")

# Save the processed dataset
new_dataset.save_to_disk("qwq_generated_dataset")

# Clear the generation model to free memory for training
del model
torch.cuda.empty_cache()
gc.collect()


In [None]:
# Prepare for SFT Training
def format_conversation_for_sft(example):
    """Format conversation for SFT training"""
    conversation = example['conversation']
    
    # Format as a single string for SFT
    formatted_text = ""
    for turn in conversation:
        if turn['role'] == 'user':
            formatted_text += f"<|im_start|>user\n{turn['content']}<|im_end|>\n"
        elif turn['role'] == 'assistant':
            formatted_text += f"<|im_start|>assistant\n{turn['content']}<|im_end|>\n"
    
    return {"text": formatted_text}


In [None]:
# Format dataset for SFT
sft_dataset = new_dataset.map(format_conversation_for_sft, remove_columns=new_dataset.column_names)

# Load model for training (you might want to use a smaller model for SFT)
training_model_name = "Qwen/QwQ-32B"  # or use a smaller Qwen model
training_model = AutoModelForCausalLM.from_pretrained(
    training_model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="/n/holylfs06/LABS/krajan_lab/Lab/cfang/hf_cache/"
)

training_tokenizer = AutoTokenizer.from_pretrained(training_model_name)
if training_tokenizer.pad_token is None:
    training_tokenizer.pad_token = training_tokenizer.eos_token

# Training arguments optimized for QwQ
training_args = TrainingArguments(
    output_dir="./qwq-sft-output",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,  # Adjust based on your GPU memory
    gradient_accumulation_steps=8,
    learning_rate=5e-6,  # Lower learning rate for large models
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    evaluation_strategy="no",
    save_strategy="steps",
    load_best_model_at_end=False,
    report_to=None,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    fp16=True,  # Use mixed precision
    gradient_checkpointing=True,  # Save memory
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,
)

# Initialize SFT Trainer
trainer = SFTTrainer(
    model=training_model,
    args=training_args,
    train_dataset=sft_dataset,
    tokenizer=training_tokenizer,
    max_seq_length=4096,  # Adjust based on your needs
    packing=False,  # Keep conversations intact
    dataset_text_field="text",
)

# Start training
print("Starting SFT training...")
trainer.train()

# Save the final model
trainer.save_model("./qwq-sft-final")
training_tokenizer.save_pretrained("./qwq-sft-final")

print("Training completed!")