In [16]:
# Classification dataset 1: MNIST
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np

# Load MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_set = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_set = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1000, shuffle=False)

# Define Neural Network
class PrunableNet(nn.Module):
    def __init__(self):
        super(PrunableNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

        # Initialize binary masks for pruning
        self.masks = {
            "fc1": torch.ones_like(self.fc1.weight),
            "fc2": torch.ones_like(self.fc2.weight),
            "fc3": torch.ones_like(self.fc3.weight),
        }

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))  # Apply weight masks during computation
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def apply_pruning(self):
        """Apply pruning by setting weights to zero based on masks."""
        self.fc1.weight.data *= self.masks["fc1"]
        self.fc2.weight.data *= self.masks["fc2"]
        self.fc3.weight.data *= self.masks["fc3"]

    def count_connections(self):
        """Count active connections based on binary masks."""
        active_connections = {
            "fc1": torch.sum(self.masks["fc1"]),
            "fc2": torch.sum(self.masks["fc2"]),
            "fc3": torch.sum(self.masks["fc3"]),
        }
        return active_connections

# Pruning Logic
class AutomataPruner:
    def __init__(self, weights, mask, threshold=0.01, memory_depth=5):
        self.weights = weights
        self.mask = mask
        self.threshold = threshold
        self.memory_depth = memory_depth
        self.states = np.zeros(weights.shape, dtype=int)

    def update(self):
        """Automaton updates the pruning decisions."""
        for index in np.ndindex(self.weights.shape):
            if abs(self.weights[index]) < self.threshold:
                self.states[index] += 1
                if self.states[index] >= self.memory_depth:
                    self.mask[index] = 0  # Prune
            else:
                self.states[index] = max(self.states[index] - 1, 0)  # Reward

# Model Summary
def model_summary(model, input_size):
    def register_hook(module):
        def hook(module, inputs, outputs):
            class_name = module.__class__.__name__
            module_idx = len(summary)
            m_key = f"{class_name}-{module_idx + 1}"
            summary[m_key] = {
                "input_shape": list(inputs[0].size()),
                "output_shape": list(outputs.size()),
                "num_params": sum(p.numel() for p in module.parameters() if p.requires_grad),
            }

        if not isinstance(module, nn.Sequential) and not isinstance(module, nn.ModuleList) and module != model:
            hooks.append(module.register_forward_hook(hook))

    summary = {}
    hooks = []
    model.apply(register_hook)
    with torch.no_grad():
        model(torch.zeros(1, *input_size))
    for h in hooks:
        h.remove()
    print("Model Summary")
    print("=" * 80)
    print(f"{'Layer':<30} {'Input Shape':<30} {'Output Shape':<30} {'Param #':<20}")
    print("=" * 80)
    total_params = 0
    for layer, info in summary.items():
        total_params += info["num_params"]
        print(f"{layer:<30} {str(info['input_shape']):<30} {str(info['output_shape']):<30} {info['num_params']:<20}")
    print("=" * 80)
    print(f"Total Parameters: {total_params}")
    return total_params

# Train and Prune
def train_and_prune(model, train_loader, criterion, optimizer, pruners=None, epochs=10):
    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        if pruners:  # Update pruners if pruning is enabled
            with torch.no_grad():
                for pruner in pruners:
                    pruner.update()
            model.apply_pruning()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

    training_time = time.time() - start_time
    return training_time

# Evaluate Function
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    accuracy = correct / len(test_loader.dataset)
    print(f"Accuracy: {accuracy * 100:.2f}%")
    return accuracy

# Main Execution
if __name__ == "__main__":
    input_size = (1, 28, 28)

    # Unpruned Model
    model_without_pruning = PrunableNet()
    print("\nUnpruned Model Summary:")
    params_no_pruning = model_summary(model_without_pruning, input_size)

    optimizer = optim.SGD(model_without_pruning.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    print("\nTraining without pruning...")
    time_no_pruning = train_and_prune(model_without_pruning, train_loader, criterion, optimizer)
    acc_no_pruning = evaluate(model_without_pruning, test_loader)

    # Pruned Model
    pruned_model = PrunableNet()
    print("\nPruned Model Summary:")
    params_with_pruning = model_summary(pruned_model, input_size)

    optimizer = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9)
    pruners = [
        AutomataPruner(pruned_model.fc1.weight.data.numpy(), pruned_model.masks["fc1"]),
        AutomataPruner(pruned_model.fc2.weight.data.numpy(), pruned_model.masks["fc2"]),
        AutomataPruner(pruned_model.fc3.weight.data.numpy(), pruned_model.masks["fc3"]),
    ]
    print("\nTraining with pruning...")
    time_with_pruning = train_and_prune(pruned_model, train_loader, criterion, optimizer, pruners)
    acc_with_pruning = evaluate(pruned_model, test_loader)

    # Count active connections (weights)
    connections_no_pruning = sum(p.numel() for p in model_without_pruning.parameters() if p.requires_grad)
    active_connections_with_pruning = pruned_model.count_connections()

    print("\nComparison Report")
    print(f"Training Time (No Pruning): {time_no_pruning:.2f} seconds")
    print(f"Training Time (With Pruning): {time_with_pruning:.2f} seconds")
    print(f"Accuracy (No Pruning): {acc_no_pruning * 100:.2f}%")
    print(f"Accuracy (With Pruning): {acc_with_pruning * 100:.2f}%")
    print(f"Active Connections (No Pruning): {connections_no_pruning}")
    print(f"Active Connections (With Pruning):")
    print(f"  fc1: {active_connections_with_pruning['fc1']}")
    print(f"  fc2: {active_connections_with_pruning['fc2']}")
    print(f"  fc3: {active_connections_with_pruning['fc3']}")



Unpruned Model Summary:
Model Summary
Layer                          Input Shape                    Output Shape                   Param #             
Linear-1                       [1, 784]                       [1, 512]                       401920              
Linear-2                       [1, 512]                       [1, 256]                       131328              
Linear-3                       [1, 256]                       [1, 10]                        2570                
Total Parameters: 535818

Training without pruning...
Epoch 1/10, Loss: 0.16198521852493286
Epoch 2/10, Loss: 0.11243375390768051
Epoch 3/10, Loss: 0.1772448718547821
Epoch 4/10, Loss: 0.04708769917488098
Epoch 5/10, Loss: 0.047526564449071884
Epoch 6/10, Loss: 0.03601180389523506
Epoch 7/10, Loss: 0.0051420629024505615
Epoch 8/10, Loss: 0.14726479351520538
Epoch 9/10, Loss: 0.0082174614071846
Epoch 10/10, Loss: 0.08859815448522568
Accuracy: 97.81%

Pruned Model Summary:
Model Summary
Layer          

In [19]:
# Classification dataset 1: Breast Cancer
import time
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import numpy as np

# Load and prepare breast cancer dataset
def load_breast_cancer_data(batch_size=32):
    data = load_breast_cancer()
    X = data.data
    y = data.target

    # Standardize features
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Convert to PyTorch datasets
    class BreastCancerDataset(torch.utils.data.Dataset):
        def __init__(self, X, y):
            self.X = torch.FloatTensor(X)
            self.y = torch.LongTensor(y)

        def __len__(self):
            return len(self.X)

        def __getitem__(self, idx):
            return self.X[idx], self.y[idx]

    train_set = BreastCancerDataset(X_train, y_train)
    test_set = BreastCancerDataset(X_test, y_test)

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader, X.shape[1]

# Neural Network for Breast Cancer classification
class PrunableNet(nn.Module):
    def __init__(self, input_size):
        super(PrunableNet, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 2)  # 2 classes: malignant or benign

        # Initialize binary masks for pruning
        self.masks = {
            "fc1": torch.ones_like(self.fc1.weight),
            "fc2": torch.ones_like(self.fc2.weight),
            "fc3": torch.ones_like(self.fc3.weight),
        }

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def apply_pruning(self):
        """Apply pruning by setting weights to zero based on masks."""
        self.fc1.weight.data *= self.masks["fc1"]
        self.fc2.weight.data *= self.masks["fc2"]
        self.fc3.weight.data *= self.masks["fc3"]

    def count_connections(self):
        """Count active connections based on binary masks."""
        active_connections = {
            "fc1": torch.sum(self.masks["fc1"]),
            "fc2": torch.sum(self.masks["fc2"]),
            "fc3": torch.sum(self.masks["fc3"]),
        }
        return active_connections

# Pruning Logic
class AutomataPruner:
    def __init__(self, weights, mask, threshold=0.01, memory_depth=5):
        self.weights = weights
        self.mask = mask
        self.threshold = threshold
        self.memory_depth = memory_depth
        self.states = np.zeros(weights.shape, dtype=int)

    def update(self):
        """Automaton updates the pruning decisions."""
        for index in np.ndindex(self.weights.shape):
            if abs(self.weights[index]) < self.threshold:
                self.states[index] += 1
                if self.states[index] >= self.memory_depth:
                    self.mask[index] = 0  # Prune
            else:
                self.states[index] = max(self.states[index] - 1, 0)  # Reward

# Train and Prune
def train_and_prune(model, train_loader, criterion, optimizer, pruners=None, epochs=10):
    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        if pruners:  # Update pruners if pruning is enabled
            with torch.no_grad():
                for pruner in pruners:
                    pruner.update()
            model.apply_pruning()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")

    training_time = time.time() - start_time
    return training_time

# Evaluate Function
def evaluate(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy * 100:.2f}%")
    return accuracy

if __name__ == "__main__":
    # Load breast cancer dataset
    train_loader, test_loader, input_size = load_breast_cancer_data()
    print(f"Input features: {input_size}")

    # Train model without pruning
    model_without_pruning = PrunableNet(input_size)
    optimizer = optim.Adam(model_without_pruning.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("\nTraining without pruning...")
    time_no_pruning = train_and_prune(model_without_pruning, train_loader, criterion, optimizer)
    acc_no_pruning = evaluate(model_without_pruning, test_loader)

    # Count connections in unpruned model
    connections_no_pruning = sum(p.numel() for p in model_without_pruning.parameters() if p.requires_grad)
    print(f"Active Connections (No Pruning): {connections_no_pruning}")

    # Train model with pruning
    model_with_pruning = PrunableNet(input_size)
    optimizer = optim.Adam(model_with_pruning.parameters(), lr=0.001)

    pruners = [
        AutomataPruner(model_with_pruning.fc1.weight.data.numpy(), model_with_pruning.masks["fc1"]),
        AutomataPruner(model_with_pruning.fc2.weight.data.numpy(), model_with_pruning.masks["fc2"]),
        AutomataPruner(model_with_pruning.fc3.weight.data.numpy(), model_with_pruning.masks["fc3"]),
    ]

    print("\nTraining with pruning...")
    time_with_pruning = train_and_prune(model_with_pruning, train_loader, criterion, optimizer, pruners)
    acc_with_pruning = evaluate(model_with_pruning, test_loader)

    # Count active connections in pruned model
    active_connections_with_pruning = model_with_pruning.count_connections()

    # Report results
    print("\nComparison Report")
    print(f"Training Time (No Pruning): {time_no_pruning:.2f} seconds")
    print(f"Training Time (With Pruning): {time_with_pruning:.2f} seconds")
    print(f"Accuracy (No Pruning): {acc_no_pruning * 100:.2f}%")
    print(f"Accuracy (With Pruning): {acc_with_pruning * 100:.2f}%")
    print(f"Active Connections (No Pruning): {connections_no_pruning}")
    print(f"Active Connections (With Pruning):")
    print(f"  fc1: {active_connections_with_pruning['fc1']}")
    print(f"  fc2: {active_connections_with_pruning['fc2']}")
    print(f"  fc3: {active_connections_with_pruning['fc3']}")


Input features: 30

Training without pruning...
Epoch 1/10, Average Loss: 0.6425
Epoch 2/10, Average Loss: 0.4882
Epoch 3/10, Average Loss: 0.2989
Epoch 4/10, Average Loss: 0.1628
Epoch 5/10, Average Loss: 0.1087
Epoch 6/10, Average Loss: 0.0869
Epoch 7/10, Average Loss: 0.0692
Epoch 8/10, Average Loss: 0.0599
Epoch 9/10, Average Loss: 0.0553
Epoch 10/10, Average Loss: 0.0525
Accuracy: 98.25%
Active Connections (No Pruning): 4130

Training with pruning...
Epoch 1/10, Average Loss: 0.6411
Epoch 2/10, Average Loss: 0.4521
Epoch 3/10, Average Loss: 0.2736
Epoch 4/10, Average Loss: 0.1577
Epoch 5/10, Average Loss: 0.1119
Epoch 6/10, Average Loss: 0.0910
Epoch 7/10, Average Loss: 0.1077
Epoch 8/10, Average Loss: 0.0646
Epoch 9/10, Average Loss: 0.0606
Epoch 10/10, Average Loss: 0.0551
Accuracy: 99.12%

Comparison Report
Training Time (No Pruning): 0.33 seconds
Training Time (With Pruning): 0.29 seconds
Accuracy (No Pruning): 98.25%
Accuracy (With Pruning): 99.12%
Active Connections (No Prun

In [18]:
# Regression Dataset: California Housing
import time
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np

# Load Dataset
data = fetch_california_housing()
X, y = data.data, data.target

# Preprocess the data
scaler = StandardScaler()
X = scaler.fit_transform(X)
y = np.expand_dims(y, axis=1)  # Make y a 2D array for PyTorch compatibility

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Convert to PyTorch tensors
X_train, y_train = torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32)
X_test, y_test = torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_train, y_train), batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(X_test, y_test), batch_size=64, shuffle=False)

# Define Neural Network for Regression
class PrunableNet(nn.Module):
    def __init__(self):
        super(PrunableNet, self).__init__()
        self.fc1 = nn.Linear(X.shape[1], 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

        # Initialize binary masks for pruning
        self.masks = {
            "fc1": torch.ones_like(self.fc1.weight),
            "fc2": torch.ones_like(self.fc2.weight),
            "fc3": torch.ones_like(self.fc3.weight),
        }

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def apply_pruning(self):
        """Apply pruning by setting weights to zero based on masks."""
        self.fc1.weight.data *= self.masks["fc1"]
        self.fc2.weight.data *= self.masks["fc2"]
        self.fc3.weight.data *= self.masks["fc3"]

    def count_connections(self):
        """Count active connections based on binary masks."""
        active_connections = {
            "fc1": torch.sum(self.masks["fc1"]),
            "fc2": torch.sum(self.masks["fc2"]),
            "fc3": torch.sum(self.masks["fc3"]),
        }
        return active_connections

# Train and Prune
def train_and_prune(model, train_loader, criterion, optimizer, pruners=None, epochs=10):
    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

        if pruners:  # Update pruners if pruning is enabled
            with torch.no_grad():
                for pruner in pruners:
                    pruner.update()
            model.apply_pruning()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss.item()}")

    training_time = time.time() - start_time
    return training_time

# Evaluate Function
def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()

    test_loss /= len(test_loader)
    print(f"Test Loss (MSE): {test_loss:.4f}")
    return test_loss

# Main Execution
if __name__ == "__main__":
    # Unpruned Model
    model_without_pruning = PrunableNet()
    optimizer = optim.SGD(model_without_pruning.parameters(), lr=0.01, momentum=0.9)
    criterion = nn.MSELoss()

    print("\nTraining without pruning...")
    time_no_pruning = train_and_prune(model_without_pruning, train_loader, criterion, optimizer)
    mse_no_pruning = evaluate(model_without_pruning, test_loader, criterion)

    # Count connections in unpruned model
    connections_no_pruning = sum(p.numel() for p in model_without_pruning.parameters() if p.requires_grad)
    print(f"Active Connections (No Pruning): {connections_no_pruning}")

    # Pruned Model
    pruned_model = PrunableNet()
    optimizer = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9)
    pruners = [
        AutomataPruner(pruned_model.fc1.weight.data.numpy(), pruned_model.masks["fc1"]),
        AutomataPruner(pruned_model.fc2.weight.data.numpy(), pruned_model.masks["fc2"]),
        AutomataPruner(pruned_model.fc3.weight.data.numpy(), pruned_model.masks["fc3"]),
    ]

    print("\nTraining with pruning...")
    time_with_pruning = train_and_prune(pruned_model, train_loader, criterion, optimizer, pruners)
    mse_with_pruning = evaluate(pruned_model, test_loader, criterion)

    # Count active connections in pruned model
    active_connections_with_pruning = pruned_model.count_connections()

    # Report results
    print("\nComparison Report")
    print(f"Training Time (No Pruning): {time_no_pruning:.2f} seconds")
    print(f"Training Time (With Pruning): {time_with_pruning:.2f} seconds")
    print(f"MSE (No Pruning): {mse_no_pruning:.4f}")
    print(f"MSE (With Pruning): {mse_with_pruning:.4f}")
    print(f"Active Connections (No Pruning): {connections_no_pruning}")
    print(f"Active Connections (With Pruning):")
    print(f"  fc1: {active_connections_with_pruning['fc1']}")
    print(f"  fc2: {active_connections_with_pruning['fc2']}")
    print(f"  fc3: {active_connections_with_pruning['fc3']}")



Training without pruning...
Epoch 1/10, Loss: 0.3446616232395172
Epoch 2/10, Loss: 0.45953094959259033
Epoch 3/10, Loss: 0.5902723670005798
Epoch 4/10, Loss: 0.4364682734012604
Epoch 5/10, Loss: 0.6320103406906128
Epoch 6/10, Loss: 0.3662148416042328
Epoch 7/10, Loss: 0.31736329197883606
Epoch 8/10, Loss: 0.3980819880962372
Epoch 9/10, Loss: 0.17703774571418762
Epoch 10/10, Loss: 0.3322415053844452
Test Loss (MSE): 0.3181
Active Connections (No Pruning): 9473

Training with pruning...
Epoch 1/10, Loss: 0.5458833575248718
Epoch 2/10, Loss: 0.2958527207374573
Epoch 3/10, Loss: 0.3071359395980835
Epoch 4/10, Loss: 0.27384474873542786
Epoch 5/10, Loss: 0.3515761196613312
Epoch 6/10, Loss: 0.41151392459869385
Epoch 7/10, Loss: 0.1601962447166443
Epoch 8/10, Loss: 0.12301412969827652
Epoch 9/10, Loss: 0.22080036997795105
Epoch 10/10, Loss: 0.2222491353750229
Test Loss (MSE): 0.3265

Comparison Report
Training Time (No Pruning): 2.72 seconds
Training Time (With Pruning): 5.22 seconds
MSE (No

In [None]:
!sudo apt update
!sudo apt install libcairo2-dev ffmpeg \
    texlive texlive-latex-extra texlive-fonts-extra \
    texlive-latex-recommended texlive-science \
    tipa libpango1.0-dev
!pip install manim
!pip install IPython==8.21.0

In [None]:
from manim import *

In [None]:
from manim import *
import numpy as np

class AutomataVisualization(Scene):
    def construct(self):
        # Colors and styling
        STATE_COLOR = "#2196F3"
        PRUNED_COLOR = "#FF5252"
        REWARD_COLOR = "#4CAF50"
        THRESHOLD_COLOR = "#FFA726"

        # Create title with reduced font size and adjusted position
        title = Text("Weight Pruning Automata", font_size=36)
        subtitle = Text("Deterministic Finite Automaton for Neural Network Pruning", font_size=20)
        subtitle.set_color(BLUE)
        title_group = VGroup(title, subtitle).arrange(DOWN, buff=0.2)
        title_group.to_edge(UP, buff=0.5)
        self.play(Write(title_group), run_time=1.5)

        # Create state diagram with adjusted positioning
        def create_state_diagram():
            states = VGroup()
            arrows = VGroup()
            labels = VGroup()
            x_positions = np.linspace(-5.5, 5.5, 6)

            for i, x in enumerate(x_positions):
                circle = Circle(radius=0.4, color=STATE_COLOR)
                label = Text("M" if i == len(x_positions)-1 else f"S{i}", font_size=20)
                state = VGroup(circle, label).move_to([x, 0.5, 0])
                states.add(state)

                if i == len(x_positions)-1:
                    circle.set_color(PRUNED_COLOR)
                    pruning_label = Text("(Pruning State)", font_size=14, color=PRUNED_COLOR)
                    pruning_label.next_to(state, UP, buff=0.1)
                    labels.add(pruning_label)

            for i in range(len(states)-1):
                forward = Arrow(
                    states[i].get_right(),
                    states[i+1].get_left(),
                    buff=0.1,
                    color=THRESHOLD_COLOR
                )
                backward = Arrow(
                    states[i+1].get_bottom(),
                    states[i].get_bottom(),
                    buff=0.1,
                    color=REWARD_COLOR,
                    path_arc=-2
                )
                arrows.add(forward, backward)

            return states, arrows, labels

        states, arrows, labels = create_state_diagram()

        # Show initial diagram
        self.play(Create(states), run_time=2)
        self.play(Create(arrows), run_time=2)
        self.play(Write(labels), run_time=1.5)

        # Add transition labels with reduced font size and adjusted position
        forward_label = Text("|w| < threshold", font_size=16, color=THRESHOLD_COLOR)
        forward_label.next_to(arrows[0], UP, buff=0.2)  # Increased buffer to move label up
        backward_label = Text("|w| ≥ threshold", font_size=16, color=REWARD_COLOR)
        backward_label.next_to(arrows[1], DOWN, buff=0.1)
        self.play(Write(forward_label), Write(backward_label), run_time=1.5)

        # Add weight value counter at bottom left with adjusted position
        weight_counter = Variable(0.5, Text("Current Weight |w|", font_size=16), var_type=DecimalNumber)
        weight_counter.to_corner(DL).shift(RIGHT * 0.5 + UP * 0.5)
        self.play(Write(weight_counter), run_time=1)

        # Add threshold line with adjusted position to be fully visible
        threshold_line = DashedLine(
            start=[-6, -1.5, 0],
            end=[5.5, -1.5, 0],  # Shortened the line to fit within the screen
            color=THRESHOLD_COLOR,
            dash_length=0.2
        )
        threshold_text = Text("T=0.5", font_size=16, color=THRESHOLD_COLOR)
        threshold_text.next_to(threshold_line, LEFT, buff=0.1)  # Moved text to the left of the line
        self.play(Create(threshold_line), Write(threshold_text), run_time=1.5)

        # Create state pointer
        pointer = Triangle(color=YELLOW).scale(0.15).rotate(PI/2)
        pointer.next_to(states[0], UP, buff=0.1)
        self.play(Create(pointer), run_time=1)

        # Add explanation box with reduced font size and adjusted position
        explanation = VGroup(
            Text("Automaton Rules:", color=WHITE, font_size=18),
            Text("1. If |w| < threshold: Move right", color=THRESHOLD_COLOR, font_size=14),
            Text("2. If |w| ≥ threshold: Move left", color=REWARD_COLOR, font_size=14),
            Text("3. At state M: Weight is pruned", color=PRUNED_COLOR, font_size=14)
        ).arrange(DOWN, aligned_edge=LEFT, buff=0.1)
        explanation.scale(0.9).to_corner(DR).shift(LEFT * 0.5 + UP * 0.5)
        self.play(Write(explanation), run_time=2)

        # Simulate automata transitions to S4
        weight_values = [0.8, 0.3, 0.2, 0.1, 0.05]
        current_state = 0

        for weight in weight_values:
            self.play(
                weight_counter.tracker.animate.set_value(weight),
                run_time=0.8
            )

            if weight < 0.5:
                if current_state < 4:
                    current_state += 1
                    self.play(
                        pointer.animate.next_to(states[current_state], UP, buff=0.1),
                        states[current_state][0].animate.set_color(STATE_COLOR),
                        run_time=0.8
                    )
            else:
                if current_state > 0:
                    self.play(
                        states[current_state][0].animate.set_color(STATE_COLOR),
                        run_time=0.8
                    )
                    current_state -= 1
                    self.play(
                        pointer.animate.next_to(states[current_state], UP, buff=0.1),
                        run_time=0.8
                    )

        # Special transition to state M
        self.wait(0.5)

        # Add transition message with adjusted position
        transition_message = Text("Weight remains below threshold...", color=YELLOW, font_size=18)
        transition_message.to_edge(UP, buff=1.5)
        self.play(Write(transition_message), run_time=1)
        self.wait(1)
        self.play(FadeOut(transition_message))

        # Move to state M
        self.play(
            pointer.animate.next_to(states[5], UP, buff=0.1),
            run_time=1.2
        )

        # Create pruning notification box with adjusted size and position
        notification_box = VGroup(
            Rectangle(width=3.5, height=1.5, fill_opacity=0.2, color=PRUNED_COLOR),
            Text("Memory Depth Reached", color=PRUNED_COLOR, font_size=20),
            Text("Weight is Pruned", color=PRUNED_COLOR, font_size=16)
        )
        notification_box[1:].arrange(DOWN, buff=0.1)
        notification_box[1:].move_to(notification_box[0])
        notification_box.to_edge(UP, buff=0.5)  # Moved to top of screen

        # Fade out title and subtitle, then show notification box
        self.play(
            FadeOut(title_group),
            FadeIn(notification_box),
            run_time=1.5
        )
        self.wait(1.5)
        self.play(FadeOut(notification_box))

        # Final wait
        self.wait(2)

if __name__ == "__main__":
    config.pixel_height = 1080
    config.pixel_width = 1920
    config.frame_rate = 60
    with tempconfig({"quality": "production_quality", "preview": True}):
        scene = AutomataVisualization()
        scene.render()
