<a href="https://colab.research.google.com/github/oluwafemidiakhoa/AIGenesis_Engine/blob/main/Untitled429.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install adaptive-sparse-training


In [None]:
!pip install -q adaptive-sparse-training gradio torchvision torch


In [None]:
import os
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from PIL import Image

from adaptive_sparse_training import ASTConfig, AdaptiveSparseTrainer


# 1. CIFAR-10 -> ImageFolder
def prepare_cifar10_imagefolder(root: str = "data/cifar10_imagefolder") -> str:
    root_path = Path(root)
    train_root = root_path / "train"
    val_root = root_path / "val"

    if train_root.is_dir() and val_root.is_dir():
        print(f"[INFO] Using existing ImageFolder dataset at: {root_path}")
        return str(root_path)

    print(f"[INFO] Preparing CIFAR-10 at {root_path} ...")
    train_root.mkdir(parents=True, exist_ok=True)
    val_root.mkdir(parents=True, exist_ok=True)

    cifar_train = datasets.CIFAR10(root="data", train=True, download=True)
    cifar_test = datasets.CIFAR10(root="data", train=False, download=True)

    classes = cifar_train.classes
    print("[INFO] CIFAR-10 classes:", classes)

    to_pil = transforms.ToPILImage()

    def dump_split(split_set, split_root: Path, split_name: str):
        print(f"[INFO] Saving {split_name} images to {split_root} ...")
        for idx, (img, label) in enumerate(split_set):
            class_name = classes[label]
            class_dir = split_root / class_name
            class_dir.mkdir(parents=True, exist_ok=True)
            img_pil: Image.Image = to_pil(img)
            img_path = class_dir / f"{idx:06d}.png"
            img_pil.save(img_path)
        print(f"[INFO] Finished saving {split_name}.")

    dump_split(cifar_train, train_root, "train")
    dump_split(cifar_test, val_root, "val")

    print(f"[INFO] Done. ImageFolder dataset at: {root_path}")
    return str(root_path)


# 2. Dataloaders
def build_dataloaders(data_root: str, batch_size: int = 128, num_workers: int = 2):
    train_dir = os.path.join(data_root, "train")
    val_dir = os.path.join(data_root, "val")

    assert os.path.isdir(train_dir), f"Train directory not found: {train_dir}"
    assert os.path.isdir(val_dir), f"Val directory not found: {val_dir}"

    train_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    val_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    train_ds = datasets.ImageFolder(train_dir, transform=train_tf)
    val_ds = datasets.ImageFolder(val_dir, transform=val_tf)

    classes = train_ds.classes
    print(f"[INFO] Classes ({len(classes)}): {classes}")
    print(f"[INFO] Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader, classes


# 3. Model
def build_model(num_classes: int, device: str = "cuda"):
    try:
        weights = models.ResNet18_Weights.DEFAULT
        model = models.resnet18(weights=weights)
    except AttributeError:
        model = models.resnet18(weights="IMAGENET1K_V1")
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model = model.to(device)
    return model


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: str = "cuda") -> float:
    model.eval()
    correct = 0
    total = 0
    for inputs, targets in loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    return 100.0 * correct / total


# 4. Train with AST + save model
def main_train():
    DATA_ROOT = "data/cifar10_imagefolder"
    BATCH_SIZE = 128
    NUM_WORKERS = 2

    EPOCHS = 5
    WARMUP_EPOCHS = 1
    LEARNING_RATE = 3e-4
    TARGET_ACTIVATION_RATE = 0.6

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {DEVICE}")

    data_root = prepare_cifar10_imagefolder(DATA_ROOT)
    train_loader, val_loader, classes = build_dataloaders(
        data_root, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS
    )

    num_classes = len(classes)
    model = build_model(num_classes, device=DEVICE)

    criterion = nn.CrossEntropyLoss(reduction="none")
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    ast_config = ASTConfig(
        target_activation_rate=TARGET_ACTIVATION_RATE,
        adapt_kp=0.002,
        adapt_ki=5e-5,
        ema_alpha=0.2,
        use_amp=True,
        device=DEVICE,
    )

    trainer = AdaptiveSparseTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=ast_config,
        optimizer=optimizer,
        criterion=criterion,
    )

    print("Starting training...")
    print(f"Starting training: {EPOCHS} epochs ({WARMUP_EPOCHS} warmup + {EPOCHS - WARMUP_EPOCHS} AST)")
    results = trainer.train(epochs=EPOCHS, warmup_epochs=WARMUP_EPOCHS)

    final_val_acc = evaluate(model, val_loader, device=DEVICE)
    energy_savings = None
    if isinstance(results, dict):
        energy_savings = results.get("energy_savings", None)

    print("\n==================== SUMMARY ====================")
    print(f"Final Validation Accuracy: {final_val_acc:.2f}%")
    if energy_savings is not None:
        print(f"Final Energy Savings: {energy_savings:.1f}%")
    else:
        print("Final Energy Savings: (not returned in results dict)")
    print("=================================================\n")

    # save model weights
    torch.save(model.state_dict(), "ast_cifar10_resnet18.pth")
    print("[INFO] Saved model to ast_cifar10_resnet18.pth")
    # save classes for later use
    with open("cifar10_classes.txt", "w") as f:
        for c in classes:
            f.write(c + "\n")
    print("[INFO] Saved class names to cifar10_classes.txt")


main_train()


In [None]:
import gradio as gr
from torchvision import transforms
from PIL import Image

# Reuse build_model from above
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device for Gradio:", DEVICE)

# Load classes
with open("cifar10_classes.txt") as f:
    CLASSES = [line.strip() for line in f.readlines()]

# Load model
num_classes = len(CLASSES)
model = build_model(num_classes, device=DEVICE)
state_dict = torch.load("ast_cifar10_resnet18.pth", map_location=DEVICE)
model.load_state_dict(state_dict)
model.eval()

preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def predict(image: Image.Image):
    if image is None:
        return {}
    x = preprocess(image).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)[0]
    return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}

demo = gr.Interface(
    fn=predict,
    inputs=gr.Image(type="pil", label="Upload CIFAR-like image"),
    outputs=gr.Label(num_top_classes=3, label="Top-3 Predictions"),
    title="AST CIFAR-10 Classifier",
    description="ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.",
)

demo.launch(debug=True)


In [None]:
import textwrap
import os

space_dir = "space_app"
os.makedirs(space_dir, exist_ok=True)

app_py = textwrap.dedent("""
    import torch
    import torch.nn as nn
    from torchvision import models, transforms
    from torch.utils.data import DataLoader
    from PIL import Image
    import gradio as gr

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # Load class names (make sure this file is in the Space)
    with open("cifar10_classes.txt") as f:
        CLASSES = [line.strip() for line in f.readlines()]

    def build_model(num_classes: int, device: str = "cpu"):
        try:
            weights = models.ResNet18_Weights.DEFAULT
            model = models.resnet18(weights=weights)
        except AttributeError:
            model = models.resnet18(weights="IMAGENET1K_V1")
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        model = model.to(device)
        return model

    num_classes = len(CLASSES)
    model = build_model(num_classes, device=DEVICE)

    state_dict = torch.load("ast_cifar10_resnet18.pth", map_location=DEVICE)
    model.load_state_dict(state_dict)
    model.eval()

    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])

    def predict(image: Image.Image):
        if image is None:
            return {}
        x = preprocess(image).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits = model(x)
            probs = torch.softmax(logits, dim=1)[0]
        return {CLASSES[i]: float(probs[i]) for i in range(len(CLASSES))}

    demo = gr.Interface(
        fn=predict,
        inputs=gr.Image(type="pil", label="Upload CIFAR-like image"),
        outputs=gr.Label(num_top_classes=3, label="Top-3 Predictions"),
        title="AST CIFAR-10 Classifier",
        description="ResNet18 fine-tuned with Adaptive Sparse Training (AST) on CIFAR-10.",
    )

    if __name__ == "__main__":
        demo.launch()Z
""").strip() + "\n"

requirements_txt = textwrap.dedent("""
    gradio
    torch
    torchvision
    pillow
""").strip() + "\n"

with open(os.path.join(space_dir, "app.py"), "w") as f:
    f.write(app_py)

with open(os.path.join(space_dir, "requirements.txt"), "w") as f:
    f.write(requirements_txt)

print("[INFO] Created Hugging Face Space app files in 'space_app/'")

# Optional: zip it so you can download easily
!zip -r space_app.zip space_app > /dev/null
print("[INFO] Zipped to space_app.zip")
