# Assignment 1: Classical Adversarial Thinking

This assignment builds on the concepts from Recitation 1. You will implement your own collision attacks and defense mechanisms.

**Prerequisites**: Complete Recitation 1 first to understand hash functions, collision attacks, and mitigation strategies.

**Instructions**
1. Complete all exercises below
2. Submit your completed notebook (.ipynb file)
3. Include a brief report (1-2 pages) analyzing your findings and discussing the security implications


## Setup Code

Run the cell below to get set up

In [None]:
# Setup: Import libraries and helper functions from Recitation 1
import hashlib
import time
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict
import random
import string

# Helper functions from Recitation 1
class SimpleHashFunctions:
    @staticmethod
    def sum_ascii(s, table_size=1000):
        return sum(ord(c) for c in s) % table_size
    
    @staticmethod
    def polynomial_rolling(s, table_size=1000, base=31):
        hash_value = 0
        for char in s:
            hash_value = (hash_value * base + ord(char)) % table_size
        return hash_value
    
    @staticmethod
    def djb2_hash(s, table_size=1000):
        hash_value = 5381
        for char in s:
            hash_value = ((hash_value << 5) + hash_value + ord(char)) % table_size
        return hash_value
    
    @staticmethod
    def custom_weak_hash(s, table_size=1000):
        return (len(s) + ord(s[0]) if s else 0) % table_size

class HashTable:
    def __init__(self, size=100, hash_function=None):
        self.size = size
        self.table = [[] for _ in range(size)]
        self.hash_function = hash_function or SimpleHashFunctions.sum_ascii
        self.collision_count = 0
        self.total_operations = 0
        
    def _hash(self, key):
        return self.hash_function(str(key), self.size)
    
    def insert(self, key, value):
        self.total_operations += 1
        index = self._hash(key)
        bucket = self.table[index]
        
        for i, (k, v) in enumerate(bucket):
            if k == key:
                bucket[i] = (key, value)
                return
        
        if len(bucket) > 0:
            self.collision_count += 1
        bucket.append((key, value))
    
    def get(self, key):
        self.total_operations += 1
        index = self._hash(key)
        bucket = self.table[index]
        
        for k, v in bucket:
            if k == key:
                return v
        raise KeyError(f"Key '{key}' not found")
    
    def get_statistics(self):
        bucket_lengths = [len(bucket) for bucket in self.table]
        non_empty_buckets = sum(1 for length in bucket_lengths if length > 0)
        max_chain_length = max(bucket_lengths) if bucket_lengths else 0
        avg_chain_length = sum(bucket_lengths) / non_empty_buckets if non_empty_buckets > 0 else 0
        
        return {
            'total_items': sum(bucket_lengths),
            'collision_count': self.collision_count,
            'collision_rate': self.collision_count / self.total_operations if self.total_operations > 0 else 0,
            'non_empty_buckets': non_empty_buckets,
            'max_chain_length': max_chain_length,
            'avg_chain_length': avg_chain_length,
            'load_factor': sum(bucket_lengths) / self.size
        }

class SecureHashFunctions:
    def __init__(self, salt=None):
        self.salt = salt or random.randint(1, 1000000)
    
    def cryptographic_hash(self, s, table_size=1000):
        hash_object = hashlib.sha256(str(s).encode())
        hash_bytes = hash_object.digest()[:8]
        hash_int = int.from_bytes(hash_bytes, byteorder='big')
        return hash_int % table_size

print("✅ Setup complete! Ready to start the exercises.")


## Exercise 1: Hash Function Quality Analysis

In cryptography and computer science, it's crucial to evaluate how well hash functions distribute their outputs. A good hash function should produce a "uniform distribution" meaning each possible hash value should be equally likely.

### Your Task
Implement a chi-square test to statistically analyze the uniformity of different hash functions. The chi-square test compares observed frequencies with expected frequencies.

**Chi-square formula**: 

$$\chi^2 = \sum_{i=1}^{n} \frac{(O_i - E_i)^2}{E_i}$$

Where:
- $O_i$ = observed frequency for bucket $i$
- $E_i$ = expected frequency for bucket $i$  
- $n$ = number of buckets


In [None]:
def chi_square_uniformity_test(hash_function, table_size, num_samples=1000):
    """
    Test if a hash function produces uniform distribution using chi-square test
    
    TODO: Complete this function
    Args:
        hash_function: The hash function to test
        table_size: Size of hash table (number of buckets)
        num_samples: Number of random strings to generate and test
        
    Returns:
        dict with results including chi-square statistic and uniformity assessment
    """
    
    # Step 1: Initialize frequency counter for each hash bucket
    frequencies = [0] * table_size
    
    # Step 2: Generate random strings and count hash frequencies
    # TODO: Generate num_samples random strings, hash them, and count frequencies
    
    # TODO: Implement your solution here
    
    # Step 3: Calculate expected frequency (should be equal for uniform distribution)
    expected_frequency = num_samples / table_size
    
    # Step 4: Calculate chi-square statistic
    # TODO: Implement chi-square formula: χ² = Σ((O_i - E_i)² / E_i)
    
    # TODO: Implement your solution here
    
    # Step 5: Determine if distribution appears uniform
    # Critical value for α=0.05 with (table_size-1) degrees of freedom
    # For table_size=10, critical value ≈ 16.919
    # If chi-square > critical value, distribution is NOT uniform
    
    # TODO: Implement your solution here
    
    return {
        'chi_square_stat': chi_square_stat,
        'critical_value': critical_value,
        'is_uniform': is_uniform,
        'frequencies': frequencies,
        'expected_frequency': expected_frequency,
        'num_samples': num_samples
    }

## Testing Exercise 1

Run this cell after you have completed exercise 1 above

In [None]:
print("🧪 Exercise 1: Testing Hash Function Uniformity")
print("=" * 50)

hash_functions_to_test = [
    ("Sum ASCII", SimpleHashFunctions.sum_ascii),
    ("Polynomial Rolling", SimpleHashFunctions.polynomial_rolling),
    ("DJB2", SimpleHashFunctions.djb2_hash),
    ("Weak Hash", SimpleHashFunctions.custom_weak_hash)
]

# Test each hash function
results = {}
for name, func in hash_functions_to_test:
    print(f"\nTesting {name}...")
    result = chi_square_uniformity_test(func, table_size=10, num_samples=1000)
    results[name] = result
    
    print(f"  Chi-square statistic: {result['chi_square_stat']:.3f}")
    print(f"  Critical value: {result['critical_value']:.3f}")
    print(f"  Is uniform: {'✅ Yes' if result['is_uniform'] else '❌ No'}")
    print(f"  Max frequency: {max(result['frequencies'])}")
    print(f"  Min frequency: {min(result['frequencies'])}")

# TODO: Create visualization comparing distributions
print(f"\n📊 Creating visualization...")
plt.figure(figsize=(16, 4))

for i, (name, func) in enumerate(hash_functions_to_test):
    plt.subplot(1, 4, i+1)
    result = results[name]
    
    # Bar chart of frequencies
    bars = plt.bar(range(10), result['frequencies'], alpha=0.7)
    
    # Add expected frequency line
    plt.axhline(y=result['expected_frequency'], color='red', linestyle='--', 
                linewidth=2, label=f"Expected ({result['expected_frequency']:.1f})")
    
    # Color bars based on deviation from expected
    for j, bar in enumerate(bars):
        deviation = abs(result['frequencies'][j] - result['expected_frequency'])
        if deviation > result['expected_frequency'] * 0.3:  # More than 30% deviation
            bar.set_color('orange')
    
    plt.title(f'{name}\nχ² = {result["chi_square_stat"]:.2f}\n{"Uniform" if result["is_uniform"] else "Non-uniform"}')
    plt.xlabel('Hash Bucket')
    plt.ylabel('Frequency')
    plt.ylim(0, max([max(r['frequencies']) for r in results.values()]) * 1.1)
    
    if i == 0:
        plt.legend()

plt.tight_layout()
plt.show()

## Exercise 2: Polynomial Hash Collision Attack

The polynomial rolling hash function computes:

$$h(s) = \left(\sum_{i=1}^{n} c_i \cdot b^{n-i}\right) \bmod T$$

Which can be written as:
$$h(s) = (c_1 \cdot b^{n-1} + c_2 \cdot b^{n-2} + \ldots + c_n \cdot b^0) \bmod T$$

Where:
- $c_i$ = ASCII value of character $i$
- $b$ = the base, typically 31
- $n$ = string length
- $s$ = input string
- $T$ = the table size

Design and implement a systematic collision attack against this hash function. Don't just find random collisions - create a **strategic method** for generating many colliding inputs.

In [None]:
def create_polynomial_collision_attack(base=31, table_size=1000, target_collisions=10):
    """
    Create adversarial inputs that collide in polynomial rolling hash
    
    TODO: Implement your collision attack strategy
    
    The polynomial rolling hash computes: 
    h(s) = (c₁·base^(n-1) + c₂·base^(n-2) + ... + cₙ·base^0) mod table_size
    
    Your attack strategy should be systematic, not random!
    
    Args:
        base: Base used in polynomial hash (typically 31)
        table_size: Size of hash table
        target_collisions: Number of colliding strings to generate
        
    Returns:
        list: Strings that all hash to the same value using polynomial_rolling
    """
    
    print(f"🎯 Exercise 2: Polynomial Hash Collision Attack")
    print(f"Target: Generate {target_collisions} strings that collide")
    print(f"Hash function: polynomial_rolling(base={base}, table_size={table_size})")
    print("=" * 60)
    
    colliding_strings = []
    
    # TODO: Implement your attack strategy here
    # Some approaches to consider:
    
    # Strategy 1: Exploit modular arithmetic
    # If base^k ≡ 1 (mod table_size), then strings differing by k positions might collide
    
    # Strategy 2: Character swapping for 2-char strings
    # For string "ab": h("ab") = (a·base + b) mod table_size
    # Find chars x,y such that (a·base + b) ≡ (x·base + y) (mod table_size)
    
    # Strategy 3: Length manipulation
    # Compare single chars vs multi-char strings that sum to same value
    
    # Strategy 4: Systematic character replacement for 3-char strings
    # For string "abc": h("abc") = (a·base² + b·base + c) mod table_size
    # Find string "xyz" where: a·base² + b·base + c ≡ x·base² + y·base + z (mod table_size)
    
    # EXAMPLE APPROACH (you should improve this):
    
    # TODO: Implement your solution here
    
    # Method 3: Mathematical approach for specific patterns
    # TODO: Implement more sophisticated collision generation
    # For 2-character strings, solve: a·base + b ≡ target_hash (mod table_size)
    # For 3-character strings, solve: a·base² + b·base + c ≡ target_hash (mod table_size)
    # This gives you systematic ways to find colliding character combinations
    
    print(f"\n✅ Generated {len(colliding_strings)} colliding strings")
    return colliding_strings

colliding_strings = create_polynomial_collision_attack(base=31, table_size=100, target_collisions=10)

## Testing Exercise 2
Run this cell to test out your changes from exercise 2

In [None]:
def demonstrate_attack_effectiveness(colliding_strings, table_size=50):
    """
    Demonstrate how the collision attack affects hash table performance    
    """
    
    print(f"\n🎭 Demonstrating Attack Effectiveness")
    print("=" * 50)
    
    # Create two hash tables: one with normal data, one with attack data
    normal_ht = HashTable(size=table_size, hash_function=SimpleHashFunctions.polynomial_rolling)
    attacked_ht = HashTable(size=table_size, hash_function=SimpleHashFunctions.polynomial_rolling)
    
    # Insert normal data
    normal_data = [f"normal_key_{i}" for i in range(len(colliding_strings))]
    for key in normal_data:
        normal_ht.insert(key, f"value_for_{key}")
    
    # Insert attack data using the colliding strings from the attack
    print(f"Inserting {len(colliding_strings)} attack strings that all collide...\n")
    
    for i, key in enumerate(colliding_strings):
        attacked_ht.insert(key, f"attack_value_{i}")
    
    # Compare statistics
    normal_stats = normal_ht.get_statistics()
    attack_stats = attacked_ht.get_statistics()
    
    print(f"Normal hash table:")
    print(f"  Collision rate: {normal_stats['collision_rate']:.2%}")
    print(f"  Max chain length: {normal_stats['max_chain_length']}")
    print(f"  Average chain length: {normal_stats['avg_chain_length']:.2f}")
    
    print(f"\nAttacked hash table:")
    print(f"  Collision rate: {attack_stats['collision_rate']:.2%}")
    print(f"  Max chain length: {attack_stats['max_chain_length']}")
    print(f"  Average chain length: {attack_stats['avg_chain_length']:.2f}")
    
    # TODO: Add timing comparison
    # TODO: Add visualization
    
    return normal_stats, attack_stats

print(f"\n📋 Attack Results:")
print("-" * 30)
for i, s in enumerate(colliding_strings):
    hash_val = SimpleHashFunctions.polynomial_rolling(s, 100, 31)
    print(f"{i+1:2d}. '{s}' -> {hash_val}")

# Verify all strings hash to the same value
hash_values = [SimpleHashFunctions.polynomial_rolling(s, 100, 31) for s in colliding_strings]
if len(set(hash_values)) == 1:
    print(f"\n✅ SUCCESS: All {len(colliding_strings)} strings hash to value {hash_values[0]}")
else:
    print(f"\n❌ FAILURE: Found {len(set(hash_values))} different hash values: {set(hash_values)}")

# Demonstrate the attack's effectiveness
_ = demonstrate_attack_effectiveness(colliding_strings, table_size=100)

## Exercise 3: Adaptive Hash Table

In the real world, systems need to detect and defend against hash collision attacks automatically. An adaptive hash table should monitor its performance and switch to more secure hash functions when under attack. Create a hash table that can detect collision attacks and automatically switch to a more secure hash function, then rehash all existing data.

In [None]:
class AdaptiveHashTable(HashTable):
    """
    Hash table that adapts its hash function when under attack
    
    TODO: Implement adaptive defense mechanism
    """
    
    def __init__(self, size=100, initial_hash_function=None):
        # Initialize base hash table
        super().__init__(size, initial_hash_function or SimpleHashFunctions.polynomial_rolling)
        
        # Defense system attributes
        self.secure_hasher = SecureHashFunctions()
        self.attack_detected = False
        self.collision_threshold = 0.6  # Switch when collision rate > 60%
        self.min_operations_before_check = 20  # Don't check too early
        self.defense_activated = False
        self.original_hash_function = self.hash_function
        
        # Performance tracking
        self.defense_activation_time = None
        self.operations_before_defense = 0
        self.operations_after_defense = 0
        
        print(f"🛡️  Adaptive hash table initialized")
        print(f"   Initial hash function: {self.original_hash_function.__name__}")
        print(f"   Collision threshold: {self.collision_threshold:.1%}")
        print(f"   Table size: {size}")
        
    def _detect_attack(self):
        """
        TODO: Implement attack detection logic
        
        Should return True if a collision attack is detected
        Consider:
        - Collision rate above threshold
        - Sudden spike in collisions  
        - Minimum number of operations to avoid false positives
        - Statistical significance of the collision pattern
        """
        # === YOUR CODE HERE === 
    # TODO: Implement your solution here
        
    def _switch_hash_function(self):
        """
        Switch to secure hash function and rehash all data
        
        Steps:
        1. Extract all current key-value pairs
        2. Switch to cryptographic hash function
        3. Clear the table
        4. Reinsert all data with new hash function
        5. Update performance metrics
        """
        print(f"🔄 ACTIVATING DEFENSE SYSTEM...")
        
        # Record timing
        switch_start_time = time.time()
        self.operations_before_defense = self.total_operations
        
        # Step 1: Extract all current data
        all_data = []
        for bucket in self.table:
            for key, value in bucket:
                all_data.append((key, value))
        
        print(f"   Extracted {len(all_data)} items from hash table")
        
        # Step 2: Switch to secure hash function
        self.hash_function = self.secure_hasher.cryptographic_hash
        print(f"   Switched to cryptographic hash function (SHA-256)")
        
        # Step 3: Clear the table and reset counters
        self.table = [[] for _ in range(self.size)]
        old_collision_count = self.collision_count
        self.collision_count = 0
        
        # Step 4: Reinsert all data with new hash function
        print(f"   Rehashing {len(all_data)} items...")
        for key, value in all_data:
            # Insert without triggering detection (use parent method)
            super().insert(key, value)
        
        # Step 5: Update status and metrics
        self.defense_activated = True
        self.attack_detected = True
        self.defense_activation_time = time.time() - switch_start_time
        
        # Show results
        new_stats = self.get_statistics()
        print(f"✅ DEFENSE ACTIVATED SUCCESSFULLY")
        print(f"   Switch time: {self.defense_activation_time:.4f} seconds")
        print(f"   Old collision rate: {old_collision_count / self.operations_before_defense:.2%}")
        print(f"   New collision rate: {new_stats['collision_rate']:.2%}")
        print(f"   New max chain length: {new_stats['max_chain_length']}")
        
    def insert(self, key, value):
        """Insert with attack detection and adaptation"""        
        # Perform normal insertion
        super().insert(key, value)
        
        # Check for attack after insertion (but not every time for performance)
        if self.total_operations % 5 == 0:  # Check every 5 operations
            if self._detect_attack():
                self._switch_hash_function()
    
    def get_defense_report(self):
        """Get detailed report on defense system status"""
        
        stats = self.get_statistics()
        
        report = {
            'defense_activated': self.defense_activated,
            'attack_detected': self.attack_detected,
            'current_collision_rate': stats['collision_rate'],
            'collision_threshold': self.collision_threshold,
            'total_operations': self.total_operations,
            'operations_before_defense': self.operations_before_defense,
            'operations_after_defense': self.total_operations - self.operations_before_defense if self.defense_activated else 0,
            'defense_activation_time': self.defense_activation_time,
            'current_hash_function': 'Cryptographic (SHA-256)' if self.defense_activated else 'Original (Polynomial Rolling)',
            'max_chain_length': stats['max_chain_length'],
            'avg_chain_length': stats['avg_chain_length']
        }
        
        return report

## Testing Exercise 3
Run this cell to test out your changes from exercise 3

In [None]:
def test_adaptive_defense(colliding_strings):
    """
    Test the adaptive defense system against collision attack    
    """
    
    print(f"\n🧪 Exercise 3: Testing Adaptive Defense System")
    print("=" * 60)
    
    # Create adaptive hash table
    adaptive_ht = AdaptiveHashTable(size=50)
    
    print(f"\n📝 Phase 1: Normal operations")
    # Insert some normal data first
    normal_keys = [f"normal_key_{i}" for i in range(15)]
    for key in normal_keys:
        adaptive_ht.insert(key, f"value_for_{key}")
    
    stats_before = adaptive_ht.get_statistics()
    print(f"Before attack - Collision rate: {stats_before['collision_rate']:.2%}")
    
    print(f"\n💥 Phase 2: Launching collision attack")
    # Now launch the collision attack
    attack_data = []
    for i in range(30):  # Insert many colliding items
        for j, colliding_key in enumerate(colliding_strings[:5]):  # Use first 5 colliding strings
            attack_key = f"{colliding_key}_attack_{i}_{j}"
            attack_data.append(attack_key)
            adaptive_ht.insert(attack_key, f"attack_value_{i}_{j}")
            
            # Check if defense was activated
            if adaptive_ht.defense_activated:
                print(f"🛡️  Defense activated after {len(attack_data)} attack insertions")
                break
        if adaptive_ht.defense_activated:
            break
    
    print(f"\n📊 Phase 3: Final analysis")
    final_report = adaptive_ht.get_defense_report()
    
    print(f"Defense System Report:")
    print(f"  Attack detected: {'✅ Yes' if final_report['attack_detected'] else '❌ No'}")
    print(f"  Defense activated: {'✅ Yes' if final_report['defense_activated'] else '❌ No'}")
    print(f"  Final collision rate: {final_report['current_collision_rate']:.2%}")
    print(f"  Operations before defense: {final_report['operations_before_defense']}")
    print(f"  Operations after defense: {final_report['operations_after_defense']}")
    print(f"  Defense activation time: {final_report['defense_activation_time']:.4f}s" if final_report['defense_activation_time'] else "N/A")
    print(f"  Current hash function: {final_report['current_hash_function']}")
    print(f"  Max chain length: {final_report['max_chain_length']}")
    
    # TODO: Add performance comparison
    # TODO: Add visualization of before/after
    
    return adaptive_ht, final_report

# Test the adaptive defense system
if 'colliding_strings' in locals():
    adaptive_table, defense_report = test_adaptive_defense(colliding_strings)
else:
    print("⚠️  Run Exercise 2 first to generate colliding strings for testing")

## Submission Checklist

- [ ] Completed all three exercises with working code
- [ ] Notebook runs without errors
- [ ] Code is well-commented and readable

**Good luck with your implementation! 🚀**
