In [1]:
# Import libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set random seed
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")



TensorFlow version: 2.20.0
Keras version: 3.12.0


In [2]:
# Load dataset
data = np.load('dataset_dev_3000.npz')
X = data['X']
y = data['y']

print(f"Input X shape: {X.shape}")
print(f"Targets y shape: {y.shape}")
print(f"Target A (10-class): range [{y[:, 0].min():.0f}, {y[:, 0].max():.0f}]")
print(f"Target B (32-class): {len(np.unique(y[:, 1]))} classes")
print(f"Target C (Regression): range [{y[:, 2].min():.4f}, {y[:, 2].max():.4f}]")

Input X shape: (3000, 32, 32)
Targets y shape: (3000, 3)
Target A (10-class): range [0, 9]
Target B (32-class): 32 classes
Target C (Regression): range [0.0003, 0.9996]


In [None]:
def apply_2d_fft(images):
    """Apply simple 2D FFT to batch of images - returns magnitude and phase"""
    # Input shape: (batch, 32, 32, 1) or (batch, 32, 32)
    # Ensure we have the right shape
    if images.ndim == 4:
        images_squeezed = images[:, :, :, 0]  # (batch, 32, 32)
    else:
        images_squeezed = images
    
    # Apply 2D FFT using numpy
    fft_result = np.fft.fft2(images_squeezed)
    
    # Shift zero frequency to center
    fft_shifted = np.fft.fftshift(fft_result)
    
    # Get magnitude
    magnitude = np.abs(fft_shifted)
    
    # Get phase
    phase = np.angle(fft_shifted)
    
    # Stack magnitude and phase as 2 channels: (batch, 32, 32, 2)
    output = np.stack([magnitude, phase], axis=-1)
    
    return output


class FourierTransformLayer(layers.Layer):
    """Custom layer to apply 2D Fourier Transform"""

    def call(self, inputs):
        # Convert to numpy, apply FFT, convert back to tensor
        result = tf.numpy_function(
            lambda x: apply_2d_fft(x.numpy() if hasattr(x, 'numpy') else x),
            [inputs],
            tf.float32
        )
        result.set_shape([inputs.shape[0], inputs.shape[1], inputs.shape[2], 2])
        return result

    def compute_output_shape(self, input_shape):
        # Input: (batch, 32, 32, 1) -> Output: (batch, 32, 32, 2)
        return (input_shape[0], input_shape[1], input_shape[2], 2)

print("Fourier Transform layer defined")

Fourier Transform layer defined


## Compile Model

In [4]:
def build_simple_cnn_head_a():
    """Ultra simple CNN - proven architecture"""
    model = models.Sequential(
        [
            # Input
            layers.Input(shape=(32, 32, 1)),
            # Conv block 1
            layers.Conv2D(32, 3, activation="relu", padding="same"),
            layers.MaxPooling2D(2),
            # Conv block 2
            layers.Conv2D(64, 3, activation="relu", padding="same"),
            layers.MaxPooling2D(2),
            # Conv block 3
            layers.Conv2D(64, 3, activation="relu", padding="same"),
            layers.MaxPooling2D(2),
            # Dense
            layers.Flatten(),
            layers.Dense(64, activation="relu"),
            layers.Dropout(0.5),
            # was 10
            layers.Dense(10, activation="softmax"),
        ],
        name="simple_cnn_head_a",
    )

    return model


model = build_simple_cnn_head_a()
model.summary()

## Fourier Transformation Layer

In [5]:
# Train/validation split stratified on Target A (10 classes)
X_train, X_val, y_train, y_val = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=SEED,
    stratify=y[:, 0]  # Stratify on Target A
)

# Extract only Target A
y_train_a = y_train[:, 0]
y_val_a = y_val[:, 0]

print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")
print(f"Target A classes: {len(np.unique(y_train_a))}")

Training samples: 2400
Validation samples: 600
Target A classes: 10


In [6]:
# Compile model for single task
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

print("Model compiled successfully")

Model compiled successfully


In [7]:
# Train model (Head A only)
history = model.fit(
    X_train, y_train_a,
    validation_data=(X_val, y_val_a),
    epochs=50,
    batch_size=32,
    verbose=1
)

print("\nTraining completed!")

Epoch 1/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 12ms/step - accuracy: 0.1058 - loss: 2.3021 - val_accuracy: 0.1300 - val_loss: 2.2823
Epoch 2/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 11ms/step - accuracy: 0.1392 - loss: 2.2762 - val_accuracy: 0.1217 - val_loss: 2.2568
Epoch 3/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.1492 - loss: 2.2552 - val_accuracy: 0.1733 - val_loss: 2.2155
Epoch 4/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.1838 - loss: 2.2142 - val_accuracy: 0.2283 - val_loss: 2.1416
Epoch 5/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.1825 - loss: 2.1693 - val_accuracy: 0.2367 - val_loss: 2.0835
Epoch 6/50
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - accuracy: 0.1954 - loss: 2.1286 - val_accuracy: 0.2517 - val_loss: 2.0159
Epoch 7/50
[1m75/75[0m [32m━━━━

In [8]:
# Evaluate on validation set
val_loss, val_accuracy = model.evaluate(X_val, y_val_a, verbose=0)

print("\n" + "="*60)
print("FINAL VALIDATION RESULTS - HEAD A (10-class)")
print("="*60)

# Get predictions
predictions = model.predict(X_val, verbose=0)
pred_classes = np.argmax(predictions, axis=1)

# Calculate accuracy
accuracy = np.mean(pred_classes == y_val_a)

print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Random baseline: {1/10:.4f} ({100/10:.2f}%)")
print("="*60)


FINAL VALIDATION RESULTS - HEAD A (10-class)
Validation Loss: 2.7766
Validation Accuracy: 0.3183 (31.83%)
Random baseline: 0.1000 (10.00%)


In [9]:
# Visualize individual samples with Fourier Transform as heatmaps
import matplotlib.pyplot as plt

# Configure how many samples to display (from 0 to n)
n = 10  # <-- Change this value to display more or fewer samples

# Display n samples from validation set
for i in range(n):
    sample_image = X_val[i].squeeze()  # Shape: (32, 32)
    true_class = y_val_a[i]
    
    # Predict class
    pred_probs = model.predict(X_val[i:i+1], verbose=0)
    pred_class = np.argmax(pred_probs)
    confidence = pred_probs[0][pred_class]
    
    # Apply FFT to get magnitude and phase
    fft_output = apply_2d_fft(X_val[i:i+1])  # Shape: (1, 32, 32, 2)
    fft_magnitude = fft_output[0, :, :, 0]  # Magnitude
    fft_phase = fft_output[0, :, :, 1]  # Phase
    
    # Create 3-panel visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Original image
    im0 = axes[0].imshow(sample_image, cmap='gray', interpolation='nearest')
    axes[0].set_title('Original Image', fontsize=11, fontweight='bold')
    axes[0].set_xlabel('X coordinate')
    axes[0].set_ylabel('Y coordinate')
    plt.colorbar(im0, ax=axes[0], label='Intensity')
    
    # FFT Magnitude
    im1 = axes[1].imshow(fft_magnitude, cmap='hot', interpolation='nearest')
    axes[1].set_title('FFT Magnitude (Log Scale)', fontsize=11, fontweight='bold')
    axes[1].set_xlabel('Frequency X')
    axes[1].set_ylabel('Frequency Y')
    plt.colorbar(im1, ax=axes[1], label='Magnitude')
    
    # FFT Phase
    im2 = axes[2].imshow(fft_phase, cmap='twilight', interpolation='nearest')
    axes[2].set_title('FFT Phase', fontsize=11, fontweight='bold')
    axes[2].set_xlabel('Frequency X')
    axes[2].set_ylabel('Frequency Y')
    plt.colorbar(im2, ax=axes[2], label='Phase')
    
    # Overall title
    fig.suptitle(f'Sample {i} | True Class: {true_class} | Predicted: {pred_class} (confidence: {confidence:.2%})', 
                 fontsize=13, fontweight='bold', y=1.02)
    
    plt.tight_layout()
    plt.show()

# Check class distribution
print("\nClass Distribution in Validation Set:")
unique, counts = np.unique(y_val_a, return_counts=True)
for cls, cnt in zip(unique, counts):
    print(f"  Class {cls}: {cnt} samples ({cnt/len(y_val_a)*100:.1f}%)")

ValueError: cannot select an axis to squeeze out which has size not equal to one

In [None]:
# Analyze confusion matrix to see which classes are confused
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Get confusion matrix
cm = confusion_matrix(y_val_a, pred_classes)

# Plot
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)
plt.title('Confusion Matrix - Target A (10 classes)', fontsize=14, fontweight='bold')
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.tight_layout()
plt.show()

# Classification report
print("\nPer-Class Performance:")
print(classification_report(y_val_a, pred_classes, digits=3))

# Find most confused pairs
print("\nMost Confused Class Pairs:")
for i in range(10):
    for j in range(i+1, 10):
        if cm[i, j] + cm[j, i] > 10:  # Significant confusion
            print(f"  Classes {i} ↔ {j}: {cm[i,j]} + {cm[j,i]} = {cm[i,j] + cm[j,i]} confusions")

In [None]:
# Plot training and validation metrics
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Accuracy
axes[0].plot(history.history['accuracy'], label='Train', linewidth=2)
axes[0].plot(history.history['val_accuracy'], label='Validation', linewidth=2)
axes[0].axhline(y=0.1, color='r', linestyle='--', label='Random (10%)', alpha=0.5)
axes[0].set_title('Head A: 10-Class Classification Accuracy', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Accuracy')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Loss
axes[1].plot(history.history['loss'], label='Train', linewidth=2)
axes[1].plot(history.history['val_loss'], label='Validation', linewidth=2)
axes[1].set_title('Loss', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Training and validation plots displayed above")