In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset, Dataset
from trl import SFTTrainer
from tqdm import tqdm
import json
import gc
import os

In [2]:
from config import storage_dir

In [6]:
model_name = "Qwen/Qwen3-30B-A3B"
model_storage_dir = os.path.join(storage_dir, "lm_sys", model_name.split("/")[-1])
response_path = os.path.join(model_storage_dir, 'lm_sys_responses')

In [7]:
from datasets import load_from_disk

dataset = load_from_disk(response_path)
print(dataset)

Dataset({
    features: ['conversation'],
    num_rows: 256
})


In [8]:
dataset[0]

{'conversation': [{'content': 'how can identity protection services help protect me against identity theft',
   'role': 'user'},
  {'content': "Identity protection services can play a crucial role in helping you protect yourself against identity theft by monitoring, detecting, and responding to potential threats. Here's how they can help:\n\n---\n\n### **1. Continuous Monitoring of Your Personal Information**\n- **Credit Reports:** Many services monitor your credit reports for suspicious activity, such as new accounts or loans opened in your name.\n- **Dark Web Monitoring:** They scan the dark web (hidden parts of the internet) for your personal information, like Social Security numbers, credit card details, or login credentials.\n- **Public Records:** They check public records for any unauthorized use of your name, address, or other personal details.\n\n---\n\n### **2. Early Detection of Fraud**\n- **Unusual Activity Alerts:** If someone tries to open a new account, apply for a loan, 

In [9]:
dataset[1]

{'conversation': [{'content': "Beside OFAC's selective sanction that target the listed individiuals and entities, please elaborate on the other types of US's sanctions, for example, comprehensive and sectoral sanctions. Please be detailed as much as possible",
   'role': 'user'},
  {'content': 'The U.S. sanctions regime is a multifaceted and comprehensive tool used to advance foreign policy and national security objectives. While the **Office of Foreign Assets Control (OFAC)** within the U.S. Department of the Treasury is primarily responsible for enforcing **targeted sanctions** (such as **individual and entity-specific sanctions**), the U.S. government also employs a range of **other types of sanctions**, including **comprehensive sanctions**, **sectoral sanctions**, **embargoes**, and **economic sanctions**. These are typically imposed by the **Department of the Treasury**, the **Department of State**, and the **Department of Commerce**, and are often authorized under various **stat

In [None]:
if not os.path.exists(model_storage_dir):
    os.makedirs(model_storage_dir)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto", 
    cache_dir="/n/holylfs06/LABS/krajan_lab/Lab/cfang/hf_cache/"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


# Load dataset

In [4]:
ds = load_dataset("lmsys/lmsys-chat-1m")


# Helper functions

In [5]:
def format_conversation_for_qwen(user_prompt):
    """Format conversation for Qwen3-30B-A3B input with enable_thinking=False"""
    formatted = tokenizer.apply_chat_template(
        user_prompt,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False  # <-- Set this flag
    )
    return formatted

In [6]:
def process_dataset_batch(dataset, batch_size=1, max_samples=None, context_length=4096):
    """Process dataset in batches to manage memory and parallelize generation"""
    processed_conversations = []
    
    # Limit samples if specified
    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))
    
    for i in tqdm(range(0, len(dataset), batch_size), desc="Processing conversations"):
        batch = dataset[i:i+batch_size]
        user_prompts = []
        user_turns = []
        for conversation in batch['conversation']:
            user_prompt = conversation[0]
            assert user_prompt['role'] == 'user'
            formatted_prompt = format_conversation_for_qwen([user_prompt])
            user_prompts.append(formatted_prompt)
            user_turns.append(user_prompt)
        
        # Tokenize as a batch
        inputs = tokenizer(
            user_prompts,
            return_tensors="pt",
            padding_side='left',
            padding=True,
            truncation=True,
            max_length=context_length
        )
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=500,
                temperature=0.7,
                top_p=0.9,
                do_sample=True,
            )
        
        # Decode only the generated part for each sample
        for idx in range(len(user_prompts)):
            input_len = inputs['input_ids'][idx].shape[0]
            generated_ids = outputs[idx][inputs['input_ids'][idx].shape[0]:]
            response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            new_conversation = [user_turns[idx], {'role': 'assistant', 'content': response}]
            processed_conversations.append({'conversation': new_conversation})
        
        # Clear GPU memory periodically
        if i % 100 == 0:
            torch.cuda.empty_cache()
            gc.collect()
    
    return processed_conversations

# Parallel


In [7]:
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)

def process_on_gpu(rank, dataset_chunk, batch_size, context_length, return_dict):
    torch.cuda.set_device(rank)
    # Load model and tokenizer inside each process
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map={"": rank},  # assign model to specific GPU
        cache_dir="/n/holylfs06/LABS/krajan_lab/Lab/cfang/hf_cache/"
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    results = process_dataset_batch(
        dataset_chunk, batch_size=batch_size, context_length=context_length
    )
    return_dict[rank] = results
    del model
    torch.cuda.empty_cache()
    gc.collect()

def process_dataset_batch_multigpu(dataset, batch_size=1, context_length=4096, max_samples=None):
    num_gpus = torch.cuda.device_count()
    print(f"Using {num_gpus} GPUs for parallel generation.")
    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))
    chunk_size = len(dataset) // num_gpus
    processes = []
    manager = mp.Manager()
    return_dict = manager.dict()
    for rank in range(num_gpus):
        start = rank * chunk_size
        end = (rank + 1) * chunk_size if rank < num_gpus - 1 else len(dataset)
        dataset_chunk = dataset.select(range(start, end))
        p = mp.Process(
            target=process_on_gpu,
            args=(rank, dataset_chunk, batch_size, context_length, return_dict)
        )
        p.start()
        processes.append(p)
    for p in processes:
        p.join()
    # Combine results from all GPUs
    all_results = []
    for rank in range(num_gpus):
        all_results.extend(return_dict[rank])
    return all_results

# Generate conversations and save dataset

In [8]:
# ...existing code...
processed_data = process_dataset_batch_multigpu(
    ds['train'], 
    batch_size=32,      # Set batch size as appropriate
    context_length=2048,
    max_samples=5,     # Adjust as needed
)
# ...existing code...

Using 2 GPUs for parallel generation.


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/n/home04/cfang/.conda/envs/jax/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/n/home04/cfang/.conda/envs/jax/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: Can't get attribute 'process_on_gpu' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/n/home04/cfang/.conda/envs/jax/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/n/home04/cfang/.conda/envs/jax/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
          

KeyError: 0

In [None]:
# Process the dataset
print("Starting dataset processing...")
processed_data = process_dataset_batch_multigpu(
    ds['train'], 
    batch_size=10,
    context_length=2048,
    max_samples=5,   # Adjust based on your needs and compute resources
)

Starting dataset processing...


Processing conversations:   0%|          | 0/1 [00:00<?, ?it/s]

Processing conversations: 100%|██████████| 1/1 [08:43<00:00, 523.60s/it]


In [None]:
# Create new dataset
new_dataset = Dataset.from_list(processed_data)
print(f"Created dataset with {len(new_dataset)} conversations")

# Save the processed dataset
new_dataset.save_to_disk("qwq_generated_dataset")

# Clear the generation model to free memory for training
del model
torch.cuda.empty_cache()
gc.collect()
