In [8]:
# Import libraries
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings('ignore')

# Set random seed
SEED = 42
np.random.seed(SEED)
tf.random.set_seed(SEED)

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

TensorFlow version: 2.20.0
Keras version: 3.12.0


In [9]:
# Load dataset
data = np.load('dataset_dev_3000.npz')
X = data['X']
y = data['y']

print(f"Input X shape: {X.shape}")
print(f"Targets y shape: {y.shape}")
print(f"Target A (10-class): range [{y[:, 0].min():.0f}, {y[:, 0].max():.0f}]")
print(f"Target B (32-class): {len(np.unique(y[:, 1]))} classes")
print(f"Target C (Regression): range [{y[:, 2].min():.4f}, {y[:, 2].max():.4f}]")

Input X shape: (3000, 32, 32)
Targets y shape: (3000, 3)
Target A (10-class): range [0, 9]
Target B (32-class): 32 classes
Target C (Regression): range [0.0003, 0.9996]


In [10]:
@tf.function
def apply_2d_fft(images):
    """Apply 2D FFT to batch of images"""
    # Cast to complex for FFT
    images_complex = tf.cast(images, tf.complex64)
    
    # Apply 2D FFT
    fft_result = tf.signal.fft2d(images_complex)
    
    # Get magnitude (amplitude spectrum)
    magnitude = tf.abs(fft_result)
    
    # Shift zero frequency to center
    magnitude = tf.signal.fftshift(magnitude, axes=[1, 2])
    
    # Log scale for better visualization
    magnitude = tf.math.log(magnitude + 1.0)
    
    # Normalize
    magnitude = (magnitude - tf.reduce_min(magnitude)) / (tf.reduce_max(magnitude) - tf.reduce_min(magnitude) + 1e-8)
    
    return magnitude

class FourierTransformLayer(layers.Layer):
    """Custom layer to apply 2D Fourier Transform"""
    def call(self, inputs):
        return apply_2d_fft(inputs)
    
    def compute_output_shape(self, input_shape):
        return input_shape

print("Fourier Transform layer defined")

Fourier Transform layer defined


## Compile Model

In [11]:
def build_simple_mtl_model():
    # Input
    input_layer = layers.Input(shape=(32, 32, 1), name='input')
    
    # Apply Fourier Transform
    fft_features = FourierTransformLayer(name='fourier_transform')(input_layer)
    
    # Shared backbone - simple CNN
    x = layers.Conv2D(32, 3, activation='relu', padding='same')(fft_features)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(64, 3, activation='relu', padding='same')(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(128, 3, activation='relu', padding='same')(x)
    x = layers.GlobalAveragePooling2D()(x)
    
    # Shared dense layers
    shared = layers.Dense(256, activation='relu')(x)
    # shared = layers.Dropout(0.3)(shared)
    shared = layers.Dense(128, activation='relu')(shared)
    # shared = layers.Dropout(0.3)(shared)
    
    # Head A: 10-class classification
    head_a = layers.Dense(64, activation='relu', name='head_a_dense')(shared)
    head_a_out = layers.Dense(10, activation='softmax', name='head_a')(head_a)
    
    # Head B: 32-class classification
    head_b = layers.Dense(64, activation='relu', name='head_b_dense')(shared)
    head_b_out = layers.Dense(32, activation='softmax', name='head_b')(head_b)
    
    # Head C: Regression
    head_c = layers.Dense(64, activation='relu', name='head_c_dense')(shared)
    head_c_out = layers.Dense(1, activation='linear', name='head_c')(head_c)
    
    # Build model
    model = models.Model(
        inputs=input_layer,
        outputs=[head_a_out, head_b_out, head_c_out],
        name='simple_mtl_fourier'
    )
    
    return model

# Build model
model = build_simple_mtl_model()
model.summary()

## Fourier Transformation Layer

In [12]:
# Train/validation split stratified on Target B (32 classes)
X_train, X_val, y_train, y_val = train_test_split(
    X, y, 
    test_size=0.2, 
    random_state=SEED,
    stratify=y[:, 1]
)

print(f"Training samples: {len(X_train)}")
print(f"Validation samples: {len(X_val)}")

Training samples: 2400
Validation samples: 600


In [13]:
# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss={
        'head_a': 'sparse_categorical_crossentropy',
        'head_b': 'sparse_categorical_crossentropy',
        'head_c': 'mse'
    },
    metrics={
        'head_a': ['accuracy'],
        'head_b': ['accuracy'],
        'head_c': ['mae']
    }
)

print("Model compiled successfully")

Model compiled successfully


In [14]:
# Prepare targets
y_train_dict = {
    'head_a': y_train[:, 0],
    'head_b': y_train[:, 1],
    'head_c': y_train[:, 2]
}

y_val_dict = {
    'head_a': y_val[:, 0],
    'head_b': y_val[:, 1],
    'head_c': y_val[:, 2]
}

# Train model
history = model.fit(
    X_train, y_train_dict,
    validation_data=(X_val, y_val_dict),
    epochs=30,
    batch_size=32,
    verbose=1
)

print("\nTraining completed!")

Epoch 1/30


[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 15ms/step - head_a_accuracy: 0.0875 - head_a_loss: 2.3062 - head_b_accuracy: 0.0321 - head_b_loss: 3.4687 - head_c_loss: 0.1006 - head_c_mae: 0.2660 - loss: 5.8754 - val_head_a_accuracy: 0.0867 - val_head_a_loss: 2.3037 - val_head_b_accuracy: 0.0383 - val_head_b_loss: 3.4636 - val_head_c_loss: 0.0879 - val_head_c_mae: 0.2583 - val_loss: 5.8551
Epoch 2/30
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 14ms/step - head_a_accuracy: 0.1004 - head_a_loss: 2.3035 - head_b_accuracy: 0.0375 - head_b_loss: 3.4654 - head_c_loss: 0.0757 - head_c_mae: 0.2339 - loss: 5.8446 - val_head_a_accuracy: 0.1000 - val_head_a_loss: 2.3032 - val_head_b_accuracy: 0.0383 - val_head_b_loss: 3.4632 - val_head_c_loss: 0.0569 - val_head_c_mae: 0.2018 - val_loss: 5.8232
Epoch 3/30
[1m75/75[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 15ms/step - head_a_accuracy: 0.1096 - head_a_loss: 2.2985 - head_b_accuracy: 0.0371 - head_b_l

In [15]:
# Evaluate on validation set
results = model.evaluate(X_val, y_val_dict, verbose=0)

print("\n" + "="*60)
print("FINAL VALIDATION RESULTS")
print("="*60)

# Get predictions
predictions = model.predict(X_val, verbose=0)
pred_a, pred_b, pred_c = predictions

# Calculate accuracies
acc_a = np.mean(np.argmax(pred_a, axis=1) == y_val[:, 0])
acc_b = np.mean(np.argmax(pred_b, axis=1) == y_val[:, 1])
mae_c = np.mean(np.abs(pred_c.flatten() - y_val[:, 2]))

print(f"Head A (10-class) Accuracy: {acc_a:.4f} ({acc_a*100:.2f}%)")
print(f"Head B (32-class) Accuracy: {acc_b:.4f} ({acc_b*100:.2f}%)")
print(f"Head C (Regression) MAE: {mae_c:.4f}")
print("="*60)


FINAL VALIDATION RESULTS
Head A (10-class) Accuracy: 0.1933 (19.33%)
Head B (32-class) Accuracy: 0.0600 (6.00%)
Head C (Regression) MAE: 0.1493


In [None]:
# Visualize all Conv2D layer activations
images_per_row = 8

for layer_name, layer_activation in zip(layer_names[1:], activations[1:]):  # Skip Fourier layer
    if 'conv' not in layer_name:
        continue
        
    n_features = layer_activation.shape[-1]
    size = layer_activation.shape[1]
    
    # Calculate grid dimensions
    n_cols = n_features // images_per_row
    if n_features % images_per_row != 0:
        n_cols += 1
    
    display_grid = np.zeros((size * n_cols, images_per_row * size))
    
    # Tile each filter into the grid
    for col in range(n_cols):
        for row in range(images_per_row):
            channel_idx = col * images_per_row + row
            if channel_idx >= n_features:
                break
                
            channel_image = layer_activation[0, :, :, channel_idx]
            
            # Post-process for visualization
            channel_image -= channel_image.mean()
            channel_image /= (channel_image.std() + 1e-5)
            channel_image *= 64
            channel_image += 128
            channel_image = np.clip(channel_image, 0, 255).astype('uint8')
            
            display_grid[col * size : (col + 1) * size,
                        row * size : (row + 1) * size] = channel_image
    
    # Display the grid
    scale = 1.5 / size
    plt.figure(figsize=(scale * display_grid.shape[1],
                       scale * display_grid.shape[0]))
    plt.title(f'{layer_name} - {n_features} filters', fontsize=14)
    plt.grid(False)
    plt.imshow(display_grid, aspect='auto', cmap='viridis')
    plt.axis('off')
    
plt.show()

In [None]:
# Visualize the original image and Fourier transform
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Original image
axes[0].imshow(sample_img[0, :, :, 0], cmap='gray')
axes[0].set_title('Original Image')
axes[0].axis('off')

# Fourier magnitude (channel 0)
if activations[0].shape[-1] >= 2:
    axes[1].imshow(activations[0][0, :, :, 0], cmap='viridis')
    axes[1].set_title('Fourier Magnitude')
    axes[1].axis('off')
    
    # Fourier phase (channel 1)
    axes[2].imshow(activations[0][0, :, :, 1], cmap='twilight')
    axes[2].set_title('Fourier Phase')
    axes[2].axis('off')
else:
    axes[1].imshow(activations[0][0, :, :, 0], cmap='viridis')
    axes[1].set_title('Fourier Transform')
    axes[1].axis('off')
    axes[2].axis('off')

plt.tight_layout()
plt.show()

print(f"Ground truth labels: A={y_val[0, 0]:.0f}, B={y_val[0, 1]:.0f}, C={y_val[0, 2]:.4f}")

In [None]:
# Get activations for a sample validation image
sample_img = X_val[0:1]  # Take first validation sample
activations = activation_model.predict(sample_img, verbose=0)

print(f"\nNumber of activation maps: {len(activations)}")
for i, (name, activation) in enumerate(zip(layer_names, activations)):
    print(f"{i+1}. {name}: {activation.shape}")

In [None]:
# Create a model that outputs intermediate layer activations
import matplotlib.pyplot as plt

# Get the names and outputs of all layers we want to visualize
layer_names = []
layer_outputs = []

for layer in model.layers:
    if 'conv' in layer.name or 'fourier' in layer.name or 'pooling' in layer.name:
        layer_names.append(layer.name)
        layer_outputs.append(layer.output)

print(f"Visualizing {len(layer_names)} layers:")
for name in layer_names:
    print(f"  - {name}")

# Create activation model
activation_model = models.Model(inputs=model.input, outputs=layer_outputs)

## Visualize Intermediate Activations (What ConvNet Learns)