# ImageNet-1K AST Training with TensorFlow Datasets

**Developed by Oluwafemi Idiakhoa**

**Advantage**: No need to download 150GB! Streams directly from TensorFlow.

**Goal**: Validate AST on full ImageNet-1K (1.28M images)

**Expected Results**:
- Accuracy: 70-72%
- Energy Savings: 80%
- Training Time: ~5 hours on A100

---

## Step 1: Setup and Install Dependencies

In [None]:
# Check GPU
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install TensorFlow Datasets and other dependencies
!pip install -q tensorflow-datasets tensorflow torch torchvision tqdm matplotlib numpy

print("‚úÖ Dependencies installed!")

## Step 2: Mount Google Drive for Checkpoints

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.makedirs("/content/drive/MyDrive/ast_imagenet1k_tfds_checkpoints", exist_ok=True)
print("‚úÖ Google Drive mounted - checkpoints will be saved here")

## Step 3: Load ImageNet-1K via TensorFlow Datasets

This will stream the data without downloading 150GB!

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

print("üîÑ Loading ImageNet-1K metadata...")
print("Note: First run will download ~6GB of metadata, then streams during training")
print()

# This prepares the dataset but doesn't download all images
builder = tfds.builder('imagenet2012')
builder.download_and_prepare()

info = builder.info
print(f"‚úÖ ImageNet-1K ready!")
print(f"   Training samples: {info.splits['train'].num_examples:,}")
print(f"   Validation samples: {info.splits['validation'].num_examples:,}")
print(f"   Number of classes: {info.features['label'].num_classes}")

## Step 4: Create PyTorch-Compatible DataLoaders

In [None]:
import torch
from torch.utils.data import DataLoader, IterableDataset
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

class TFDSImageNetDataset(IterableDataset):
    """Convert TensorFlow Dataset to PyTorch IterableDataset"""
    
    def __init__(self, split, transform=None):
        self.ds = tfds.load('imagenet2012', split=split, shuffle_files=True)
        self.ds = self.ds.repeat()  # Repeat indefinitely
        self.transform = transform
    
    def __iter__(self):
        for example in tfds.as_numpy(self.ds):
            image = example['image']
            label = example['label']
            
            # Convert to PIL Image
            image = Image.fromarray(image)
            
            if self.transform:
                image = self.transform(image)
            
            yield image, label

# Define transforms
normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    normalize,
    transforms.RandomErasing(p=0.25),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize,
])

print("‚úÖ Transforms configured")

## Step 5: Clone AST Repository and Load Configuration

In [None]:
# Clone repository
!git clone https://github.com/oluwafemidiakhoa/adaptive-sparse-training.git
%cd adaptive-sparse-training

from KAGGLE_IMAGENET1K_AST_CONFIGS import get_config

# Get Ultra configuration
config = get_config("ultra")

# Adjust for A100/V100
config.batch_size = 256  # Adjust based on your GPU
config.num_workers = 0   # TFDS handles threading internally

print("="*70)
print("ULTRA CONFIGURATION - ImageNet-1K via TensorFlow Datasets")
print("="*70)
print(f"Classes: {config.num_classes}")
print(f"Total Epochs: {config.num_epochs}")
print(f"Batch Size: {config.batch_size}")
print(f"Target Activation Rate: {config.target_activation_rate:.0%}")
print(f"Expected Energy Savings: {(1-config.target_activation_rate)*100:.0f}%")
print("="*70)

## Step 6: Create DataLoaders

In [None]:
# Create datasets
train_dataset = TFDSImageNetDataset('train', transform=train_transform)
val_dataset = TFDSImageNetDataset('validation', transform=val_transform)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    num_workers=0,  # TFDS handles this
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config.batch_size,
    num_workers=0,
)

print("‚úÖ DataLoaders created")
print("üì¶ Streaming ImageNet-1K from TensorFlow Datasets")

## Step 7: Training Script

**Note**: This uses the same AST implementation from ImageNet-100

In [None]:
# The rest is identical to your ImageNet-1K training script
# Just copy the training loop from ImageNet1K_Ultra_Colab.ipynb cell 16

print("‚ö†Ô∏è  Copy the training script from your ImageNet1K_Ultra_Colab.ipynb")
print("    (The one starting with: import torch.nn as nn...)")
print()
print("The only difference is we're using TFDS streaming instead of local files!")

## Advantages of This Approach

‚úÖ **No 150GB download** - Streams data as needed

‚úÖ **Works immediately** - No waiting for dataset access

‚úÖ **Same training code** - Just different data loading

‚úÖ **Legitimate source** - Official TensorFlow Datasets

‚ö†Ô∏è **Small downside**: First epoch may be slightly slower due to streaming