In [None]:
!pip install transformers
!pip install accelerate

In [1]:

import torch
import json
from tqdm import tqdm
from datetime import datetime
from pathlib import Path
import random
from templates_and_instructions import templates, instructions

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "PKU-Alignment/alpaca-7b-reproduced"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map='auto')
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')

In [4]:
def collect_model_responses(model, tokenizer, instructions, templates, model_name,
                           num_samples, batch_size, output_filepath=None,
                           generation_params=None, random_seed=None):
    """
    Collect model responses for instruction categories × templates × num_samples.

    Args:
        model: The loaded model
        tokenizer: The loaded tokenizer
        instructions: Dict mapping category names to lists of instruction variations
        templates: Dict of template_name -> Jinja2 Template objects
        model_name: Name of the model
        num_samples: Number of samples per category-template combination
        batch_size: Batch size for generation
        output_filepath: Path to save results
        generation_params: Parameters for generation
        random_seed: Random seed for reproducible sampling (optional)
    """

    if random_seed is not None:
        random.seed(random_seed)
        torch.manual_seed(random_seed)

    if generation_params is None:
        generation_params = {
            'max_new_tokens': 100,
            'temperature': 0.7,
            'do_sample': True
        }

    #==== DEFINE RESULTS FILE ====#
    if output_filepath is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_name_clean = model_name.replace('/', '_').replace('\\', '_')
        output_filepath = Path(f"./results/{model_name_clean}_{timestamp}.jsonl")

    output_filepath = Path(output_filepath)
    output_filepath.parent.mkdir(parents=True, exist_ok=True)

    if output_filepath.exists():
        output_filepath.unlink()

    #==== PREPARE METADATA ====#
    sampled_instructions = {}

    metadata = {
        'model_name': model_name,
        'num_samples': num_samples,
        'generation_params': generation_params,
        'batch_size': batch_size,
        'random_seed': random_seed,
        'total_combinations': len(instructions) * len(templates),
        'total_responses': len(instructions) * len(templates) * num_samples,
        'instruction_categories_count': len(instructions),
        'templates_count': len(templates),
        'instruction_categories': list(instructions.keys()),
        'templates': list(templates.keys()),
        'timestamp': datetime.now().isoformat(),
        'output_file': str(output_filepath),
        'sampled_instructions': sampled_instructions
    }

    #==== PREPARE PROMPTS ====#
    prompts = []
    prompt_metadata = []

    total_iterations = len(instructions) * len(templates) * num_samples

    with tqdm(total=total_iterations, desc="Preparing prompts") as pbar:
        for category_name, instruction_variations in instructions.items():
            for template_name, template in templates.items():
                sampled_for_combo = []
                for sample_idx in range(num_samples):
                    selected_instruction = random.choice(instruction_variations)
                    sampled_for_combo.append(selected_instruction)

                    messages = [{"role": "user", "content": selected_instruction}]
                    formatted_prompt = template.render(
                        messages=messages,
                        add_generation_prompt=True,
                        eos_token=tokenizer.eos_token
                    )

                    prompts.append(formatted_prompt)
                    prompt_metadata.append({
                        'category': category_name,
                        'instruction': selected_instruction,
                        'template_name': template_name,
                        'sample_idx': sample_idx,
                        'formatted_prompt': formatted_prompt
                    })

                    pbar.update(1)

                combo_key = f"{category_name}-{template_name}"
                sampled_instructions[combo_key] = sampled_for_combo


    metadata['sampled_instructions'] = sampled_instructions
    metadata_path = write_metadata_json(output_filepath, metadata)
    print(f"Metadata saved to {metadata_path}")


    total_prompts = len(prompts)
    print(f"Generating {total_prompts} responses using batch_size={batch_size}...")
    print(f"Writing responses to {output_filepath}")

    #==== PROCESS PROMPTS ====#
    with tqdm(total=total_prompts, desc="Generating responses") as pbar:
        for i in range(0, total_prompts, batch_size):
            batch_prompts = prompts[i:i+batch_size]
            batch_metadata = prompt_metadata[i:i+batch_size]

            full_responses = generate_response(model, tokenizer, batch_prompts, **generation_params)

            for response, metadata in zip(full_responses, batch_metadata):
                record = {
                    'model_name': model_name,
                    'category': metadata['category'],
                    'instruction': metadata['instruction'],
                    'template': metadata['template_name'],
                    'sample_idx': metadata['sample_idx'],
                    'full_response': response,
                    'formatted_prompt': metadata['formatted_prompt'],
                }

                write_response_jsonl(output_filepath, record)

            pbar.update(len(batch_prompts))

    print(f"All responses saved to {output_filepath}")
    print(f"Total combinations: {len(instructions)} categories × {len(templates)} templates × {num_samples} samples = {total_prompts} responses")
    return output_filepath, metadata_path


def generate_response(model, tokenizer, prompts: list[str], max_new_tokens, temperature, do_sample):
    inputs = tokenizer(prompts, return_tensors="pt", truncation=True,
                       max_length=2048, padding=True, return_attention_mask=True)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1
        )

    full_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return full_responses


def write_response_jsonl(filepath, record):
    with open(filepath, 'a', encoding='utf-8') as f:
        f.write(json.dumps(record, ensure_ascii=False) + '\n')


def write_metadata_json(filepath, metadata):
    metadata_path = filepath.with_suffix('.metadata.json')
    with open(metadata_path, 'w', encoding='utf-8') as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)
    return metadata_path


In [None]:
results = collect_model_responses(
    model, tokenizer, instructions, templates,
    model_name,
    num_samples=100,
    batch_size=16,
    random_seed=26
)