# Robust Vision: Quick Start

This notebook demonstrates how to train and evaluate a robust vision model in under 5 minutes.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/or4k2l/Truth-Seeking-Pattern-Matching/blob/main/notebooks/quickstart.ipynb)

## 1. Installation

Install the package (skip if already installed):

In [None]:
# For Colab
# !pip install git+https://github.com/or4k2l/Truth-Seeking-Pattern-Matching.git

# For local development
import sys
sys.path.insert(0, '../src')

## 2. Import Libraries

In [None]:
import jax
import jax.numpy as jnp
from robust_vision.data.loaders import ScalableDataLoader
from robust_vision.data.noise import NoiseLibrary
from robust_vision.models.cnn import ProductionCNN
from robust_vision.training.trainer import ProductionTrainer
from robust_vision.evaluation.robustness import RobustnessEvaluator
from robust_vision.evaluation.visualization import plot_robustness_curves

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

## 3. Load Dataset

We'll use CIFAR-10 for this quick demo:

In [None]:
# Create data loader
data_loader = ScalableDataLoader(
    dataset_name="cifar10",
    batch_size=128,
    image_size=(32, 32),
    cache=True,
    prefetch=True
)

# Load train and test datasets
train_ds = data_loader.get_train_loader()
test_ds = data_loader.get_test_loader()

print("Datasets loaded!")

## 4. Train Model

Train a model for 5 epochs (fast for demo):

In [None]:
# Initialize model
model = ProductionCNN(
    n_classes=10,
    features=[64, 128],  # Smaller for speed
    dropout_rate=0.3
)

# Initialize trainer
trainer = ProductionTrainer(
    model=model,
    num_classes=10,
    learning_rate=1e-3,
    weight_decay=1e-4,
    ema_decay=0.99,
    loss_type="label_smoothing",
    loss_kwargs={"smoothing": 0.1}
)

# Train
rng = jax.random.PRNGKey(42)
state = trainer.train(
    rng=rng,
    train_ds=train_ds,
    eval_ds=test_ds,
    num_epochs=5,  # Quick demo
    input_shape=(1, 32, 32, 3),
    eval_every=1
)

print("\nTraining completed!")

## 5. Evaluate Robustness

Test the model against different types of noise:

In [None]:
# Use EMA parameters for evaluation
eval_params = state.ema_params if state.ema_params is not None else state.params

# Initialize robustness evaluator
evaluator = RobustnessEvaluator(
    model_apply_fn=model.apply,
    params=eval_params,
    num_classes=10,
    noise_types=['gaussian', 'salt_pepper', 'fog'],
    severities=[0.0, 0.1, 0.2, 0.3],
    rng_key=jax.random.PRNGKey(123)
)

# Evaluate (on subset for speed)
results = evaluator.evaluate_dataset(
    dataset=test_ds,
    max_batches=10  # Quick evaluation
)

# Print summary
evaluator.print_summary(results)

## 6. Visualize Results

Create robustness curves:

In [None]:
import matplotlib.pyplot as plt

# Plot robustness curves
fig = plot_robustness_curves(results, metric='accuracy')
plt.show()

print("\nDone! Your model is trained and evaluated.")

## 7. Save Model

Save the trained model:

In [None]:
from flax.training import checkpoints

# Save checkpoint
checkpoints.save_checkpoint(
    ckpt_dir='./checkpoints',
    target=state,
    step=5,
    prefix='quickstart_'
)

print("Model saved to ./checkpoints")

## Next Steps

- Train for more epochs for better performance
- Try different loss functions (margin, focal, combined)
- Evaluate on full test set
- Experiment with different architectures
- Check out the [documentation](../docs/) for more details