In [7]:
import os
import time  # For measuring epoch runtime
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from torchvision.models import ResNet18_Weights
from sklearn.metrics import precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from PIL import ImageFile, Image
from torch.optim.lr_scheduler import CosineAnnealingLR
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
import torch
import jax

# Optionally, try importing PyCUDA if available
try:
    import pycuda.driver as cuda
except ImportError:
    cuda = None

# ---------------------------
# Check GPU with PyTorch
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch sees device: {device}")

# ---------------------------
# Check GPU with JAX
# ---------------------------
print("\nJAX devices:")
jax_devices = jax.devices()
for d in jax_devices:
    print(f" - {d} (platform: {d.platform})")

gpu_jax_devices = [d for d in jax_devices if d.platform == "gpu"]
if gpu_jax_devices:
    print("\nJAX GPU device(s) available:")
    for gpu in gpu_jax_devices:
        print(f" - {gpu}")
else:
    print("\nNo GPU device available for JAX; it will use the CPU.")
    print("If you expect a GPU, ensure you have installed the GPU-enabled version of jaxlib.")

# ---------------------------
# Optionally, Check GPU with PyCUDA
# ---------------------------
if cuda:
    try:
        cuda.init()  # Initialize the CUDA driver
        num_gpus = cuda.Device.count()
        print(f"\nPyCUDA: Found {num_gpus} GPU(s).")
        for i in range(num_gpus):
            dev = cuda.Device(i)
            print(f" - GPU {i}: {dev.name()}")
    except Exception as e:
        print(f"\nPyCUDA encountered an error: {e}")
else:
    print("\nPyCUDA is not installed. Skipping PyCUDA GPU check.")


PyTorch sees device: cuda

JAX devices:
 - TFRT_CPU_0 (platform: cpu)

No GPU device available for JAX; it will use the CPU.
If you expect a GPU, ensure you have installed the GPU-enabled version of jaxlib.

PyCUDA is not installed. Skipping PyCUDA GPU check.


In [None]:
import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp

# -----------------------------------------------------------------------------
# Check for GPU availability
# -----------------------------------------------------------------------------
gpu_devices = [device for device in jax.devices() if device.platform == "gpu"]
if gpu_devices:
    print("GPU device(s) available:", gpu_devices)
else:
    print("No GPU device available; using CPU:", jax.devices())

# -----------------------------------------------------------------------------
# 1. Unpickle and load the CIFAR-10 dataset
# -----------------------------------------------------------------------------

def unpickle(file):
    """Unpickle the given file."""
    with open(file, 'rb') as fo:
        data_dict = pickle.load(fo, encoding='bytes')
    return data_dict

data_dir = "cifar-10-batches-py"

# Load training data from the 5 data batches
train_data_list = []
train_labels_list = []
for i in range(1, 6):
    batch = unpickle(os.path.join(data_dir, f"data_batch_{i}"))
    train_data_list.append(batch[b"data"])
    train_labels_list.extend(batch[b"labels"])

# Concatenate training data and convert labels to a numpy array
train_data = np.concatenate(train_data_list, axis=0)
train_labels = np.array(train_labels_list)

# Reshape the data:
#   - Original shape: (N, 3072) where each row contains 1024 red, 1024 green, 1024 blue values.
#   - Reshape to (N, 3, 32, 32) then transpose to (N, 32, 32, 3) for NHWC format.
train_data = train_data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
# Normalize pixel values to [0, 1]
train_data = train_data.astype(np.float32) / 255.0

# Load test data
test_batch = unpickle(os.path.join(data_dir, "test_batch"))
test_data = test_batch[b"data"]
test_labels = np.array(test_batch[b"labels"])
test_data = test_data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
test_data = test_data.astype(np.float32) / 255.0

# Load metadata (e.g. label names)
meta = unpickle(os.path.join(data_dir, "batches.meta"))
label_names = [name.decode('utf-8') for name in meta[b"label_names"]]

print("CIFAR-10 dataset loaded:")
print(" Training data shape:", train_data.shape)
print(" Test data shape:", test_data.shape)
print(" Label names:", label_names)

# -----------------------------------------------------------------------------
# 2. Define a simple CNN using JAX (from scratch)
# -----------------------------------------------------------------------------

# Helper functions for parameter initialization

def init_conv_params(key, filter_shape):
    """
    Initialize parameters for a convolutional layer.
    filter_shape: (filter_height, filter_width, in_channels, out_channels)
    """
    k1, _ = jax.random.split(key)
    fan_in = np.prod(filter_shape[:3])
    # He initialization for ReLU activation
    W = jax.random.normal(k1, filter_shape) * jnp.sqrt(2.0 / fan_in)
    b = jnp.zeros((filter_shape[-1],))
    return W, b

def init_dense_params(key, in_dim, out_dim):
    """
    Initialize parameters for a dense (fully-connected) layer.
    """
    k1, _ = jax.random.split(key)
    W = jax.random.normal(k1, (in_dim, out_dim)) * jnp.sqrt(2.0 / in_dim)
    b = jnp.zeros((out_dim,))
    return W, b

def init_cnn_params(key):
    """
    Initialize the parameters for our CNN.
    Architecture:
      - Conv1: 3x3 kernel, 3 in-channels, 32 out-channels, stride 1, SAME padding.
      - MaxPool: 2x2 window, stride 2.
      - Conv2: 3x3 kernel, 32 in-channels, 64 out-channels, stride 1, SAME padding.
      - MaxPool: 2x2 window, stride 2.
      - Dense1: Fully connected layer from flattened features to 256 hidden units.
      - Dense2: Output layer (256 -> 10 classes)
      
    Note: With 'SAME' padding and two 2×2 poolings, 32×32 images become 8×8 feature maps.
    """
    keys = jax.random.split(key, 4)
    conv1_W, conv1_b = init_conv_params(keys[0], (3, 3, 3, 32))
    conv2_W, conv2_b = init_conv_params(keys[1], (3, 3, 32, 64))
    dense_input_dim = 8 * 8 * 64
    dense1_W, dense1_b = init_dense_params(keys[2], dense_input_dim, 256)
    dense2_W, dense2_b = init_dense_params(keys[3], 256, 10)
    
    params = {
        'conv1': (conv1_W, conv1_b),
        'conv2': (conv2_W, conv2_b),
        'dense1': (dense1_W, dense1_b),
        'dense2': (dense2_W, dense2_b)
    }
    return params

def cnn_forward(params, x):
    """
    Forward pass for the CNN.
    x: input batch with shape (batch, 32, 32, 3)
    Returns the logits (unnormalized log probabilities) for each class.
    """
    conv1_W, conv1_b = params['conv1']
    conv2_W, conv2_b = params['conv2']
    dense1_W, dense1_b = params['dense1']
    dense2_W, dense2_b = params['dense2']
    
    # --- Convolutional Layer 1 ---
    # Use jax.lax.conv_general_dilated to perform the convolution.
    x = jax.lax.conv_general_dilated(
        x, conv1_W,
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )
    x = x + conv1_b  # add bias (broadcasting over H and W)
    x = jax.nn.relu(x)
    # 2x2 max pooling (reduce_window)
    x = jax.lax.reduce_window(
        x, -jnp.inf, jax.lax.max,
        window_dimensions=(1, 2, 2, 1),
        window_strides=(1, 2, 2, 1),
        padding='SAME'
    )
    
    # --- Convolutional Layer 2 ---
    x = jax.lax.conv_general_dilated(
        x, conv2_W,
        window_strides=(1, 1),
        padding='SAME',
        dimension_numbers=('NHWC', 'HWIO', 'NHWC')
    )
    x = x + conv2_b
    x = jax.nn.relu(x)
    # Second 2x2 max pooling
    x = jax.lax.reduce_window(
        x, -jnp.inf, jax.lax.max,
        window_dimensions=(1, 2, 2, 1),
        window_strides=(1, 2, 2, 1),
        padding='SAME'
    )
    
    # --- Flatten and Fully-Connected Layers ---
    x = x.reshape((x.shape[0], -1))  # flatten each example
    # Dense Layer 1
    x = jnp.dot(x, dense1_W) + dense1_b
    x = jax.nn.relu(x)
    # Dense Layer 2 (Output layer)
    logits = jnp.dot(x, dense2_W) + dense2_b
    return logits

def cross_entropy_loss(logits, labels):
    """
    Compute the cross-entropy loss between the predicted logits and true labels.
    labels: integer array of shape (batch,)
    """
    one_hot = jax.nn.one_hot(labels, num_classes=10)
    log_probs = jax.nn.log_softmax(logits)
    loss = -jnp.sum(one_hot * log_probs, axis=-1)
    return jnp.mean(loss)

def compute_loss(params, x, y):
    """Compute loss for a batch."""
    logits = cnn_forward(params, x)
    return cross_entropy_loss(logits, y)

def compute_accuracy(params, x, y):
    """Compute the accuracy on a batch."""
    logits = cnn_forward(params, x)
    predictions = jnp.argmax(logits, axis=-1)
    return jnp.mean(predictions == y)

# -----------------------------------------------------------------------------
# 3. Training Loop
# -----------------------------------------------------------------------------

# Define one update (training) step using SGD.
# We use jax.grad to compute gradients and jax.tree_util.tree_map to update parameters.
@jax.jit
def update(params, x, y, learning_rate):
    grads = jax.grad(compute_loss)(params, x, y)
    new_params = jax.tree_util.tree_map(lambda p, g: p - learning_rate * g, params, grads)
    return new_params

# Training hyperparameters
num_epochs = 10
batch_size = 128
learning_rate = 0.001

# Initialize CNN parameters with a PRNG key
key = jax.random.PRNGKey(42)
params = init_cnn_params(key)

num_train = train_data.shape[0]
num_batches = num_train // batch_size

print("\nStarting training...")
for epoch in range(num_epochs):
    # Shuffle the training data at the beginning of each epoch
    permutation = np.random.permutation(num_train)
    train_data = train_data[permutation]
    train_labels = train_labels[permutation]
    
    # Process mini-batches
    for i in range(num_batches):
        start = i * batch_size
        end = start + batch_size
        x_batch = train_data[start:end]
        y_batch = train_labels[start:end]
        params = update(params, x_batch, y_batch, learning_rate)
    
    # Evaluate performance on (a subset of) training and test data
    train_loss = compute_loss(params, train_data[:1000], train_labels[:1000])
    train_acc  = compute_accuracy(params, train_data[:1000], train_labels[:1000])
    test_loss  = compute_loss(params, test_data, test_labels)
    test_acc   = compute_accuracy(params, test_data, test_labels)
    print(f"Epoch {epoch+1:02d} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}% | "
          f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc*100:.2f}%")


No GPU device available; using CPU: [CpuDevice(id=0)]
CIFAR-10 dataset loaded:
 Training data shape: (50000, 32, 32, 3)
 Test data shape: (10000, 32, 32, 3)
 Label names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

Starting training...
