In [1]:
# Lightweight SNR–SER comparison (load UNet, train CAE/DnCNN)
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Model, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.optimizers.schedules import ExponentialDecay

print("TensorFlow:", tf.__version__)
print("GPU:", tf.config.list_physical_devices('GPU'))


TensorFlow: 2.15.0
GPU: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:
# Data: CIFAR-10 and z-score helpers
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32')/255.0; y_train = y_train.flatten()
x_test  = x_test.astype('float32')/255.0;  y_test  = y_test.flatten()

MEAN = tf.constant(np.mean(x_train, axis=(0,1,2)), dtype=tf.float32)
STD  = tf.constant(np.std(x_train,  axis=(0,1,2)) + 1e-6, dtype=tf.float32)

def to_zscore(x):
    return (x - MEAN) / STD

def from_zscore(z):
    return z * STD + MEAN

cifar10_class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
print('Data ready:', x_train.shape, x_test.shape)


2025-10-13 15:30:34.285574: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M4
2025-10-13 15:30:34.285594: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2025-10-13 15:30:34.285600: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2025-10-13 15:30:34.285628: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2025-10-13 15:30:34.285644: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Data ready: (50000, 32, 32, 3) (10000, 32, 32, 3)


In [3]:
# Noise: fixed SNR generators (Gaussian/S&P/Burst)

def gaussian_snr_to_cond_vector(snr_db) -> tf.Tensor:
    # Fully TF-native: accept tensor or float
    snr_db = tf.cast(snr_db, tf.float32)
    log10_sigma = -snr_db / 20.0
    c = tf.clip_by_value(log10_sigma - 0.5, 0.0, 1.0)
    return tf.stack([tf.constant(1.0), tf.constant(0.0), tf.constant(0.0), c])


def add_gaussian_noise_fixed_snr(clean_img_01: tf.Tensor, snr_db):
    img_z = to_zscore(clean_img_01)
    snr_db = tf.cast(snr_db, tf.float32)
    sigma = tf.pow(10.0, -snr_db/20.0)
    noise = tf.random.normal(tf.shape(img_z), stddev=sigma, dtype=tf.float32)
    noisy_z = img_z + noise
    cond = gaussian_snr_to_cond_vector(snr_db)
    return noisy_z, cond


def snr_scale_noise(clean_z: tf.Tensor, noisy_z: tf.Tensor, target_snr_db: tf.Tensor):
    noise = noisy_z - clean_z
    px = tf.reduce_mean(tf.square(clean_z))
    pn = tf.reduce_mean(tf.square(noise)) + 1e-12
    r = tf.pow(10.0, target_snr_db/10.0)
    pn_target = px / r
    k = tf.sqrt(tf.maximum(pn_target / pn, 1e-12))
    return clean_z + k*noise


def add_sp_noise_fixed_snr(clean_img_01: tf.Tensor, snr_db: tf.Tensor, amount: float = 0.15):
    img_z = to_zscore(clean_img_01)
    u = tf.random.uniform(tf.shape(img_z))
    salt = tf.cast(u < amount*0.5, tf.float32)
    pepper = tf.cast(u > 1.0 - amount*0.5, tf.float32)
    noisy_z = img_z * (1.0 - salt - pepper) + salt
    noisy_z = snr_scale_noise(img_z, noisy_z, snr_db)
    return noisy_z, tf.convert_to_tensor([0.0,1.0,0.0,amount], dtype=tf.float32)


def add_burst_noise_fixed_snr(clean_img_01: tf.Tensor, snr_db: tf.Tensor, size_factor: float = 0.3, intensity: float = 0.85):
    img_z = to_zscore(clean_img_01)
    h = tf.shape(img_z)[0]; w = tf.shape(img_z)[1]; cch = tf.shape(img_z)[2]
    bh = tf.maximum(1, tf.cast(tf.cast(h, tf.float32)*size_factor, tf.int32))
    bw = tf.maximum(1, tf.cast(tf.cast(w, tf.float32)*size_factor, tf.int32))
    sy = tf.random.uniform([], maxval=tf.maximum(1, h-bh), dtype=tf.int32)
    sx = tf.random.uniform([], maxval=tf.maximum(1, w-bw), dtype=tf.int32)
    patch = tf.random.normal([bh, bw, cch], stddev=intensity)
    noise = tf.pad(patch, [[sy, h-sy-bh], [sx, w-sx-bw], [0,0]])
    mask  = tf.pad(tf.ones([bh, bw, cch]), [[sy, h-sy-bh], [sx, w-sx-bw], [0,0]])
    noisy_z = img_z * (1.0 - mask) + (img_z + noise) * mask
    noisy_z = snr_scale_noise(img_z, noisy_z, snr_db)
    c = tf.clip_by_value(size_factor*intensity, 0.0, 1.0)
    return noisy_z, tf.convert_to_tensor([0.0,0.0,1.0,c], dtype=tf.float32)


def make_fixed_snr_dataset_noise(x, y, snr_db: float, noise_type: str = 'gaussian', batch_size: int = 128):
    ds = tf.data.Dataset.from_tensor_slices((x, y))
    def _map_fn(clean_img, label):
        clean_img = tf.cast(clean_img, tf.float32)
        sdb = tf.cast(snr_db, tf.float32)
        if noise_type == 'gaussian': noisy_z, cond = add_gaussian_noise_fixed_snr(clean_img, sdb)
        elif noise_type in ('sp','s&p'): noisy_z, cond = add_sp_noise_fixed_snr(clean_img, sdb)
        elif noise_type == 'burst': noisy_z, cond = add_burst_noise_fixed_snr(clean_img, sdb)
        else: raise ValueError(noise_type)
        clean_z = to_zscore(clean_img)
        return (noisy_z, cond), (clean_z, label)
    return ds.map(_map_fn, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)


In [4]:
# Load pretrained UNet (conditional multitask UNet)
unet_path = '/Users/ihaegwon/Lab/best_cifar10_conditional_model.keras'
unet_model = tf.keras.models.load_model(unet_path)
print('Loaded UNet from:', unet_path)
print('UNet inputs:', [inp.name for inp in unet_model.inputs])




Loaded UNet from: /Users/ihaegwon/Lab/best_cifar10_conditional_model.keras
UNet inputs: ['image_input', 'noise_map_input']


In [5]:
# CAE/DnCNN multitask models

def build_cae_multitask(input_shape_img=(32,32,3), input_shape_map=(4,), num_classes=10):
    img_in  = layers.Input(shape=input_shape_img, name='image_input')
    cond_in = layers.Input(shape=input_shape_map, name='noise_map_input')
    x = layers.Conv2D(32,3,padding='same',activation='relu')(img_in)
    x = layers.Conv2D(32,3,padding='same',activation='relu')(x)
    s1 = x  # 32x32
    p1 = layers.MaxPooling2D(2)(s1)  # 16x16
    x = layers.Conv2D(64,3,padding='same',activation='relu')(p1)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
    s2 = x  # 16x16
    p2 = layers.MaxPooling2D(2)(s2)  # 8x8
    x = layers.Conv2D(128,3,padding='same',activation='relu')(p2)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x)
    feat = layers.GlobalAveragePooling2D()(x)
    feat = layers.Concatenate()([feat, cond_in])
    feat = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(1e-4))(feat)
    feat = layers.Dropout(0.5)(feat)
    cls_out = layers.Dense(num_classes, activation='softmax', name='classification_output')(feat)
    d = layers.Conv2DTranspose(64,2,strides=2,padding='same')(x)  # 16x16
    d = layers.Concatenate()([d, s2])
    d = layers.Conv2D(64,3,padding='same',activation='relu')(d)
    d = layers.Conv2D(64,3,padding='same',activation='relu')(d)
    d = layers.Conv2DTranspose(32,2,strides=2,padding='same')(d)  # 32x32
    d = layers.Concatenate()([d, s1])
    d = layers.Conv2D(32,3,padding='same',activation='relu')(d)
    d = layers.Conv2D(32,3,padding='same',activation='relu')(d)
    rec = layers.Conv2D(3,1,activation='linear', name='restoration_output')(d)
    return Model(inputs=[img_in, cond_in], outputs=[rec, cls_out], name='CAE_multitask')


def build_dncnn_multitask(input_shape_img=(32,32,3), input_shape_map=(4,), num_classes=10, depth=17, filters=64):
    img_in  = layers.Input(shape=input_shape_img, name='image_input')
    cond_in = layers.Input(shape=input_shape_map, name='noise_map_input')
    x = layers.Conv2D(filters,3,padding='same',activation='relu')(img_in)
    for _ in range(depth-2):
        x = layers.Conv2D(filters,3,padding='same',use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
    res = layers.Conv2D(3,3,padding='same',activation='linear', name='residual_pred')(x)
    rec = layers.Subtract(name='restoration_output')([img_in, res])
    feat = layers.GlobalAveragePooling2D()(x)
    feat = layers.Concatenate()([feat, cond_in])
    feat = layers.Dense(128, activation='relu', kernel_regularizer=regularizers.l2(1e-4))(feat)
    feat = layers.Dropout(0.5)(feat)
    cls_out = layers.Dense(num_classes, activation='softmax', name='classification_output')(feat)
    return Model(inputs=[img_in, cond_in], outputs=[rec, cls_out], name='DnCNN_multitask')

# Learning rate schedule (match UNet): ExponentialDecay per-step
STEPS_PER_EPOCH = int(np.ceil(len(x_train)/128))
initial_learning_rate = 1e-4
lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps=STEPS_PER_EPOCH,
    decay_rate=0.96,
    staircase=True
)

cae_model = build_cae_multitask(num_classes=10)
dncnn_model = build_dncnn_multitask(num_classes=10)

for m in [cae_model, dncnn_model]:
    m.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
              loss={'restoration_output':'mae','classification_output':'sparse_categorical_crossentropy'},
              loss_weights={'restoration_output':0.8,'classification_output':0.2},
              metrics={'classification_output':'accuracy'})

print('CAE/DnCNN ready with LR schedule')




CAE/DnCNN ready with LR schedule


In [6]:
# Mixed-SNR training dataset (Gaussian only by default)
BATCH_SIZE = 128

def gen_mixed_gaussian_sample(clean_img, label):
    clean_img = tf.cast(clean_img, tf.float32)
    snr_db = tf.random.uniform([], -30.0, -10.0)
    noisy_z, cond = add_gaussian_noise_fixed_snr(clean_img, snr_db)
    clean_z = to_zscore(clean_img)
    return (noisy_z, cond), (clean_z, label)

train_ds_mixed = (tf.data.Dataset.from_tensor_slices((x_train, y_train))
                  .shuffle(50000)
                  .map(gen_mixed_gaussian_sample, num_parallel_calls=tf.data.AUTOTUNE)
                  .batch(BATCH_SIZE)
                  .prefetch(tf.data.AUTOTUNE))

val_ds_mixed = (tf.data.Dataset.from_tensor_slices((x_test, y_test))
                .map(gen_mixed_gaussian_sample, num_parallel_calls=tf.data.AUTOTUNE)
                .batch(BATCH_SIZE)
                .prefetch(tf.data.AUTOTUNE))

print('Datasets ready')


Datasets ready


In [None]:
# SER evaluators and plotting

def eval_model_ser_over_snrs(model, x, y, snr_list_db, noise_type='gaussian', batch_size=512):
    results = {}
    for snr in snr_list_db:
        ds = make_fixed_snr_dataset_noise(x, y, snr_db=float(snr), noise_type=noise_type, batch_size=batch_size)
        total = 0; errors = 0
        for (noisy_z_b, cond_b), (clean_z_b, label_b) in ds:
            _, logits_b = model.predict([noisy_z_b, cond_b], verbose=0)
            pred = np.argmax(logits_b, axis=-1)
            total += label_b.shape[0]
            errors += int(np.sum(pred != label_b.numpy()))
        results[float(snr)] = errors / max(1, total)
    return results


def plot_snr_ser(models_ser_dict, title='SNR vs SER', threshold=0.10):
    plt.figure(figsize=(7,5))
    for name, ser_map in models_ser_dict.items():
        snrs = np.array(sorted(ser_map.keys()))
        sers = np.array([ser_map[s] for s in snrs])
        plt.plot(snrs, sers, marker='o', label=name)
        idx = np.where(np.diff((sers <= threshold).astype(int)) != 0)[0]
        if idx.size > 0:
            i = idx[0]
            x0,x1 = snrs[i], snrs[i+1]; y0,y1 = sers[i], sers[i+1]
            if y1 != y0:
                x_cross = x0 + (threshold - y0) * (x1 - x0) / (y1 - y0)
                plt.scatter([x_cross],[threshold], marker='x', s=80)
                plt.text(x_cross, threshold+0.02, f"{name}: {x_cross:.1f} dB", ha='center', fontsize=9)
    plt.axhline(threshold, color='gray', ls='--', lw=1, label='SER=0.10')
    plt.ylim(0,1); plt.xlabel('SNR (dB)'); plt.ylabel('SER'); plt.title(title); plt.grid(ls=':'); plt.legend(); plt.show()


: 

In [None]:
# Evaluate UNet SER first (Gaussian)
snr_grid = list(range(-30, -9, 2))
unet_ser = eval_model_ser_over_snrs(unet_model, x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)
print('UNet SER computed for', len(snr_grid), 'SNR points')


In [None]:
# Train/Load CAE MTL only, then evaluate SER (Gaussian)
EPOCHS = 200
ckpt_dir = '/Users/ihaegwon/Lab'
cae_ckpt = os.path.join(ckpt_dir, 'best_cae_multitask.keras')

callbacks_cae = [
    EarlyStopping(monitor='val_classification_output_accuracy', patience=20, restore_best_weights=True),
    ModelCheckpoint(filepath=cae_ckpt, save_weights_only=False, monitor='val_classification_output_accuracy', mode='max', save_best_only=True)
]

if os.path.exists(cae_ckpt):
    print(f'Loading CAE weights from {cae_ckpt}')
    cae_model = tf.keras.models.load_model(cae_ckpt)
else:
    print('\nTraining CAE (up to 200 epochs, early-stop) ...')
    cae_model.fit(train_ds_mixed, epochs=EPOCHS, validation_data=val_ds_mixed, callbacks=callbacks_cae, verbose=1)

cae_ser = eval_model_ser_over_snrs(cae_model, x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)

models_map = {'UNet (MTL)': unet_ser, 'CAE (MTL)': cae_ser}
plot_snr_ser(models_map, title='Gaussian: SNR vs SER (UNet vs CAE)')


Training CAE (up to 200 epochs, early-stop) ...
Epoch 1/200


2025-10-13 15:30:40.806612: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2025-10-13 15:30:40.892496: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 2/200
Epoch 3/200
Epoch 4/200
Epoch 5/200
Epoch 6/200
Epoch 7/200
Epoch 8/200
Epoch 9/200
Epoch 10/200
Epoch 11/200
Epoch 12/200
Epoch 13/200
Epoch 14/200
Epoch 15/200
Epoch 16/200
Epoch 17/200
Epoch 18/200
Epoch 19/200
Epoch 20/200
Epoch 21/200
Epoch 22/200
Epoch 23/200
Epoch 24/200
Epoch 25/200
Epoch 26/200
Epoch 27/200
Epoch 28/200
Epoch 29/200
Epoch 30/200
Epoch 31/200
Epoch 32/200
Epoch 33/200
Epoch 34/200
Epoch 35/200
Epoch 36/200
Epoch 37/200
Epoch 38/200
Epoch 39/200
Epoch 40/200
Epoch 41/200
Epoch 42/200
Epoch 43/200
Epoch 44/200
Epoch 45/200
Epoch 46/200
Epoch 47/200
Epoch 48/200
Epoch 49/200
Epoch 50/200
Epoch 51/200
Epoch 52/200
Epoch 53/200
Epoch 54/200
Epoch 55/200

Training DnCNN (up to 200 epochs, early-stop) ...
Epoch 1/200


In [None]:
# Fixed classifier (clean CIFAR-10) with load-if-exists

def build_fixed_classifier(input_shape=(32,32,3), num_classes=10):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(inputs)
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return Model(inputs, outputs, name='FixedClassifier')

clf_ckpt = '/Users/ihaegwon/Lab/best_fixed_classifier.keras'
if os.path.exists(clf_ckpt):
    print(f'Loading fixed classifier from {clf_ckpt}')
    fixed_clf = tf.keras.models.load_model(clf_ckpt)
else:
    fixed_clf = build_fixed_classifier()
    fixed_clf.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
                      loss='sparse_categorical_crossentropy',
                      metrics=['accuracy'])
    clf_callbacks = [
        EarlyStopping(monitor='val_accuracy', patience=8, restore_best_weights=True),
        ModelCheckpoint(filepath=clf_ckpt, save_weights_only=False, monitor='val_accuracy', mode='max', save_best_only=True)
    ]
    print('\nTraining fixed classifier on clean CIFAR-10 ...')
    fixed_clf.fit(x_train, y_train, validation_data=(x_test, y_test),
                  epochs=50, batch_size=256, callbacks=clf_callbacks, verbose=1)
    print('Saved best fixed classifier to', clf_ckpt)



In [None]:
# Basic restoration-only CAE/DnCNN (no conditioning, no classifier head)

def build_cae_restoration(input_shape=(32,32,3)):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(32,3,padding='same',activation='relu')(inp)
    x = layers.Conv2D(32,3,padding='same',activation='relu')(x)
    s1 = x
    p1 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(p1)
    x = layers.Conv2D(64,3,padding='same',activation='relu')(x)
    s2 = x
    p2 = layers.MaxPooling2D(2)(x)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(p2)
    x = layers.Conv2D(128,3,padding='same',activation='relu')(x)
    d = layers.Conv2DTranspose(64,2,strides=2,padding='same')(x)
    d = layers.Concatenate()([d, s2])
    d = layers.Conv2D(64,3,padding='same',activation='relu')(d)
    d = layers.Conv2D(64,3,padding='same',activation='relu')(d)
    d = layers.Conv2DTranspose(32,2,strides=2,padding='same')(d)
    d = layers.Concatenate()([d, s1])
    d = layers.Conv2D(32,3,padding='same',activation='relu')(d)
    d = layers.Conv2D(32,3,padding='same',activation='relu')(d)
    out = layers.Conv2D(3,1,activation='linear')(d)
    return Model(inp, out, name='CAE_restoration')


def build_dncnn_restoration(input_shape=(32,32,3), depth=17, filters=64):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(filters,3,padding='same',activation='relu')(inp)
    for _ in range(depth-2):
        x = layers.Conv2D(filters,3,padding='same',use_bias=False)(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
    res = layers.Conv2D(3,3,padding='same',activation='linear')(x)
    out = layers.Subtract()([inp, res])
    return Model(inp, out, name='DnCNN_restoration')

cae_rest_ckpt = '/Users/ihaegwon/Lab/best_cae_restoration.keras'
dncnn_rest_ckpt = '/Users/ihaegwon/Lab/best_dncnn_restoration.keras'

cae_rest = build_cae_restoration(); dncnn_rest = build_dncnn_restoration()
for m in [cae_rest, dncnn_rest]:
    m.compile(optimizer=tf.keras.optimizers.Adam(1e-4), loss='mae')

# Train or load restoration-only models using mixed Gaussian SNR data
rest_train = train_ds_mixed.map(lambda inp, tgt: (from_zscore(inp[0]), from_zscore(tgt[0])))
rest_val   = val_ds_mixed.map(lambda inp, tgt: (from_zscore(inp[0]), from_zscore(tgt[0])))

rest_callbacks_cae = [EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
                      ModelCheckpoint(filepath=cae_rest_ckpt, save_weights_only=False, monitor='val_loss', mode='min', save_best_only=True)]
rest_callbacks_dn  = [EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
                      ModelCheckpoint(filepath=dncnn_rest_ckpt, save_weights_only=False, monitor='val_loss', mode='min', save_best_only=True)]

if os.path.exists(cae_rest_ckpt):
    print(f'Loading CAE restoration from {cae_rest_ckpt}')
    cae_rest = tf.keras.models.load_model(cae_rest_ckpt)
else:
    print('\nTraining CAE restoration ...')
    cae_rest.fit(rest_train, validation_data=rest_val, epochs=100, callbacks=rest_callbacks_cae, verbose=1)

if os.path.exists(dncnn_rest_ckpt):
    print(f'Loading DnCNN restoration from {dncnn_rest_ckpt}')
    dncnn_rest = tf.keras.models.load_model(dncnn_rest_ckpt)
else:
    print('\nTraining DnCNN restoration ...')
    dncnn_rest.fit(rest_train, validation_data=rest_val, epochs=100, callbacks=rest_callbacks_dn, verbose=1)



In [None]:
# Pipeline SER evaluators (no-rest and restoration+classifier)

def eval_pipeline_ser_over_snrs(classifier, restorer, x, y, snr_list_db, noise_type='gaussian', batch_size=512):
    results = {}
    for snr in snr_list_db:
        ds = make_fixed_snr_dataset_noise(x, y, snr_db=float(snr), noise_type=noise_type, batch_size=batch_size)
        total = 0; errors = 0
        for (noisy_z_b, cond_b), (clean_z_b, label_b) in ds:
            if restorer is None:
                restored = from_zscore(noisy_z_b)
            else:
                restored = restorer.predict(from_zscore(noisy_z_b), verbose=0)
            logits_b = classifier.predict(restored, verbose=0)
            pred = np.argmax(logits_b, axis=-1)
            total += label_b.shape[0]
            errors += int(np.sum(pred != label_b.numpy()))
        results[float(snr)] = errors / max(1, total)
    return results

snr_grid = list(range(-30, -9, 2))

# Evaluate pipeline baselines (Gaussian)
no_rest_ser  = eval_pipeline_ser_over_snrs(fixed_clf, None,       x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)
cae_pipe_ser = eval_pipeline_ser_over_snrs(fixed_clf, cae_rest,   x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)
dn_pipe_ser  = eval_pipeline_ser_over_snrs(fixed_clf, dncnn_rest, x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)

# Combine plots with MTL models already computed in previous cell
models_map = {
    'UNet (MTL)': unet_ser,
    'CAE (MTL)': cae_ser,
    'DnCNN (MTL)': dncnn_ser,
    'No-Rest + FixedClf': no_rest_ser,
    'CAE-Rest + FixedClf': cae_pipe_ser,
    'DnCNN-Rest + FixedClf': dn_pipe_ser,   
}
plot_snr_ser(models_map, title='Gaussian: SNR vs SER (MTL vs Pipelines)')



In [None]:
# Train/Load DnCNN MTL separately, then evaluate SER (Gaussian)
EPOCHS = 200
ckpt_dir = '/Users/ihaegwon/Lab'
dncnn_ckpt = os.path.join(ckpt_dir, 'best_dncnn_multitask.keras')

callbacks_dncnn = [
    EarlyStopping(monitor='val_classification_output_accuracy', patience=20, restore_best_weights=True),
    ModelCheckpoint(filepath=dncnn_ckpt, save_weights_only=False, monitor='val_classification_output_accuracy', mode='max', save_best_only=True)
]

if os.path.exists(dncnn_ckpt):
    print(f'Loading DnCNN weights from {dncnn_ckpt}')
    dncnn_model = tf.keras.models.load_model(dncnn_ckpt)
else:
    print('\nTraining DnCNN (up to 200 epochs, early-stop) ...')
    dncnn_model.fit(train_ds_mixed, epochs=EPOCHS, validation_data=val_ds_mixed, callbacks=callbacks_dncnn, verbose=1)

try:
    dncnn_ser = eval_model_ser_over_snrs(dncnn_model, x_test, y_test, snr_grid, noise_type='gaussian', batch_size=512)
    models_map = {'UNet (MTL)': unet_ser}
    if 'cae_ser' in globals(): models_map['CAE (MTL)'] = cae_ser
    models_map['DnCNN (MTL)'] = dncnn_ser
    plot_snr_ser(models_map, title='Gaussian: SNR vs SER (UNet vs CAE vs DnCNN)')
except Exception as e:
    print('DnCNN evaluation skipped due to error:', e)

