In [None]:
from typing import List, Dict, Any
import json
import os
import gc
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
torch.manual_seed(11037)

In [None]:
default_params = {"temperature": 0.7, "top_p": 0.9}

class LanguageModel:    
    def __init__(self,
                 model_name: str,
                 enable_thinking: bool = False,
                 params: dict = None) -> None:
        if params is None:
            params = default_params.copy()
        
        self.model_name = model_name
        self.enable_thinking = enable_thinking
        self.params = params
        self.tokenizer = None
        self.model = None
        self._loaded = False
    
    def load(self) -> None:
        if self._loaded:
            return
            
        print(f"  Loading model: {self.model_name} (thinking={'enabled' if self.enable_thinking else 'disabled'})")
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_name,
            torch_dtype="auto",
            device_map="auto",
        )
        #self.model = self.model.to("cuda")
        self._loaded = True
        print(f"  Model loaded successfully")
    
    def unload(self) -> None:
        if not self._loaded:
            return
            
        print(f"  Unloading model: {self.model_name}")
        
        del self.model
        del self.tokenizer
        self.model = None
        self.tokenizer = None
        self._loaded = False
        
        gc.collect()
        torch.cuda.empty_cache()
        
        print(f"  Model unloaded, GPU memory freed")

    def update_params(self, params: dict) -> None:
        self.params = params

    def call(self, prompt: str) -> str:
        if not self._loaded:
            self.load()
        
        messages = [
            {"role": "user", "content": prompt}
        ]
        
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=self.enable_thinking,
        )
        
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        
        generated_ids = self.model.generate(
            **model_inputs,
            max_new_tokens=8192 if self.enable_thinking else 2048,
            temperature=self.params["temperature"],
            top_p=self.params["top_p"]
        )
        
        output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

        if self.enable_thinking:
            try:
                index = len(output_ids) - output_ids[::-1].index(151668) 
            except ValueError:
                index = 0
        else:
            index = 0

        return self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")


In [None]:
class DistillationResult:    
    def __init__(self, 
                 name: str,
                 description: str,
                 input_type: str,
                 example_inputs: List[str],
                 edge_cases: List[str],
                 maximize_coverage: str,
                 raw_distillation: str):
        self.name = name
        self.description = description
        self.input_type = input_type
        self.example_inputs = example_inputs
        self.edge_cases = edge_cases
        self.maximize_coverage = maximize_coverage
        self.raw_distillation = raw_distillation
    
    def to_dict(self) -> Dict[str, Any]:
        return {
            "name": self.name,
            "description": self.description,
            "input_type": self.input_type,
            "example_inputs": self.example_inputs,
            "edge_cases": self.edge_cases,
            "maximize_coverage": self.maximize_coverage,
            "raw_distillation": self.raw_distillation
        }
    
    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "DistillationResult":
        return cls(
            name=data.get("name", "Unknown"),
            description=data.get("description", ""),
            input_type=data.get("input_type", "string"),
            example_inputs=data.get("example_inputs", []),
            edge_cases=data.get("edge_cases", []),
            maximize_coverage=data.get("maximize_coverage", ""),
            raw_distillation=data.get("raw_distillation", "")
        )

In [None]:
class FuzzTargetDistillator:
    def __init__(self, language_model: LanguageModel, base_dir: str):
        self.llm = language_model
        self.base_dir = base_dir

    def _read_files(self, file_paths: List[str]) -> str:
        file_contents = []
        
        for file_path in file_paths:
            full_path = os.path.join(self.base_dir, file_path)
            if os.path.exists(full_path):
                try:
                    with open(full_path, 'r', encoding='utf-8') as f:
                        content = f.read()
                        file_contents.append(f"=== FILE: {file_path} ===\n{content}")
                except Exception as e:
                    print(f"  Warning: Could not read {file_path}: {e}")
        
        return "\n\n".join(file_contents)

    def distillate(self, target_key: str, file_paths: List[str]) -> DistillationResult:        
        source_code = self._read_files(file_paths)
        
        if not source_code:
            return DistillationResult(
                name=target_key,
                description="No source files found",
                input_type="string",
                example_inputs=[],
                edge_cases=[],
                maximize_coverage="",
                raw_distillation=""
            )
        
        if len(source_code) > 32768:
            source_code = source_code[:32768] + "\n\n... (truncated)"
        
        prompt = f"""You are a code analysis expert. Your task is to analyze the following Go source code and prepare a CONCISE technical summary for a fuzzing input generator.

=== SOURCE CODE ===
{source_code}

=== YOUR TASK ===
Analyze this code and provide a structured JSON response with the following information:

1. **name**: A short descriptive name for this fuzzing target (e.g., "CEL Filter Parser", "Markdown Tag Extractor")

2. **description**: A CONCISE technical summary (max 500 words) that includes:
   - What the code does (main functionality)
   - Input format: exact syntax rules, grammar, allowed characters
   - How inputs are processed
   - Error handling patterns
   - Any validation or parsing logic

3. **input_type**: Either "string" or "bytes" - what type of input the main functions accept

4. **example_inputs**: Array of 5-10 valid input examples that would be accepted by the code, make them different to show variety of inputs

5. **edge_cases**: Array of 10-20 edge cases that might cause errors, panics, or unexpected behavior:
   - Boundary conditions (empty, very long, special chars)
   - Malformed inputs
   - Unicode edge cases
   - Security-relevant inputs (injections, bypasses)

6. **maximize_coverage**: Strictly and technically explain which cases would lead to getting into different parts of code and maximizing code coverage

Be CONCISE but TECHNICAL. Focus on details that help generate effective fuzz inputs.

IMPORTANT: Return ONLY valid JSON in this exact format:
{{
  "name": "...",
  "description": "...",
  "input_type": "string",
  "example_inputs": ["...", "..."],
  "edge_cases": ["...", "..."],
  "maximize_coverage": "..."
}}

No markdown, no explanations inside or outside JSON.
"""
        
        response = self.llm.call(prompt)
        
        try:
            start = response.find('{')
            end = response.rfind('}') + 1
            if start >= 0 and end > start:
                json_str = response[start:end]
                data = json.loads(json_str)
                
                return DistillationResult(
                    name=data.get("name", target_key),
                    description=data.get("description", ""),
                    input_type=data.get("input_type", "string"),
                    example_inputs=data.get("example_inputs", []),
                    edge_cases=data.get("edge_cases", []),
                    maximize_coverage=data.get("maximize_coverage", ""),
                    raw_distillation=response
                )
        except json.JSONDecodeError as e:
            print(f"  Warning: Failed to parse JSON from distillation: {e}")
        except Exception as e:
            print(f"  Warning: Error processing distillation: {e}")
        
        return DistillationResult(
            name=target_key,
            description=response,
            input_type="string",
            example_inputs=[],
            edge_cases=[],
            maximize_coverage="",
            raw_distillation=response
        )

In [None]:
FUZZ_FILE_GROUPS = {
    "filter": {
        "files": ["plugin/filter/parser.go", "plugin/filter/engine.go"],
    },
    "url": {
        "files": ["plugin/httpgetter/html_meta.go"],
    },
    "email": {
        "files": ["internal/util/util.go"],
    },
    "uid": {
        "files": ["internal/base/resource_name.go"],
    },
}

In [None]:
class TargetedFuzzer:    
    def __init__(self, language_model: LanguageModel):
        self.llm = language_model

    def _flatten_inputs(self, inputs) -> List[str]:
        result = []
        
        if not isinstance(inputs, list):
            if isinstance(inputs, str):
                return [inputs]
            return [str(inputs)]
            
        for item in inputs:
            if isinstance(item, list):
                result.extend(self._flatten_inputs(item))
            elif isinstance(item, dict):
                result.append(json.dumps(item, ensure_ascii=False))
            elif isinstance(item, bytes):
                result.append(item.decode('utf-8', errors='replace'))
            elif isinstance(item, str):
                result.append(item)
            elif item is not None:
                result.append(str(item))
        return result

    def generate_inputs(self, 
                        distillation: DistillationResult,
                        existing_inputs: List[str],
                        num_inputs: int = 20,
                        iteration: int = 1) -> List[str]:
        if existing_inputs:
            max_show = min(50, len(existing_inputs))
            recent_inputs = existing_inputs[-max_show:]
            existing_inputs_section = f"""
Previously generated inputs ({len(existing_inputs)} total, showing last {max_show}):
{json.dumps(recent_inputs, indent=2, ensure_ascii=False)}

IMPORTANT: Generate NEW and DIFFERENT inputs. Do not repeat the above. Don't try to copy patterns above.
"""
        else:
            existing_inputs_section = """
This is the first iteration. Start with diverse inputs.
"""

        iteration_guidance = {
            1: "Focus on: basic valid inputs, simple variations, common patterns",
            2: "Focus on: maximizing code coverage based on code structure analysis", 
            3: "Focus on: Unicode characters, special symbols, escape sequences",
            4: "Focus on: malformed inputs, syntax errors, invalid combinations",
            5: "Focus on: security payloads, injection attempts, bypass patterns",
            6: "Focus on: boundary conditions, empty/null inputs, very long inputs",
        }
        
        guidance = iteration_guidance.get(iteration, "Focus on: creative edge cases not yet covered")
        
        examples_section = ""
        if distillation.example_inputs:
            examples_section = f"""
=== VALID INPUT EXAMPLES (from code analysis) ===
{json.dumps(distillation.example_inputs, indent=2, ensure_ascii=False)}
"""
        
        edge_cases_section = ""
        if distillation.edge_cases:
            edge_cases_section = f"""
=== KNOWN EDGE CASES ===
{json.dumps(distillation.edge_cases, indent=2, ensure_ascii=False)}
"""

        coverage_section = ""
        if distillation.maximize_coverage and iteration >= 5:
            coverage_section = f"""
=== CODE COVERAGE GUIDANCE ===
{distillation.maximize_coverage}
"""
        
        prompt = f"""You are a fuzzing expert generating test inputs.

=== TARGET ===
Name: {distillation.name}
Input type: {distillation.input_type}

=== TECHNICAL CONTEXT ===
{distillation.description}
{examples_section}
{edge_cases_section}
{coverage_section}

=== EXISTING CORPUS ===
{existing_inputs_section}

=== ITERATION {iteration} GUIDANCE ===
{guidance}

=== YOUR TASK ===
Generate {num_inputs} NEW and UNIQUE test inputs.

Requirements:
1. Each input must be DIFFERENT from previously generated ones
2. Cover new edge cases and patterns
3. Include both valid and invalid inputs
4. Be creative but relevant to the target

CRITICAL: Return ONLY a valid JSON array of strings:
["input1", "input2", "input3", ...]

No explanations, no markdown, just the JSON array.
"""
        
        response = self.llm.call(prompt)
        
        try:
            start = response.find('[')
            end = response.rfind(']') + 1
            if start >= 0 and end > start:
                json_str = response[start:end]
                inputs = json.loads(json_str)
                return self._flatten_inputs(inputs)
        except json.JSONDecodeError as e:
            print(f"    Warning: JSON parse error: {e}")
        except Exception as e:
            print(f"    Warning: Error processing inputs: {e}")
        
        return []

    def generate_corpus(self, 
                        distillation: DistillationResult,
                        iterations: int = 6, 
                        inputs_per_iter: int = 20) -> List[str]:
        all_inputs_list = []
        all_inputs_set = set()
        
        seed_inputs = distillation.example_inputs + distillation.edge_cases
        for inp in seed_inputs:
            if isinstance(inp, str) and inp not in all_inputs_set:
                all_inputs_set.add(inp)
                all_inputs_list.append(inp)
        
        print(f"  Starting with {len(all_inputs_list)} seed inputs from distillation")
        
        for i in range(iterations):
            print(f"  Iteration {i+1}/{iterations}...")
            try:
                new_inputs = self.generate_inputs(
                    distillation=distillation,
                    existing_inputs=all_inputs_list,
                    num_inputs=inputs_per_iter,
                    iteration=i + 1
                )
                
                added_count = 0
                for inp in new_inputs:
                    if isinstance(inp, str) and inp not in all_inputs_set:
                        all_inputs_set.add(inp)
                        all_inputs_list.append(inp)
                        added_count += 1
                
                print(f"    Generated {len(new_inputs)} inputs, added {added_count} new, total: {len(all_inputs_list)}")
                
            except Exception as e:
                print(f"    Error in iteration {i+1}: {e}")
                continue
        
        return all_inputs_list


def save_corpus(corpus: List[str], target_key: str, output_dir: str = "corpus"):
    target_dir = os.path.join(output_dir, target_key)
    os.makedirs(target_dir, exist_ok=True)
    
    for i, inp in enumerate(corpus):
        file_path = os.path.join(target_dir, f"input_{i:04d}")
        with open(file_path, 'wb') as f:
            if isinstance(inp, str):
                f.write(inp.encode('utf-8'))
            else:
                f.write(inp)
    
    print(f"Saved {len(corpus)} inputs to {target_dir}")


def run_two_phase_pipeline(base_dir: str, 
                           file_groups: Dict[str, Dict] = None,
                           distill_model_name: str = "Qwen/Qwen3-1.7B",
                           gen_model_name: str = "Qwen/Qwen3-1.7B",
                           iterations: int = 6,
                           inputs_per_iter: int = 20):    
    if file_groups is None:
        file_groups = FUZZ_FILE_GROUPS
    
    results = {}
    distillation_results = {}
    
    print("\n" + "="*60)
    print("PHASE 1: DISTILLATION")
    print("="*60)
    
    distill_model = LanguageModel(
        model_name=distill_model_name,
        enable_thinking=True,
        params={"temperature": 0.6, "top_p": 0.95}
    )
    distill_model.load()
    
    distillator = FuzzTargetDistillator(distill_model, base_dir)
    
    for target_key, config in file_groups.items():
        file_paths = config["files"]
        print(f"\n--- Distilling: {target_key} ---")
        print(f"  Files: {file_paths}")
        
        try:
            distillation = distillator.distillate(target_key, file_paths)
            distillation_results[target_key] = distillation
            
            print(f"  Name: {distillation.name}")
            print(f"  Input type: {distillation.input_type}")
            print(f"  Examples: {len(distillation.example_inputs)}")
            print(f"  Edge cases: {len(distillation.edge_cases)}")
            print(f"  Description: {len(distillation.description)} chars")
            
        except Exception as e:
            print(f"  Error: {e}")
            distillation_results[target_key] = DistillationResult(
                name=target_key,
                description=f"Fuzzing target for files: {file_paths}",
                input_type="string",
                example_inputs=[],
                edge_cases=[],
                maximize_coverage="",
                raw_distillation=""
            )
    
    distill_model.unload()
    
    print("\n" + "="*60)
    print("PHASE 2: CORPUS GENERATION")
    print("="*60)
    
    gen_model = LanguageModel(
        model_name=gen_model_name,
        enable_thinking=False,
        params={"temperature": 0.83, "top_p": 0.85}
    )
    gen_model.load()
    
    fuzzer = TargetedFuzzer(gen_model)
    
    for target_key in file_groups.keys():
        distillation = distillation_results[target_key]
        print(f"\n--- Generating corpus: {distillation.name} ({target_key}) ---")
        
        try:
            corpus = fuzzer.generate_corpus(
                distillation=distillation,
                iterations=iterations,
                inputs_per_iter=inputs_per_iter
            )
            
            save_corpus(corpus, target_key)
            
            results[target_key] = {
                "distillation": distillation.to_dict(),
                "corpus_size": len(corpus),
                "corpus": corpus
            }
            print(f"  Generated {len(corpus)} unique inputs")
            
        except Exception as e:
            print(f"  Error: {e}")
            results[target_key] = {"error": str(e)}
    
    gen_model.unload()
    
    return results, distillation_results

In [None]:
results, distillations = run_two_phase_pipeline(
    base_dir="memos",
    file_groups=FUZZ_FILE_GROUPS,
    distill_model_name="Qwen/Qwen3-4B",
    gen_model_name="Qwen/Qwen3-1.7B",
    iterations=6,
    inputs_per_iter=20
)

with open("fuzzing_results.json", "w") as f:
    summary = {}
    for k, v in results.items():
        if "error" in v:
            summary[k] = {"error": v["error"]}
        else:
            summary[k] = {
                "name": v["distillation"]["name"],
                "input_type": v["distillation"]["input_type"],
                "corpus_size": v["corpus_size"],
                "sample_inputs": v["corpus"][:20]
            }
    json.dump(summary, f, indent=2, ensure_ascii=False)

with open("distillation_results.json", "w") as f:
    distill_data = {k: v.to_dict() for k, v in distillations.items()}
    json.dump(distill_data, f, indent=2, ensure_ascii=False)

print("\n" + "="*60)
print("SUMMARY")
print("="*60)

print("\nDistillation Results:")
for target_key, distillation in distillations.items():
    print(f"  {target_key}:")
    print(f"    Name: {distillation.name}")
    print(f"    Type: {distillation.input_type}")
    print(f"    Examples: {len(distillation.example_inputs)}, Edge cases: {len(distillation.edge_cases)}")

print("\nCorpus Generation:")
total_inputs = 0
for target, data in results.items():
    if "error" in data:
        print(f"  {target}: ERROR - {data['error']}")
    else:
        print(f"  {target}: {data['corpus_size']} inputs")
        total_inputs += data['corpus_size']

print(f"\nTotal inputs generated: {total_inputs}")
print("Corpus saved to: ./corpus/")
print("Results saved to: fuzzing_results.json")
print("Distillation saved to: distillation_results.json")