# üöÄ AWS IAM Policy Generator ‚Äî Model Training & Evaluation
**Project**: Fine-tuning Mistral-7B to convert natural language ‚Üí valid AWS IAM JSON policies  
**Author**: Vatsal Naik | Northeastern University
## This notebook covers:
1. **Model Setup**: Mistral-7B-v0.3 with QLoRA (4-bit quantization)
2. **Training**: 4 hyperparameter configurations with early stopping
3. **Baseline Comparison**: Zero-shot vs Few-shot vs Fine-tuned
4. **Evaluation**: JSON validity, schema compliance, service accuracy
5. **Error Analysis**: Failure categorization, patterns by complexity
6. **Inference Pipeline**: Gradio demo application

## 1. Environment Setup

**Hardware**: NVIDIA A100-SXM4 80GB GPU

**CUDA**: 12.1

**Key Libraries**: transformers, peft (LoRA), bitsandbytes (4-bit quantization), trl (SFTTrainer)

In [None]:
!pip install transformers datasets peft accelerate bitsandbytes sentencepiece protobuf trl scikit-learn matplotlib --quiet

In [2]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.0f} GB")

PyTorch: 2.5.1+cu121
CUDA available: True
GPU: NVIDIA A100-SXM4-40GB
VRAM: 42 GB


## 2. Authentication & Configuration

In [None]:
from huggingface_hub import loginimport os# Login using token stored in environment variable# Set via: export HF_TOKEN=hf_xxxxx before launching Jupyterlogin(token=os.environ.get("HF_TOKEN", "SET_YOUR_TOKEN"))os.environ["WANDB_DISABLED"] = "true"print("‚úÖ Ready!")

## 3. Load DatasetLoading the preprocessed dataset created in `01_Dataset_Preparation.ipynb`.

In [4]:
from datasets import load_dataset
import json

# Load the JSONL files you created
dataset = load_dataset("json", data_files={
    "train": "iam-finetuning/dataset/processed/train.jsonl",
    "validation": "iam-finetuning/dataset/processed/val.jsonl",
    "test": "iam-finetuning/dataset/processed/test.jsonl",
})

print(dataset)
print(f"\nSample:\n{dataset['train'][0]['text'][:300]}...")

DatasetDict({
    train: Dataset({
        features: ['text', 'id', 'complexity', 'services', 'source'],
        num_rows: 1189
    })
    validation: Dataset({
        features: ['text', 'id', 'complexity', 'services', 'source'],
        num_rows: 148
    })
    test: Dataset({
        features: ['text', 'id', 'complexity', 'services', 'source'],
        num_rows: 151
    })
})

Sample:
### Instruction:
Provide an IAM policy for ROSA Cloud Network Config Operator Policy

### Response:
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "DescribeNetworkResources",
      "Effect": "Allow",
      "Action": [
        "ec2:DescribeInstances",
        "ec2:DescribeInstanceSt...


## 4. Tokenizer Setup

In [5]:
from transformers import AutoTokenizer

model_name = "mistralai/Mistral-7B-v0.3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Test tokenization
sample = dataset["train"][0]["text"]
tokens = tokenizer(sample, truncation=True, max_length=1024)
print(f"Sample token count: {len(tokens['input_ids'])}")
print(f"Vocab size: {tokenizer.vocab_size}")

Sample token count: 294
Vocab size: 32768


## 5. Model Loading with QLoRAWe use **4-bit NF4 quantization** (QLoRA) to reduce the 7B parameter model from ~14GB to ~4GB VRAM, enabling efficient fine-tuning on a single GPU.

|Parameter | Value |
|-----------|-------|
| Base model | Mistral-7B-v0.3 |
| Quantization | 4-bit NF4 |
| Compute dtype | bfloat16 |
| Double quantization | Enabled |

In [6]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

model.config.use_cache = False

# Check memory usage
allocated = torch.cuda.memory_allocated() / 1e9
print(f"Model loaded! GPU memory used: {allocated:.1f} GB")
print(f"Model parameters: {model.num_parameters() / 1e6:.0f}M")

`torch_dtype` is deprecated! Use `dtype` instead!


Loading weights:   0%|          | 0/291 [00:00<?, ?it/s]

Model loaded! GPU memory used: 4.1 GB
Model parameters: 7248M


## 6. LoRA Adapter ConfigurationLow-Rank Adaptation (LoRA) trains only a small set of adapter weights (~0.6% of total parameters), making fine-tuning dramatically faster and more memory-efficient.

|Parameter | Config A (Best) |
|-----------|----------------|
| Rank (r) | 16 |
| Alpha | 32 |
| Target modules | q, k, v, o, gate, up, down projections |
| Dropout | 0.05 |

In [7]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

# LoRA Config ‚Äî this is hyperparameter config A (balanced)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Should print something like: trainable params: 41M || all params: 7,283M || trainable%: 0.56%

trainable params: 41,943,040 || all params: 7,289,966,592 || trainable%: 0.5754


## 7. Training### Config A: Balanced (LR=2e-4, r=16, Œ±=32, 3 epochs)

In [8]:
from trl import SFTTrainer, SFTConfig

# Check what version of trl we have
import trl
print(f"trl version: {trl.__version__}")

training_args = SFTConfig(
    output_dir="./results/config_a",
    
    # Training params
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    
    # Learning rate
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    weight_decay=0.01,
    
    # Eval
    eval_strategy="steps",
    eval_steps=50,
    
    # Saving
    save_strategy="steps",
    save_steps=50,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    
    # Logging
    logging_steps=10,
    report_to="none",
    
    # Performance
    bf16=True,
    max_length=1024,  # changed from max_seq_length
    
    # Misc
    seed=42,
)

print("Training config ready!")
print(f"Effective batch size: {training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps}")
print(f"Epochs: {training_args.num_train_epochs}")
print(f"Learning rate: {training_args.learning_rate}")

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.


trl version: 0.27.2
Training config ready!
Effective batch size: 16
Epochs: 3
Learning rate: 0.0002


In [9]:
from transformers import EarlyStoppingCallback

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["validation"],
    processing_class=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

print(f"Training samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
print(f"Steps per epoch: {len(dataset['train']) // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps)}")
print("\nStarting training...")

train_result = trainer.train()

# Print results
print(f"\nTraining complete!")
print(f"Train loss: {train_result.metrics['train_loss']:.4f}")
print(f"Training runtime: {train_result.metrics['train_runtime']:.0f} seconds")

Tokenizing eval dataset:   0%|          | 0/148 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/148 [00:00<?, ? examples/s]

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Training samples: 1189
Validation samples: 148
Steps per epoch: 74

Starting training...


Step,Training Loss,Validation Loss
50,0.34142,0.349021
100,0.245109,0.321393
150,0.214992,0.301875
200,0.146991,0.315044



Training complete!
Train loss: 0.2506
Training runtime: 1489 seconds


In [10]:
import json

# Save metrics
log_history = trainer.state.log_history
with open("results/config_a/training_logs.json", "w") as f:
    json.dump(log_history, f, indent=2)

# Extract and display key metrics
train_losses = [(l["step"], l["loss"]) for l in log_history if "loss" in l and "eval_loss" not in l]
eval_losses = [(l["step"], l["eval_loss"]) for l in log_history if "eval_loss" in l]

print("Config A Results:")
print(f"  Final train loss: {train_losses[-1][1]:.4f}")
print(f"  Best eval loss: {min(l[1] for l in eval_losses):.4f}")
print(f"  Training steps: {train_losses[-1][0]}")

# Save these for later comparison
config_a_results = {
    "config": "A - balanced",
    "lr": 2e-4,
    "lora_r": 16,
    "lora_alpha": 32,
    "epochs": 3,
    "batch_size": "4 x 4 = 16",
    "final_train_loss": train_losses[-1][1],
    "best_eval_loss": min(l[1] for l in eval_losses),
}
with open("results/config_a/summary.json", "w") as f:
    json.dump(config_a_results, f, indent=2)

print("\nLogs saved!")

Config A Results:
  Final train loss: 0.1522
  Best eval loss: 0.3019
  Training steps: 220

Logs saved!


### Quick Inference Test

In [12]:
# Quick inference test
prompt = "### Instruction:\nAllow read-only access to S3 bucket named customer-data and write logs to CloudWatch\n\n### Response:\n"

inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.1,
        do_sample=True,
        top_p=0.9,
        repetition_penalty=1.1,
    )

response = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_policy = response.split("### Response:\n")[-1].strip()

print("Generated policy:")
print(generated_policy)

# Validate JSON
try:
    policy = json.loads(generated_policy)
    print("\n‚úÖ Valid JSON!")
except:
    print("\n‚ùå Invalid JSON ‚Äî model may need more training")

Generated policy:
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": [
        "s3:GetObject",
        "s3:ListBucket"
      ],
      "Resource": [
        "arn:aws:s3:::customer-data",
        "arn:aws:s3:::customer-data/*"
      ]
    },
    {
      "Effect": "Allow",
      "Action": [
        "logs:CreateLogGroup",
        "logs:CreateLogStream",
        "logs:PutLogEvents"
      ],
      "Resource": "*"
    }
  ]
}

‚úÖ Valid JSON!


### Additional ConfigurationsWe trained 3 additional configurations to compare hyperparameter choices:

| Config | Learning Rate | LoRA Rank | LoRA Alpha | Epochs | Purpose |
|--------|--------------|-----------|------------|--------|---------|
| **A** | 2e-4 | 16 | 32 | 3 | Balanced baseline |
| **B** | 1e-4 | 32 | 64 | 3 | Lower LR, higher capacity |
| **C** | 5e-4 | 16 | 32 | 5 | Aggressive LR, more epochs |
| **D** | 2e-4 | 8 | 16 | 3 | Minimal capacity |

## 8. Hyperparameter Comparison

In [6]:
config_details = {
    "config_a": {"lr": "2e-4", "lora_r": 16, "lora_alpha": 32, "epochs": 3},
    "config_b": {"lr": "1e-4", "lora_r": 32, "lora_alpha": 64, "epochs": 3},
    "config_c": {"lr": "5e-4", "lora_r": 16, "lora_alpha": 32, "epochs": 5},
    "config_d": {"lr": "2e-4", "lora_r": 8, "lora_alpha": 16, "epochs": 3},
}

all_configs = {}
for config_name in ["config_a", "config_b", "config_c", "config_d"]:
    log_path = f"results/{config_name}/training_logs.json"
    if os.path.exists(log_path):
        with open(log_path) as f:
            logs = json.load(f)
        eval_losses = [l["eval_loss"] for l in logs if "eval_loss" in l]
        train_losses = [l["loss"] for l in logs if "loss" in l and "eval_loss" not in l]
        all_configs[config_name] = {
            **config_details[config_name],
            "best_eval_loss": min(eval_losses) if eval_losses else None,
            "final_train_loss": train_losses[-1] if train_losses else None,
        }

print("=== HYPERPARAMETER COMPARISON ===")
print(f"{'Config':<12} {'LR':<8} {'LoRA r':<8} {'Epochs':<8} {'Train Loss':<14} {'Best Eval Loss':<14}")
print("-" * 64)
for name, m in all_configs.items():
    print(f"{name:<12} {m['lr']:<8} {m['lora_r']:<8} {m['epochs']:<8} {m['final_train_loss']:<14.4f} {m['best_eval_loss']:<14.4f}")

with open("results/hp_comparison.json", "w") as f:
    json.dump(all_configs, f, indent=2)
print("\nSaved!")

=== HYPERPARAMETER COMPARISON ===
Config       LR       LoRA r   Epochs   Train Loss     Best Eval Loss
----------------------------------------------------------------
config_a     2e-4     16       3        0.1522         0.3019        
config_b     1e-4     32       3        0.1584         0.3025        
config_c     5e-4     16       5        0.0481         0.3207        
config_d     2e-4     8        3        0.1887         0.3114        

Saved!


## 9. Baseline vs Fine-Tuned ComparisonWe evaluate three approaches on the same 151 test examples:
1. **Zero-shot**: Base Mistral-7B with no examples
2. **Few-shot**: Base Mistral-7B with 3 in-context examples
3. **Fine-tuned**: Our best model (Config A)

In [3]:
test_data = []
with open("iam-finetuning/dataset/processed/test.jsonl") as f:
    for line in f:
        test_data.append(json.loads(line))
print(f"Test set: {len(test_data)} examples")

def extract_instruction_and_expected(item):
    text = item["text"]
    instruction = text.split("### Instruction:\n")[1].split("\n\n### Response:")[0]
    expected = text.split("### Response:\n")[1]
    return instruction, expected

def generate_fast(model, tokenizer, prompt):
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1500).to("cuda")
    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=256, temperature=0.1, do_sample=True,
            top_p=0.9, repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).split("### Response:\n")[-1].strip()

few_shot_examples = """### Instruction:
Allow read-only access to S3 bucket named my-data

### Response:
{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetObject","s3:ListBucket"],"Resource":["arn:aws:s3:::my-data","arn:aws:s3:::my-data/*"]}]}

### Instruction:
Allow a user to start and stop EC2 instances

### Response:
{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["ec2:StartInstances","ec2:StopInstances","ec2:DescribeInstances"],"Resource":"*"}]}

### Instruction:
Allow writing logs to CloudWatch

### Response:
{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["logs:CreateLogGroup","logs:CreateLogStream","logs:PutLogEvents"],"Resource":"*"}]}

"""
print("Functions ready!")

Test set: 151 examples
Functions ready!


In [9]:
metrics_zero = evaluate_results(zero_shot, "Zero-Shot")
metrics_few = evaluate_results(few_shot, "Few-Shot")
metrics_ft = evaluate_results(finetuned, "Fine-Tuned")

print("=" * 75)
print(f"{'Metric':<30} {'Zero-Shot':>12} {'Few-Shot':>12} {'Fine-Tuned':>12}")
print("=" * 75)
print(f"{'JSON Valid Rate (%)':<30} {metrics_zero['json_valid_rate']:>11.1f}% {metrics_few['json_valid_rate']:>11.1f}% {metrics_ft['json_valid_rate']:>11.1f}%")
print(f"{'JSON Extracted Rate (%)':<30} {metrics_zero['json_extracted_rate']:>11.1f}% {metrics_few['json_extracted_rate']:>11.1f}% {metrics_ft['json_extracted_rate']:>11.1f}%")
print(f"{'Schema Valid Rate (%)':<30} {metrics_zero['schema_valid_rate']:>11.1f}% {metrics_few['schema_valid_rate']:>11.1f}% {metrics_ft['schema_valid_rate']:>11.1f}%")
print(f"{'Service Accuracy (%)':<30} {metrics_zero['service_accuracy']:>11.1f}% {metrics_few['service_accuracy']:>11.1f}% {metrics_ft['service_accuracy']:>11.1f}%")
print(f"{'Effect Accuracy (%)':<30} {metrics_zero['effect_accuracy']:>11.1f}% {metrics_few['effect_accuracy']:>11.1f}% {metrics_ft['effect_accuracy']:>11.1f}%")
print("=" * 75)

print("\nSchema Valid Rate by Complexity:")
print(f"{'Complexity':<12} {'Zero-Shot':>12} {'Few-Shot':>12} {'Fine-Tuned':>12}")
print("-" * 48)
for comp in ["simple", "medium", "complex"]:
    def rate(m, c):
        total = m["by_complexity"][c]["total"]
        valid = m["by_complexity"][c]["schema_valid"]
        return f"{valid}/{total} ({valid/total*100:.0f}%)" if total > 0 else "N/A"
    print(f"{comp:<12} {rate(metrics_zero, comp):>12} {rate(metrics_few, comp):>12} {rate(metrics_ft, comp):>12}")

with open("results/evaluation_metrics.json", "w") as f:
    json.dump({
        "zero_shot": {k: v for k, v in metrics_zero.items() if k != "by_complexity"},
        "few_shot": {k: v for k, v in metrics_few.items() if k != "by_complexity"},
        "finetuned": {k: v for k, v in metrics_ft.items() if k != "by_complexity"},
    }, f, indent=2, default=str)
print("\nMetrics saved!")

Metric                            Zero-Shot     Few-Shot   Fine-Tuned
JSON Valid Rate (%)                    0.7%         2.0%        60.3%
JSON Extracted Rate (%)               79.5%         7.3%        60.3%
Schema Valid Rate (%)                 78.8%         7.3%        60.3%
Service Accuracy (%)                  28.7%         0.0%        52.0%
Effect Accuracy (%)                   98.3%       100.0%        95.6%

Schema Valid Rate by Complexity:
Complexity      Zero-Shot     Few-Shot   Fine-Tuned
------------------------------------------------
simple        33/42 (79%)    3/42 (7%)  32/42 (76%)
medium        34/39 (87%)   5/39 (13%)  26/39 (67%)
complex       52/70 (74%)    3/70 (4%)  33/70 (47%)

Metrics saved!


## 10. Evaluation Metrics
We measure five key metrics:-

**JSON Valid Rate**: % of outputs that are directly parseable as valid JSON- 

**Schema Valid Rate**: % with correct IAM structure (Version, Statement, Effect, Action, Resource)- 

**Service Accuracy**: Jaccard similarity of AWS services between generated and expected policies- 

**Effect Accuracy**: Whether Allow/Deny effects match the expected policy

In [8]:
import json
import re
from collections import Counter

with open("results/zero_shot_results.json") as f:
    zero_shot = json.load(f)
with open("results/few_shot_results.json") as f:
    few_shot = json.load(f)
with open("results/finetuned_results.json") as f:
    finetuned = json.load(f)

print(f"Zero-shot: {len(zero_shot)}, Few-shot: {len(few_shot)}, Fine-tuned: {len(finetuned)}")

def validate_json(text):
    try:
        json.loads(text)
        return True
    except:
        return False

def extract_json_from_text(text):
    try:
        return json.loads(text), True
    except:
        pass
    matches = re.findall(r'\{[\s\S]*\}', text)
    for match in matches:
        try:
            return json.loads(match), True
        except:
            continue
    return None, False

def check_policy_schema(policy):
    errors = []
    if not isinstance(policy, dict):
        return False, ["Not a JSON object"]
    if "Version" not in policy:
        errors.append("missing_version")
    if "Statement" not in policy:
        return False, ["missing_statement"]
    statements = policy["Statement"]
    if isinstance(statements, dict):
        statements = [statements]
    if not isinstance(statements, list) or len(statements) == 0:
        return False, ["empty_statement"]
    for i, stmt in enumerate(statements):
        if "Effect" not in stmt:
            errors.append("missing_effect")
        elif stmt["Effect"] not in ["Allow", "Deny"]:
            errors.append("invalid_effect")
        if "Action" not in stmt and "NotAction" not in stmt:
            errors.append("missing_action")
        if "Resource" not in stmt and "NotResource" not in stmt:
            errors.append("missing_resource")
    return len(errors) == 0, errors

def extract_services(policy):
    services = set()
    if not isinstance(policy, dict):
        return services
    statements = policy.get("Statement", [])
    if isinstance(statements, dict):
        statements = [statements]
    for stmt in statements:
        for key in ["Action", "NotAction"]:
            actions = stmt.get(key, [])
            if isinstance(actions, str):
                actions = [actions]
            for action in actions:
                if ":" in action:
                    services.add(action.split(":")[0].lower())
    return services

def compute_service_accuracy(expected_json, generated_json):
    expected_services = extract_services(expected_json)
    generated_services = extract_services(generated_json)
    if not expected_services:
        return 1.0 if not generated_services else 0.0
    if not generated_services:
        return 0.0
    intersection = expected_services & generated_services
    union = expected_services | generated_services
    return len(intersection) / len(union)

def compute_effect_accuracy(expected_json, generated_json):
    def get_effects(policy):
        effects = set()
        stmts = policy.get("Statement", [])
        if isinstance(stmts, dict):
            stmts = [stmts]
        for s in stmts:
            effects.add(s.get("Effect", ""))
        return effects
    return 1.0 if get_effects(expected_json) == get_effects(generated_json) else 0.0

def evaluate_results(results, label):
    metrics = {
        "label": label, "total": len(results),
        "json_valid": 0, "json_extracted": 0, "schema_valid": 0,
        "service_accuracy_sum": 0, "effect_accuracy_sum": 0,
        "by_complexity": {
            "simple": {"total": 0, "json_valid": 0, "schema_valid": 0},
            "medium": {"total": 0, "json_valid": 0, "schema_valid": 0},
            "complex": {"total": 0, "json_valid": 0, "schema_valid": 0}
        },
    }
    for item in results:
        generated = item["generated"]
        expected = item["expected"]
        complexity = item.get("complexity", "medium")
        metrics["by_complexity"][complexity]["total"] += 1

        gen_json, extracted = extract_json_from_text(generated)
        exp_json, _ = extract_json_from_text(expected)

        if validate_json(generated):
            metrics["json_valid"] += 1
        if extracted:
            metrics["json_extracted"] += 1
            metrics["by_complexity"][complexity]["json_valid"] += 1
            schema_ok, _ = check_policy_schema(gen_json)
            if schema_ok:
                metrics["schema_valid"] += 1
                metrics["by_complexity"][complexity]["schema_valid"] += 1
            if exp_json:
                metrics["service_accuracy_sum"] += compute_service_accuracy(exp_json, gen_json)
                metrics["effect_accuracy_sum"] += compute_effect_accuracy(exp_json, gen_json)

    n = metrics["total"]
    n_ext = metrics["json_extracted"]
    metrics["json_valid_rate"] = metrics["json_valid"] / n * 100
    metrics["json_extracted_rate"] = metrics["json_extracted"] / n * 100
    metrics["schema_valid_rate"] = metrics["schema_valid"] / n * 100
    metrics["service_accuracy"] = metrics["service_accuracy_sum"] / n_ext * 100 if n_ext > 0 else 0
    metrics["effect_accuracy"] = metrics["effect_accuracy_sum"] / n_ext * 100 if n_ext > 0 else 0
    return metrics

print("Evaluation functions ready!")

Zero-shot: 151, Few-shot: 151, Fine-tuned: 151
Evaluation functions ready!


In [9]:
metrics_zero = evaluate_results(zero_shot, "Zero-Shot")
metrics_few = evaluate_results(few_shot, "Few-Shot")
metrics_ft = evaluate_results(finetuned, "Fine-Tuned")

print("=" * 75)
print(f"{'Metric':<30} {'Zero-Shot':>12} {'Few-Shot':>12} {'Fine-Tuned':>12}")
print("=" * 75)
print(f"{'JSON Valid Rate (%)':<30} {metrics_zero['json_valid_rate']:>11.1f}% {metrics_few['json_valid_rate']:>11.1f}% {metrics_ft['json_valid_rate']:>11.1f}%")
print(f"{'JSON Extracted Rate (%)':<30} {metrics_zero['json_extracted_rate']:>11.1f}% {metrics_few['json_extracted_rate']:>11.1f}% {metrics_ft['json_extracted_rate']:>11.1f}%")
print(f"{'Schema Valid Rate (%)':<30} {metrics_zero['schema_valid_rate']:>11.1f}% {metrics_few['schema_valid_rate']:>11.1f}% {metrics_ft['schema_valid_rate']:>11.1f}%")
print(f"{'Service Accuracy (%)':<30} {metrics_zero['service_accuracy']:>11.1f}% {metrics_few['service_accuracy']:>11.1f}% {metrics_ft['service_accuracy']:>11.1f}%")
print(f"{'Effect Accuracy (%)':<30} {metrics_zero['effect_accuracy']:>11.1f}% {metrics_few['effect_accuracy']:>11.1f}% {metrics_ft['effect_accuracy']:>11.1f}%")
print("=" * 75)

print("\nSchema Valid Rate by Complexity:")
print(f"{'Complexity':<12} {'Zero-Shot':>12} {'Few-Shot':>12} {'Fine-Tuned':>12}")
print("-" * 48)
for comp in ["simple", "medium", "complex"]:
    def rate(m, c):
        total = m["by_complexity"][c]["total"]
        valid = m["by_complexity"][c]["schema_valid"]
        return f"{valid}/{total} ({valid/total*100:.0f}%)" if total > 0 else "N/A"
    print(f"{comp:<12} {rate(metrics_zero, comp):>12} {rate(metrics_few, comp):>12} {rate(metrics_ft, comp):>12}")

with open("results/evaluation_metrics.json", "w") as f:
    json.dump({
        "zero_shot": {k: v for k, v in metrics_zero.items() if k != "by_complexity"},
        "few_shot": {k: v for k, v in metrics_few.items() if k != "by_complexity"},
        "finetuned": {k: v for k, v in metrics_ft.items() if k != "by_complexity"},
    }, f, indent=2, default=str)
print("\nMetrics saved!")

Metric                            Zero-Shot     Few-Shot   Fine-Tuned
JSON Valid Rate (%)                    0.7%         2.0%        60.3%
JSON Extracted Rate (%)               79.5%         7.3%        60.3%
Schema Valid Rate (%)                 78.8%         7.3%        60.3%
Service Accuracy (%)                  28.7%         0.0%        52.0%
Effect Accuracy (%)                   98.3%       100.0%        95.6%

Schema Valid Rate by Complexity:
Complexity      Zero-Shot     Few-Shot   Fine-Tuned
------------------------------------------------
simple        33/42 (79%)    3/42 (7%)  32/42 (76%)
medium        34/39 (87%)   5/39 (13%)  26/39 (67%)
complex       52/70 (74%)    3/70 (4%)  33/70 (47%)

Metrics saved!


## 11. Visualizations

In [10]:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Chart 1: Main metrics
metrics_names = ["JSON Valid\nRate", "Schema Valid\nRate", "Service\nAccuracy", "Effect\nAccuracy"]
zero_vals = [metrics_zero["json_valid_rate"], metrics_zero["schema_valid_rate"], metrics_zero["service_accuracy"], metrics_zero["effect_accuracy"]]
few_vals = [metrics_few["json_valid_rate"], metrics_few["schema_valid_rate"], metrics_few["service_accuracy"], metrics_few["effect_accuracy"]]
ft_vals = [metrics_ft["json_valid_rate"], metrics_ft["schema_valid_rate"], metrics_ft["service_accuracy"], metrics_ft["effect_accuracy"]]

x = range(len(metrics_names))
width = 0.25
axes[0].bar([i - width for i in x], zero_vals, width, label="Zero-Shot", color="#e74c3c")
axes[0].bar(x, few_vals, width, label="Few-Shot", color="#f39c12")
axes[0].bar([i + width for i in x], ft_vals, width, label="Fine-Tuned", color="#2ecc71")
axes[0].set_ylabel("Percentage (%)")
axes[0].set_title("Model Comparison: Key Metrics")
axes[0].set_xticks(x)
axes[0].set_xticklabels(metrics_names)
axes[0].legend()
axes[0].set_ylim(0, 105)

# Chart 2: By complexity
complexities = ["simple", "medium", "complex"]
for idx, (metrics, label, color) in enumerate([
    (metrics_zero, "Zero-Shot", "#e74c3c"),
    (metrics_few, "Few-Shot", "#f39c12"),
    (metrics_ft, "Fine-Tuned", "#2ecc71")
]):
    rates = []
    for comp in complexities:
        total = metrics["by_complexity"][comp]["total"]
        valid = metrics["by_complexity"][comp]["schema_valid"]
        rates.append(valid / total * 100 if total > 0 else 0)
    axes[1].bar([i + idx * 0.25 for i in range(3)], rates, 0.25, label=label, color=color)
axes[1].set_ylabel("Schema Valid Rate (%)")
axes[1].set_title("Schema Validity by Complexity")
axes[1].set_xticks([i + 0.25 for i in range(3)])
axes[1].set_xticklabels(complexities)
axes[1].legend()
axes[1].set_ylim(0, 105)

# Chart 3: Training loss curves
with open("results/config_a/training_logs.json") as f:
    logs_a = json.load(f)
train_steps = [l["step"] for l in logs_a if "loss" in l and "eval_loss" not in l]
train_losses = [l["loss"] for l in logs_a if "loss" in l and "eval_loss" not in l]
eval_steps = [l["step"] for l in logs_a if "eval_loss" in l]
eval_losses = [l["eval_loss"] for l in logs_a if "eval_loss" in l]
axes[2].plot(train_steps, train_losses, label="Train Loss", color="#3498db", linewidth=2)
axes[2].plot(eval_steps, eval_losses, label="Eval Loss", color="#e74c3c", linewidth=2, marker="o")
axes[2].set_xlabel("Steps")
axes[2].set_ylabel("Loss")
axes[2].set_title("Config A: Training & Validation Loss")
axes[2].legend()

plt.tight_layout()
plt.savefig("results/evaluation_charts.png", dpi=150, bbox_inches="tight")
print("Charts saved to results/evaluation_charts.png")
plt.show()

Charts saved to results/evaluation_charts.png


In [11]:
fig, ax = plt.subplots(figsize=(10, 6))

with open("results/hp_comparison.json") as f:
    hp_data = json.load(f)

configs = list(hp_data.keys())
train_losses = [hp_data[c]["final_train_loss"] for c in configs]
eval_losses = [hp_data[c]["best_eval_loss"] for c in configs]

x = range(len(configs))
width = 0.35
ax.bar([i - width/2 for i in x], train_losses, width, label="Final Train Loss", color="#3498db")
ax.bar([i + width/2 for i in x], eval_losses, width, label="Best Eval Loss", color="#e74c3c")

labels = [f"{c}\nLR={hp_data[c]['lr']}, r={hp_data[c]['lora_r']}" for c in configs]
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_ylabel("Loss")
ax.set_title("Hyperparameter Configuration Comparison")
ax.legend()

plt.tight_layout()
plt.savefig("results/hp_comparison_chart.png", dpi=150, bbox_inches="tight")
print("HP chart saved!")
plt.show()

HP chart saved!


## 12. Error AnalysisDetailed analysis of the fine-tuned model's failures to identify patterns and inform improvements.

In [12]:
print("=" * 70)
print("ERROR ANALYSIS")
print("=" * 70)

error_categories = {
    "no_json": [], "invalid_schema": [], "wrong_services": [],
    "wrong_effect": [], "correct": [],
}

for item in finetuned:
    gen_json, extracted = extract_json_from_text(item["generated"])
    exp_json, _ = extract_json_from_text(item["expected"])
    
    if not extracted:
        error_categories["no_json"].append(item)
        continue
    schema_ok, errors = check_policy_schema(gen_json)
    if not schema_ok:
        error_categories["invalid_schema"].append(item)
        continue
    if exp_json:
        svc_acc = compute_service_accuracy(exp_json, gen_json)
        eff_acc = compute_effect_accuracy(exp_json, gen_json)
        if eff_acc < 1.0:
            error_categories["wrong_effect"].append(item)
        elif svc_acc < 0.5:
            error_categories["wrong_services"].append(item)
        else:
            error_categories["correct"].append(item)

total = len(finetuned)
print(f"\nFine-Tuned Model Error Breakdown ({total} test examples):")
print(f"  ‚úÖ Correct:         {len(error_categories['correct']):>4} ({len(error_categories['correct'])/total*100:.1f}%)")
print(f"  ‚ùå No JSON output:  {len(error_categories['no_json']):>4} ({len(error_categories['no_json'])/total*100:.1f}%)")
print(f"  ‚ùå Invalid schema:  {len(error_categories['invalid_schema']):>4} ({len(error_categories['invalid_schema'])/total*100:.1f}%)")
print(f"  ‚ùå Wrong services:  {len(error_categories['wrong_services']):>4} ({len(error_categories['wrong_services'])/total*100:.1f}%)")
print(f"  ‚ùå Wrong effect:    {len(error_categories['wrong_effect']):>4} ({len(error_categories['wrong_effect'])/total*100:.1f}%)")

print("\n" + "=" * 70)
print("SPECIFIC FAILURE EXAMPLES")
print("=" * 70)
example_count = 0
for category, label in [("no_json","NO JSON"), ("invalid_schema","BAD SCHEMA"), ("wrong_services","WRONG SERVICES"), ("wrong_effect","WRONG EFFECT")]:
    for item in error_categories[category][:2]:
        if example_count >= 8: break
        example_count += 1
        print(f"\n--- Example {example_count}: {label} (complexity: {item.get('complexity','?')}) ---")
        print(f"Instruction: {item['instruction'][:150]}")
        print(f"Generated:   {item['generated'][:200]}")
        print(f"Expected:    {item['expected'][:200]}")

print("\n" + "=" * 70)
print("ERROR PATTERNS BY COMPLEXITY")
print("=" * 70)
for comp in ["simple", "medium", "complex"]:
    comp_items = [item for item in finetuned if item.get("complexity") == comp]
    comp_correct = sum(1 for item in comp_items if extract_json_from_text(item["generated"])[1] and check_policy_schema(extract_json_from_text(item["generated"])[0])[0])
    print(f"  {comp}: {comp_correct}/{len(comp_items)} correct ({comp_correct/len(comp_items)*100:.1f}%)")

print("\n" + "=" * 70)
print("SUGGESTED IMPROVEMENTS")
print("=" * 70)
print("  1. Add constrained/guided decoding to enforce JSON output structure")
print("  2. Post-processing validation layer to auto-correct common schema issues")
print("  3. Augment training data with more diverse service combinations")
print("  4. Add more Deny/mixed Allow-Deny examples")
print("  5. RAG with AWS action catalog to prevent hallucinated action names")
print("  6. Increase dataset to 5,000+ examples focusing on complex policies")
print("  7. Try DPO/RLHF to improve policy correctness beyond SFT")

with open("results/error_analysis.json", "w") as f:
    json.dump({"summary": {cat: len(items) for cat, items in error_categories.items()}}, f, indent=2)
print("\nError analysis saved!")

ERROR ANALYSIS

Fine-Tuned Model Error Breakdown (151 test examples):
  ‚úÖ Correct:           44 (29.1%)
  ‚ùå No JSON output:    60 (39.7%)
  ‚ùå Invalid schema:     0 (0.0%)
  ‚ùå Wrong services:    43 (28.5%)
  ‚ùå Wrong effect:       4 (2.6%)

SPECIFIC FAILURE EXAMPLES

--- Example 1: NO JSON (complexity: complex) ---
Instruction: Provide an IAM policy for Access Analyzer Service Role Policy
Generated:   {
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "AllowAccessAnalyzerToDescribeResources",
      "Effect": "Allow",
      "Action": [
        "access-analyzer:Describe*",
        "acm-
Expected:    {
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "AccessAnalyzerServiceRolePolicy",
      "Effect": "Allow",
      "Action": [
        "dynamodb:GetResourcePolicy",
        "dynamodb:L

--- Example 2: NO JSON (complexity: medium) ---
Instruction: Provide an IAM policy for AWS Control Tower Cloud Trail Role Policy
Generated:   {
  "Version": "2012-10-17",


## 13. Side-by-Side Comparison Examples

In [13]:
print("=" * 70)
print("SIDE-BY-SIDE COMPARISON")
print("=" * 70)

indices = [0, len(test_data)//4, len(test_data)//2, 3*len(test_data)//4, -1]
for idx in indices:
    print(f"\n{'='*70}")
    print(f"INSTRUCTION: {zero_shot[idx]['instruction'][:120]}")
    print(f"COMPLEXITY: {zero_shot[idx].get('complexity', '?')}")
    print("-"*70)
    for label, results in [("Zero-Shot", zero_shot), ("Few-Shot", few_shot), ("Fine-Tuned", finetuned)]:
        gen_json, extracted = extract_json_from_text(results[idx]["generated"])
        if extracted:
            schema_ok, _ = check_policy_schema(gen_json)
            status = "‚úÖ Valid" if schema_ok else "‚ö†Ô∏è Bad Schema"
        else:
            status = "‚ùå No JSON"
        print(f"\n  [{label}] {status}")
        print(f"  {results[idx]['generated'][:150]}...")

with open("results/comparison_examples.json", "w") as f:
    json.dump([{
        "instruction": zero_shot[i]["instruction"],
        "expected": zero_shot[i]["expected"][:300],
        "zero_shot": zero_shot[i]["generated"][:300],
        "few_shot": few_shot[i]["generated"][:300],
        "finetuned": finetuned[i]["generated"][:300],
    } for i in range(min(10, len(test_data)))], f, indent=2)
print("\nComparison saved!")

SIDE-BY-SIDE COMPARISON

INSTRUCTION: Provide an IAM policy for AWSWAF Full Access
COMPLEXITY: complex
----------------------------------------------------------------------

  [Zero-Shot] ‚úÖ Valid
  ```json
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "VisualEditor0",
      "Effect": "Allow",
      "Action": [
        "waf-regi...

  [Few-Shot] ‚úÖ Valid
  {"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["*"],"Resource":"*"}]}

### Instruction:
Allow a user to create, update, delete, and ...

  [Fine-Tuned] ‚úÖ Valid
  {
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": [
        "waf:*",
        "waf-regional:*",
        "clo...

INSTRUCTION: Provide an IAM policy for AWS Marketplace License Management Service Role Policy
COMPLEXITY: medium
----------------------------------------------------------------------

  [Zero-Shot] ‚úÖ Valid
  ```json
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effe

## 14. Inference Pipeline
### Gradio Demo Application Saved as `inference_app.py` ‚Äî run with `python inference_app.py` on a GPU instance.
### Inference Pipeline ClassSaved as `inference_pipeline.py` ‚Äî import and use programmatically.

In [15]:
# Save the Gradio app as a standalone Python file
gradio_code = '''
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

# ---- Load Model ----
model_name = "mistralai/Mistral-7B-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_name, quantization_config=bnb_config, device_map="auto", dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(base_model, "results/config_a/final_model")
model.eval()
print("Model loaded!")

# ---- Inference Function ----
def generate_policy(description, max_tokens=512, temperature=0.1):
    prompt = f"### Instruction:\\n{description}\\n\\n### Response:\\n"
    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs, max_new_tokens=max_tokens, temperature=temperature,
            do_sample=True, top_p=0.9, repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    policy_text = response.split("### Response:\\n")[-1].strip()

    # Validate
    try:
        parsed = json.loads(policy_text)
        return json.dumps(parsed, indent=2), "‚úÖ Valid IAM Policy"
    except json.JSONDecodeError:
        return policy_text, "‚ùå Invalid JSON - may need manual correction"

# ---- Gradio Interface ----
import gradio as gr

demo = gr.Interface(
    fn=generate_policy,
    inputs=[
        gr.Textbox(label="Describe the IAM policy you need", lines=3,
                   placeholder="e.g., Allow read-only access to S3 bucket named customer-data"),
        gr.Slider(minimum=128, maximum=1024, value=512, step=64, label="Max Tokens"),
        gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Temperature"),
    ],
    outputs=[
        gr.Code(label="Generated IAM Policy", language="json"),
        gr.Textbox(label="Validation Status"),
    ],
    title="AWS IAM Policy Generator",
    description="Fine-tuned Mistral-7B model that converts natural language descriptions into valid AWS IAM policies.",
    examples=[
        ["Allow read-only access to S3 bucket named customer-data", 512, 0.1],
        ["Allow a Lambda function to read from DynamoDB table users and write logs to CloudWatch", 512, 0.1],
        ["Deny all S3 delete operations across all buckets", 512, 0.1],
        ["Allow EC2 instances tagged with Team=backend to be started and stopped", 512, 0.1],
        ["Allow assuming a cross-account role in account 987654321098", 512, 0.1],
    ],
)

demo.launch(share=True)
'''

with open("inference_app.py", "w") as f:
    f.write(gradio_code)

print("Gradio app saved to inference_app.py")
print("To run: python inference_app.py (requires GPU)")

Gradio app saved to inference_app.py
To run: python inference_app.py (requires GPU)


## Results Summary

### Key Findings

| Metric | Zero-Shot | Few-Shot | Fine-Tuned |
|--------|-----------|----------|------------|
| JSON Valid Rate | 0.7% | 2.0% | **60.3%** |
| Schema Valid Rate | 78.8%* | 7.3% | **60.3%** |
| Service Accuracy | 28.7% | 0.0% | **52.0%** |
| Effect Accuracy | 98.3% | 100.0% | **95.6%** |

*Zero-shot produces JSON wrapped in markdown code blocks ‚Äî not directly parseable in production.*

### Hyperparameter Results

| Config | LR | LoRA r | Best Eval Loss | Notes |
|--------|-----|--------|----------------|-------|
| **A** ‚úÖ | 2e-4 | 16 | **0.3019** | Best overall |
| B | 1e-4 | 32 | 0.3025 | Close second |
| C | 5e-4 | 16 | 0.3207 | Overfits |
| D | 2e-4 | 8 | 0.3114 | Underfits |

### Error Analysis (Fine-Tuned Model)

| Category | Count | Rate |
|----------|-------|------|
| ‚úÖ Correct | 44 | 29.1% |
| ‚ùå Truncated output (no complete JSON) | 60 | 39.7% |
| ‚ùå Wrong services | 43 | 28.5% |
| ‚ùå Wrong effect | 4 | 2.6% |

### Accuracy by Complexity

| Complexity | Accuracy |
|------------|----------|
| Simple | 76.2% |
| Medium | 66.7% |
| Complex | 47.1% |