# Introduction to Triton DSL

**A Domain-Specific Language for Ternary Neural Networks**

---

## Table of Contents

1. [Overview](#overview)
2. [Why Ternary Neural Networks?](#why-ternary)
3. [Installation & Setup](#installation)
4. [Quick Start Tutorial](#quick-start)
5. [MNIST Example](#mnist-example)
6. [Interactive Visualizations](#visualizations)
7. [Next Steps](#next-steps)

## 1. Overview {#overview}

Triton is a high-performance Domain-Specific Language (DSL) designed to optimize **Ternary Neural Networks (TNNs)** by enforcing ternary constraints at the syntax level. This enables significant memory density improvements over standard floating-point representations.

### Key Features

- ðŸŽ¯ **Native Ternary Type System**: Built-in `trit` primitive and `TernaryTensor` data structures
- âš¡ **Zero-Cost Abstractions**: Compile-time type checking with runtime efficiency  
- ðŸš€ **Hardware Optimization**: 2-bit packed storage, CUDA kernels, zero-skipping
- ðŸ”— **PyTorch Integration**: Seamless transpilation to PyTorch modules

### Architecture Pipeline

```
Triton Source (.tri)
    â†“
Lexer/Parser â†’ AST
    â†“
Type Checker
    â†“
Code Generator â†’ PyTorch/CUDA
```

## 2. Why Ternary Neural Networks? {#why-ternary}

Ternary Neural Networks constrain weights to **{-1, 0, 1}**, providing:

### Memory Savings

- **Float32**: 32 bits per weight
- **Ternary**: 2 bits per weight  
- **Compression**: 16x theoretical, 4-6x practical

### Speed Improvements

- Zero-skipping: Skip computations for zero weights
- Simpler operations: Multiplication becomes addition/subtraction
- **2-3x faster inference** on optimized hardware

### Accuracy Trade-offs

| Model | Float32 | Ternary | Accuracy Drop |
|-------|---------|---------|---------------|
| MNIST | 98.5% | 96-97% | 1-2% |
| CIFAR-10 | 92% | 88-90% | 2-4% |
| ImageNet (ResNet-18) | 70% | 65-67% | 3-5% |

## 3. Installation & Setup {#installation}

Let's set up the environment and import necessary libraries.

In [None]:
# Check if we're in the correct directory
import os
import sys
from pathlib import Path

# Add Triton root to path
triton_root = Path.cwd().parent.parent
if str(triton_root) not in sys.path:
    sys.path.insert(0, str(triton_root))

print(f"Triton Root: {triton_root}")
print(f"Python Version: {sys.version}")

In [None]:
# Install required packages (if needed)
# !pip install torch torchvision numpy matplotlib seaborn scikit-learn tqdm

# Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display, Markdown, HTML

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 4. Quick Start Tutorial {#quick-start}

Let's start with a simple example of defining a ternary layer in Triton DSL.

### 4.1 Triton DSL Syntax

Here's a simple ternary layer definition in Triton DSL:

```triton
layer SimpleTernary(in_features: int, out_features: int) -> TernaryTensor {
    let W: TernaryTensor = random_ternary([out_features, in_features])
    
    fn forward(x: Tensor[float16]) -> Tensor[float16] {
        let output = ternary_matmul(x, W)
        return output
    }
}
```

**Key Points:**
- `TernaryTensor`: Type-safe tensor constrained to {-1, 0, 1}
- `random_ternary()`: Initializes weights to random ternary values
- `ternary_matmul()`: Optimized matrix multiplication for ternary weights

### 4.2 PyTorch Backend Implementation

Let's implement the core ternary quantization primitive in PyTorch:

In [None]:
class TernaryQuantize(torch.autograd.Function):
    """
    Ternary quantization with Straight-Through Estimator (STE).
    
    Forward: Quantize weights to {-1, 0, 1}
    Backward: Pass gradients straight through
    """
    
    @staticmethod
    def forward(ctx, input, threshold=0.5):
        """Quantize to ternary values {-1, 0, 1}"""
        output = torch.sign(input)
        output[torch.abs(input) < threshold] = 0
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        """Straight-through estimator: pass gradients unchanged"""
        return grad_output, None


class LinearTernary(nn.Module):
    """Ternary Linear Layer - weights constrained to {-1, 0, 1}"""
    
    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Initialize weights as float for training
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        self.bias = nn.Parameter(torch.zeros(out_features))
    
    def forward(self, x):
        # Quantize weights during forward pass
        weight_ternary = TernaryQuantize.apply(self.weight)
        return F.linear(x, weight_ternary, self.bias)
    
    def extra_repr(self):
        return f'in_features={self.in_features}, out_features={self.out_features}'


# Test the layer
layer = LinearTernary(10, 5)
x = torch.randn(2, 10)
output = layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\nWeight values (before quantization): {layer.weight[0, :5].detach().numpy()}")
print(f"Quantized weights: {TernaryQuantize.apply(layer.weight)[0, :5].detach().numpy()}")

### 4.3 Visualizing Quantization

Let's visualize how weights are quantized to ternary values:

In [None]:
# Generate random weights
weights = torch.randn(1000) * 2
quantized = TernaryQuantize.apply(weights)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original distribution
axes[0].hist(weights.numpy(), bins=50, alpha=0.7, color='blue', edgecolor='black')
axes[0].axvline(-0.5, color='red', linestyle='--', label='Threshold')
axes[0].axvline(0.5, color='red', linestyle='--')
axes[0].set_title('Original Weight Distribution')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Frequency')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Quantized distribution
axes[1].hist(quantized.numpy(), bins=[-1.5, -0.5, 0.5, 1.5], alpha=0.7, 
             color='green', edgecolor='black')
axes[1].set_title('Quantized Weight Distribution')
axes[1].set_xlabel('Weight Value')
axes[1].set_ylabel('Frequency')
axes[1].set_xticks([-1, 0, 1])
axes[1].grid(True, alpha=0.3)

# Scatter plot showing quantization mapping
sample_idx = torch.arange(100)
axes[2].scatter(weights[sample_idx].numpy(), quantized[sample_idx].numpy(), 
               alpha=0.5, s=30)
axes[2].axhline(0, color='gray', linestyle='-', linewidth=0.5)
axes[2].axvline(-0.5, color='red', linestyle='--', alpha=0.5)
axes[2].axvline(0.5, color='red', linestyle='--', alpha=0.5)
axes[2].set_title('Quantization Mapping')
axes[2].set_xlabel('Original Weight')
axes[2].set_ylabel('Quantized Weight')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Statistics
print(f"\nQuantization Statistics:")
print(f"  -1 values: {(quantized == -1).sum().item()} ({(quantized == -1).sum().item()/len(quantized)*100:.1f}%)")
print(f"   0 values: {(quantized == 0).sum().item()} ({(quantized == 0).sum().item()/len(quantized)*100:.1f}%)")
print(f"  +1 values: {(quantized == 1).sum().item()} ({(quantized == 1).sum().item()/len(quantized)*100:.1f}%)")

## 5. MNIST Example {#mnist-example}

Let's build a complete ternary neural network for MNIST digit classification.

### 5.1 Define the Ternary Network

In [None]:
class TernaryNet(nn.Module):
    """Simple Ternary Neural Network for MNIST"""
    
    def __init__(self):
        super().__init__()
        self.fc1 = LinearTernary(784, 256)
        self.fc2 = LinearTernary(256, 128)
        self.fc3 = LinearTernary(128, 10)
    
    def forward(self, x):
        x = x.view(-1, 784)  # Flatten
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Create model
model = TernaryNet().to(device)
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

### 5.2 Load MNIST Dataset

In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Data transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
])

# Download and load datasets
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

### 5.3 Visualize Sample Data

In [None]:
# Get a batch of training data
examples = iter(train_loader)
example_data, example_targets = next(examples)

# Plot samples
fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(example_data[i].squeeze(), cmap='gray')
    ax.set_title(f'Label: {example_targets[i].item()}')
    ax.axis('off')
plt.suptitle('Sample MNIST Images', fontsize=16, y=1.02)
plt.tight_layout()
plt.show()

### 5.4 Training Loop

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total


def evaluate(model, loader, criterion, device):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)
    
    return total_loss / len(loader), 100. * correct / total

### 5.5 Train the Model

In [None]:
from tqdm.notebook import tqdm

# Training configuration
epochs = 5
learning_rate = 0.01
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'test_loss': [],
    'test_acc': []
}

# Training loop
print("Training Ternary Neural Network...\n")
for epoch in range(1, epochs + 1):
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['test_loss'].append(test_loss)
    history['test_acc'].append(test_acc)
    
    print(f"Epoch {epoch}/{epochs}:")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%")
    print()

## 6. Interactive Visualizations {#visualizations}

Let's visualize the training progress and model performance.

### 6.1 Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
axes[0].plot(history['train_loss'], label='Train Loss', marker='o', linewidth=2)
axes[0].plot(history['test_loss'], label='Test Loss', marker='s', linewidth=2)
axes[0].set_title('Loss Curves', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy curves
axes[1].plot(history['train_acc'], label='Train Accuracy', marker='o', linewidth=2)
axes[1].plot(history['test_acc'], label='Test Accuracy', marker='s', linewidth=2)
axes[1].set_title('Accuracy Curves', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final Test Accuracy: {history['test_acc'][-1]:.2f}%")

### 6.2 Weight Distribution Analysis

In [None]:
# Get weights from first layer
weights_float = model.fc1.weight.detach().cpu()
weights_ternary = TernaryQuantize.apply(weights_float)

fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Float weights distribution
axes[0].hist(weights_float.flatten().numpy(), bins=50, alpha=0.7, 
            color='steelblue', edgecolor='black')
axes[0].set_title('Float Weights (Layer 1)', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Weight Value')
axes[0].set_ylabel('Frequency')
axes[0].grid(True, alpha=0.3)

# Ternary weights distribution
ternary_counts = [(weights_ternary == -1).sum().item(),
                  (weights_ternary == 0).sum().item(),
                  (weights_ternary == 1).sum().item()]
axes[1].bar([-1, 0, 1], ternary_counts, color=['red', 'gray', 'green'], 
           alpha=0.7, edgecolor='black')
axes[1].set_title('Ternary Weights (Layer 1)', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Weight Value')
axes[1].set_ylabel('Count')
axes[1].set_xticks([-1, 0, 1])
axes[1].grid(True, alpha=0.3)

# Sparsity visualization
sparsity_per_layer = []
layer_names = []
for name, module in model.named_modules():
    if isinstance(module, LinearTernary):
        w = TernaryQuantize.apply(module.weight.detach().cpu())
        sparsity = (w == 0).sum().item() / w.numel() * 100
        sparsity_per_layer.append(sparsity)
        layer_names.append(name)

axes[2].bar(range(len(sparsity_per_layer)), sparsity_per_layer, 
           color='coral', alpha=0.7, edgecolor='black')
axes[2].set_title('Sparsity by Layer', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Layer')
axes[2].set_ylabel('Sparsity (%)')
axes[2].set_xticks(range(len(layer_names)))
axes[2].set_xticklabels(layer_names, rotation=45)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nWeight Statistics (Layer 1):")
print(f"  -1 weights: {ternary_counts[0]} ({ternary_counts[0]/weights_ternary.numel()*100:.1f}%)")
print(f"   0 weights: {ternary_counts[1]} ({ternary_counts[1]/weights_ternary.numel()*100:.1f}%)")
print(f"  +1 weights: {ternary_counts[2]} ({ternary_counts[2]/weights_ternary.numel()*100:.1f}%)")

### 6.3 Predictions Visualization

In [None]:
# Get test batch
test_examples = iter(test_loader)
test_data, test_targets = next(test_examples)
test_data, test_targets = test_data.to(device), test_targets.to(device)

# Make predictions
model.eval()
with torch.no_grad():
    outputs = model(test_data)
    predictions = outputs.argmax(dim=1)

# Plot samples with predictions
fig, axes = plt.subplots(3, 8, figsize=(16, 6))
for i, ax in enumerate(axes.flat):
    if i < len(test_data):
        img = test_data[i].cpu().squeeze()
        true_label = test_targets[i].cpu().item()
        pred_label = predictions[i].cpu().item()
        
        ax.imshow(img, cmap='gray')
        color = 'green' if true_label == pred_label else 'red'
        ax.set_title(f'T:{true_label} P:{pred_label}', color=color, fontweight='bold')
        ax.axis('off')

plt.suptitle('Predictions (Green=Correct, Red=Incorrect)', 
            fontsize=16, fontweight='bold', y=1.0)
plt.tight_layout()
plt.show()

# Calculate accuracy
correct = (predictions == test_targets).sum().item()
total = len(test_targets)
print(f"Batch Accuracy: {100. * correct / total:.2f}%")

### 6.4 Model Size Comparison

In [None]:
import sys

def calculate_model_size(model, quantized=False):
    """Calculate model size in bytes"""
    total_bits = 0
    
    for param in model.parameters():
        if quantized:
            # Ternary: 2 bits per weight
            total_bits += param.numel() * 2
        else:
            # Float32: 32 bits per weight
            total_bits += param.numel() * 32
    
    return total_bits / 8  # Convert to bytes

# Calculate sizes
float32_size = calculate_model_size(model, quantized=False) / 1024  # KB
ternary_size = calculate_model_size(model, quantized=True) / 1024   # KB
compression_ratio = float32_size / ternary_size

# Visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
models = ['Float32', 'Ternary']
sizes = [float32_size, ternary_size]
colors = ['steelblue', 'green']

bars = ax1.bar(models, sizes, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
ax1.set_ylabel('Model Size (KB)', fontsize=12)
ax1.set_title('Model Size Comparison', fontsize=14, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Add value labels on bars
for bar, size in zip(bars, sizes):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
            f'{size:.1f} KB',
            ha='center', va='bottom', fontweight='bold')

# Compression visualization
ax2.text(0.5, 0.7, f'{compression_ratio:.1f}x', 
        ha='center', va='center', fontsize=72, fontweight='bold', color='green')
ax2.text(0.5, 0.3, 'Compression Ratio', 
        ha='center', va='center', fontsize=18, color='gray')
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.axis('off')

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("MODEL SIZE COMPARISON")
print("="*60)
print(f"Float32 Model:  {float32_size:>10.2f} KB")
print(f"Ternary Model:  {ternary_size:>10.2f} KB")
print(f"Savings:        {float32_size - ternary_size:>10.2f} KB ({(1 - ternary_size/float32_size)*100:.1f}%)")
print(f"Compression:    {compression_ratio:>10.1f}x")
print("="*60)

## 7. Next Steps {#next-steps}

Congratulations! You've learned the basics of Triton DSL and ternary neural networks.

### Continue Learning:

1. **ðŸ“˜ [02_quantization_tutorial.ipynb](02_quantization_tutorial.ipynb)**: Deep dive into quantization techniques
   - Ternary, INT8, and mixed precision
   - QAT vs PTQ comparison
   - Advanced quantization strategies

2. **ðŸ“˜ [03_performance_analysis.ipynb](03_performance_analysis.ipynb)**: Performance optimization
   - Model profiling and benchmarking
   - Memory analysis
   - Speed optimization techniques

### Additional Resources:

- **Documentation**: [docs/QUICKSTART_PYTORCH_BACKEND.md](../../docs/QUICKSTART_PYTORCH_BACKEND.md)
- **Examples**: [examples/](../../examples/)
- **API Reference**: [docs/api/](../../docs/api/)

### Try This:

- Experiment with different network architectures
- Try different quantization thresholds
- Test on CIFAR-10 or other datasets
- Implement stochastic quantization
- Compare with binary neural networks

---

## Summary

In this notebook, you learned:

âœ… What Triton DSL is and why ternary neural networks matter  
âœ… How to implement ternary quantization with PyTorch  
âœ… How to build and train a ternary neural network  
âœ… How to visualize quantization and model performance  
âœ… The memory savings from ternary quantization (16x!)  

**Key Takeaway**: Ternary Neural Networks provide a practical trade-off between model size and accuracy, making them ideal for edge deployment and resource-constrained environments.

---

*Triton DSL - High-Performance Ternary Neural Networks*