In [None]:
import tensorflow as tf
import numpy as np
import os
import gc
from tensorflow import keras

# Load a small subset of MNIST to test with
print("Loading MNIST data...")
try:
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    
    # Use a small subset for testing
    x_train = x_train[:1000].astype('float32') / 255.0
    y_train = y_train[:1000]
    x_test = x_test[:100].astype('float32') / 255.0
    y_test = y_test[:100]
    
    # Reshape for the model
    x_train = x_train.reshape(-1, 28*28)
    x_test = x_test.reshape(-1, 28*28)
    
    print(f"Data loaded. x_train shape: {x_train.shape}")
except Exception as e:
    print(f"Error loading data: {e}")
    raise

# Create a very simple model
print("Creating model...")
try:
    model = keras.Sequential([
        keras.layers.Input(shape=(784,)),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    # Compile with a low learning rate
    model.compile(
        optimizer=keras.optimizers.legacy.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    print("Model created and compiled")
except Exception as e:
    print(f"Error creating model: {e}")
    raise

# Try to free up memory before training
print("Cleaning up memory...")
gc.collect()

# Add error handling for the training process
print("Starting training with a small batch size...")
try:
    # Use a small batch size and just 1 epoch for testing
    history = model.fit(
        x_train, y_train,
        batch_size=32,
        epochs=1,
        verbose=1,
        validation_data=(x_test, y_test)
    )
    print("Training completed successfully!")
except tf.errors.ResourceExhaustedError as e:
    print(f"Memory error during training: {e}")
    print("Try reducing batch size, model size, or use CPU only.")
except Exception as e:
    print(f"Error during training: {e}")
    print(f"Error type: {type(e)}")
    raise

print("Test completed without kernel crash!")