# Preventing AI Model Distillation

This demo explores techniques to prevent knowledge distillation to protect proprietary models, ensure security, and safeguard intellectual property, with practical implementations for each approach.

### 1. Output Watermarking

This demo hides a watermark inside AI-generated text by embedding invisible markers that can later be detected.

Features in this Code:
* Embeds Hidden Patterns in AI Responses
* Watermarks Can Be Extracted Later to Verify Ownership
* Works Without Affecting Readability

In [1]:
import hashlib
import random

class AIWatermarker:
    def __init__(self, secret_key="my_secret"):
        self.secret_key = secret_key  # Used for creating a unique watermark

    def generate_watermark(self, text):
        """Create a watermark based on a hash of the text and secret key."""
        hash_obj = hashlib.sha256((text + self.secret_key).encode()).hexdigest()
        watermark = hash_obj[:6]  # Use first 6 characters for watermark
        return watermark

    def embed_watermark(self, text):
        """Inserts watermark into the AI-generated text in an invisible way."""
        watermark = self.generate_watermark(text)
        words = text.split()
        pos = random.randint(0, len(words) - 1)  # Randomly select a word position
        words.insert(pos, f"[{watermark}]")  # Embed watermark in brackets
        return " ".join(words)

    def detect_watermark(self, text):
        """Extracts and verifies the watermark from AI-generated text."""
        words = text.split()
        for word in words:
            if word.startswith("[") and word.endswith("]"):
                extracted = word[1:-1]
                return f"Watermark Found: {extracted}"
        return "No Watermark Detected"

# === Demo Usage ===
ai_model = AIWatermarker()

# AI generates a response
generated_text = "AI security is crucial to protect models from extraction."
watermarked_text = ai_model.embed_watermark(generated_text)

# Simulated output usage
print("Watermarked AI Output:", watermarked_text)

# Check for watermark
verification = ai_model.detect_watermark(watermarked_text)
print(verification)


Watermarked AI Output: AI security is crucial to [6b9826] protect models from extraction.
Watermark Found: 6b9826


### 2. Legal and Policy Barriers

Legal & Policy Barriers in AI security involve embedding legal disclaimers, licensing constraints, and policy enforcement within AI responses to deter misuse and unauthorized use. These barriers can:

* Prevent AI Misuse – Display legal warnings when detecting suspicious behavior.
* Deter Model Theft – Include licensing information in responses to enforce ownership.
* Enforce AI Ethics – Block harmful or policy-violating queries.

Honeypot Responses are trap responses designed to catch and flag bad actors trying to exploit the AI. These responses can:

* Identify Malicious Users – Trigger when someone asks about model extraction, adversarial attacks, or hacking.
* Track Unauthorized API Access – AI embeds unique markers to trace misuse.
* Provide Fake Data to Attackers – Prevent real model leakage.

In [2]:
def detect_honeypot(response):
    """Checks for a specific phrase to identify unauthorized copies."""
    honeypot_phrases = ["UniquePatternXYZ123"]
    return any(phrase in response for phrase in honeypot_phrases)

generated_response = "This is an AI-generated response. UniquePatternXYZ123"
print(detect_honeypot(generated_response))  # Detects unauthorized use

True


In [3]:
import time

class SecureAIChatbot:
    def __init__(self):
        self.honeypot_queries = ["how to extract model", "bypass AI security", "steal AI responses"]
        self.legal_disclaimer = "⚠️ Legal Notice: Unauthorized use of AI-generated content may violate terms of service."

    def check_honeypot(self, query):
        """Detects if a user is trying to attack the AI system."""
        for honeypot in self.honeypot_queries:
            if honeypot in query.lower():
                return True
        return False

    def generate_response(self, query):
        """Handles AI responses with security checks."""
        if self.check_honeypot(query):
            # Log the suspicious query for monitoring
            with open("security_log.txt", "a") as log:
                log.write(f"⚠️ Honeypot Triggered! Suspicious Query: '{query}' at {time.ctime()}\n")

            return "🚨 Security Alert: Your query has been flagged for review."

        # Apply legal disclaimer for sensitive topics
        if "copyright" in query.lower() or "legal" in query.lower():
            return f"{self.legal_disclaimer}\n\nAI cannot provide legal advice."

        # Normal AI response
        return f"🤖 AI Response: {query.capitalize()} is an interesting topic!"

# === Demo Usage ===
secure_ai = SecureAIChatbot()

# Example Queries
queries = [
    "Tell me about AI security",
    "How to extract model",
    "Legal implications of AI",
    "How to bypass AI security"
]

for q in queries:
    print(f"User: {q}")
    print(f"AI: {secure_ai.generate_response(q)}\n")


User: Tell me about AI security
AI: 🤖 AI Response: Tell me about ai security is an interesting topic!

User: How to extract model
AI: 🚨 Security Alert: Your query has been flagged for review.

User: Legal implications of AI
AI: ⚠️ Legal Notice: Unauthorized use of AI-generated content may violate terms of service.

AI cannot provide legal advice.

User: How to bypass AI security
AI: 🤖 AI Response: How to bypass ai security is an interesting topic!



### 3. Noise Injection

Controlled noise injection, leveraging adaptive noise scaling and differential privacy, ensures that critical responses remain clear while subtly obfuscating security-sensitive outputs. This ensures responses are not identical when queried multiple times.

Features of the code:

* Adaptive Noise Injection: More noise for sensitive data, less for general queries.
* Differential Privacy (ε-Tuning): Controls the balance between security and usability.*
* Query Monitoring: Detects repeated queries to prevent adversarial model extraction.

In [4]:
import numpy as np
import hashlib
from collections import defaultdict

class SecureAI:
    def __init__(self, epsilon=0.2, sensitivity_threshold=0.5):
        self.epsilon = epsilon  # Controls noise level
        self.sensitivity_threshold = sensitivity_threshold  # Determines what gets obfuscated
        self.query_log = defaultdict(int)  # Tracks repeated queries

    def _hash_query(self, query):
        """Generate a hashed identifier for the query to track repetition."""
        return hashlib.sha256(query.encode()).hexdigest()

    def _add_noise(self, value, sensitivity):
        """Apply Laplacian noise based on sensitivity and privacy budget (epsilon)."""
        scale = sensitivity / self.epsilon  # More sensitive = more noise
        noise = np.random.laplace(0, scale)
        return round(value + noise, 2)  # Keep output readable

    def process_query(self, query):
        """Handle a user query with adaptive noise injection."""
        query_id = self._hash_query(query)
        self.query_log[query_id] += 1

        # Simulated AI responses
        responses = {
            "company revenue": (5000000, 0.8),  # Sensitive data
            "weather forecast": (25, 0.2),      # General data
            "stock price prediction": (150, 0.6),  # Semi-sensitive
        }

        if query not in responses:
            return "I can't answer that."

        base_value, sensitivity = responses[query]

        # If a user queries the same thing repeatedly, increase noise (anti-extraction)
        sensitivity += 0.1 * min(self.query_log[query_id], 5)

        # Apply controlled noise
        noised_value = self._add_noise(base_value, sensitivity)
        
        return f"Approximate {query}: {noised_value}"

# Demo Usage
secure_ai = SecureAI(epsilon=0.3)

queries = ["company revenue", "weather forecast", "stock price prediction", "company revenue", "company revenue"]
for q in queries:
    print(secure_ai.process_query(q))


Approximate company revenue: 4999995.8
Approximate weather forecast: 21.94
Approximate stock price prediction: 150.82
Approximate company revenue: 5000001.23
Approximate company revenue: 5000002.78


### 4. Gradient Masking

This prevents attackers from extracting model weights via gradient-based methods

Features in This Code:
* Gradient Obfuscation: Adds random noise to gradients during training.
* Detects Query-Based Attacks: If gradients are requested too often, it injects stronger noise.
* Maintains Usability: The model still learns but is resistant to adversarial exploitation.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch
import torch.nn as nn
from torchsummary import summary

class GradientMaskedModel(nn.Module):
    def __init__(self, input_size=2, output_size=1, noise_level=0.1):
        super().__init__()
        self.fc = nn.Linear(input_size, output_size)
        self.noise_level = noise_level  # Controls gradient noise

    def forward(self, x):
        return self.fc(x)

    def add_gradient_masking(self):
        """Modify gradients to make them less useful for attackers."""
        for param in self.parameters():
            if param.grad is not None:
                noise = torch.randn_like(param.grad) * self.noise_level  # Add noise
                param.grad += noise  # Mask the true gradient

# Generate dummy data
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])

# Initialize model
model = GradientMaskedModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)

summary(model, (2, ))  # Input shape    

# Training loop with Gradient Masking
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()

    # Apply gradient masking
    model.add_gradient_masking()

    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Checking modified gradients
for name, param in model.named_parameters():
    print(f"Parameter: {name}, Gradient: {param.grad}")


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                    [-1, 1]               3
Total params: 3
Trainable params: 3
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------
Epoch 1, Loss: 3.983057975769043
Epoch 2, Loss: 16.132280349731445
Epoch 3, Loss: 66.94312286376953
Epoch 4, Loss: 279.5446472167969
Epoch 5, Loss: 1162.6827392578125
Parameter: fc.weight, Gradient: tensor([[-146.7402, -211.9694]])
Parameter: fc.bias, Gradient: tensor([-65.0913])


### 5. Active Monitoring and Detection

Anomaly Detection for API Queries

This detects unusual query patterns indicating potential model extraction.

In [6]:
import time
from collections import defaultdict, Counter

# Dictionary to store user query logs with timestamps
query_log = defaultdict(list)

# Thresholds for detection
SIMILARITY_THRESHOLD = 0.5  # 50% of queries must be unique
WINDOW_SIZE = 10  # Track last N queries

def log_request(user_id, query):
    """Logs user queries and detects high-frequency repetition attempts."""
    timestamp = time.time()
    query_log[user_id].append((query, timestamp))

    # Keep only the last N queries for analysis
    if len(query_log[user_id]) > WINDOW_SIZE:
        query_log[user_id].pop(0)

    # Analyze uniqueness within the window
    user_queries = [q for q, _ in query_log[user_id]]
    unique_queries = set(user_queries)
    
    if len(unique_queries) < len(user_queries) * SIMILARITY_THRESHOLD:
        print(f"[ALERT] Potential distillation attempt detected for user '{user_id}'!")
        print(f"Recent queries: {Counter(user_queries)}\n")

# Simulating repeated queries
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "Explain AI.")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "Tell me about AI.")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")
log_request("Black Hat", "What is AI?")


[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 4, 'Explain AI.': 1})

[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 5, 'Explain AI.': 1, 'Tell me about AI.': 1})

[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 6, 'Explain AI.': 1, 'Tell me about AI.': 1})

[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 7, 'Explain AI.': 1, 'Tell me about AI.': 1})

[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 8, 'Explain AI.': 1, 'Tell me about AI.': 1})

[ALERT] Potential distillation attempt detected for user 'Black Hat'!
Recent queries: Counter({'What is AI?': 8, 'Explain AI.': 1, 'Tell me about AI.': 1})



### 6. Model Access Control

Restricting access to the model ensures that only authorized users can interact with it, reducing the risk of distillation attempts.

Rate Limiting: Restrict the number of queries per user to slow down data collection for distillation.

In [7]:
import time
from collections import defaultdict

request_log = defaultdict(list)
RATE_LIMIT = 3  # Limit to 3 queries per minute

def predict(user, request):   
    current_time = time.time()
    request_log[user] = [t for t in request_log[user] if t > current_time - 60]
    
    if len(request_log[user]) >= RATE_LIMIT:
        return {"error": f"Too many requests from {user}. Limit to {RATE_LIMIT} queries per minute." }
        
    request_log[user].append(current_time)
    return {"response": "Model output"}

predict("User", "Tell me a joke")
predict("User", "How to diagnose a car?")
predict("User", "What is LINQ in .NET?")
predict("User", "What is an LLM prompt?")

{'error': 'Too many requests from User. Limit to 3 queries per minute.'}