In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import mixed_precision, backend as K
from sklearn.metrics import classification_report, roc_curve, auc, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

tf.random.set_seed(42)
mixed_precision.set_global_policy('mixed_float16')

DATASET_PATH = "/kaggle/input/cell-images-2/cell_images"
IMG_SIZE = (160, 160)
BATCH_SIZE = 64
EPOCHS = 30

datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2,
    horizontal_flip=True,
    zoom_range=0.2
)

train_gen = datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='training',
    shuffle=True
)

val_gen = datagen.flow_from_directory(
    DATASET_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='validation',
    shuffle=False
)


def wavelet_transform(x):
    out1 = tf.image.resize(x, (80, 80))
    out2 = tf.image.resize(x, (40, 40))
    out3 = tf.image.resize(x, (20, 20))
    out4 = tf.image.resize(x, (10, 10))
    return [out1, out2, out3, out4]


def WaveletTransformAxisY(x):
    even = x[:, :, ::2]
    odd = x[:, :, 1::2]
    L = (even + odd) / 2
    H = (even - odd) / 2
    return L, H

def WaveletTransformAxisX(x):
    even = x[:, ::2, :]
    odd = x[:, 1::2, :]
    L = (even + odd) / 2
    H = (even - odd) / 2
    return L, H

def Wavelet(batch_image):
    batch_image = K.permute_dimensions(batch_image, [0, 3, 1, 2])
    r, g, b = batch_image[:, 0], batch_image[:, 1], batch_image[:, 2]

    def level_decompose(channel):
        L, H = WaveletTransformAxisY(channel)
        LL, LH = WaveletTransformAxisX(L)
        HL, HH = WaveletTransformAxisX(H)
        return LL, LH, HL, HH

    def full_decomposition(channel):
        out1 = level_decompose(channel)
        out2 = level_decompose(out1[0])
        out3 = level_decompose(out2[0])
        out4 = level_decompose(out3[0])
        return out1 + out2 + out3 + out4

    r_decom = full_decomposition(r)
    g_decom = full_decomposition(g)
    b_decom = full_decomposition(b)

    def group_levels(decom):
        return [
            K.stack(decom[0:4], axis=1),
            K.stack(decom[4:8], axis=1),
            K.stack(decom[8:12], axis=1),
            K.stack(decom[12:16], axis=1)
        ]

    r_levels = group_levels(r_decom)
    g_levels = group_levels(g_decom)
    b_levels = group_levels(b_decom)

    levels = []
    for i in range(4):
        level = K.concatenate([r_levels[i], g_levels[i], b_levels[i]], axis=1)
        level = K.permute_dimensions(level, [0, 2, 3, 1])
        levels.append(level)
    return levels

def Wavelet_out_shape(input_shapes):
    return [
        (None, IMG_SIZE[0]//2, IMG_SIZE[1]//2, 12),
        (None, IMG_SIZE[0]//4, IMG_SIZE[1]//4, 12),
        (None, IMG_SIZE[0]//8, IMG_SIZE[1]//8, 12),
        (None, IMG_SIZE[0]//16, IMG_SIZE[1]//16, 12)
    ]


def spatial_attention(input_feature):
    avg_pool = Lambda(lambda x: tf.reduce_mean(x, axis=-1, keepdims=True))(input_feature)
    max_pool = Lambda(lambda x: tf.reduce_max(x, axis=-1, keepdims=True))(input_feature)
    concat = Concatenate(axis=-1)([avg_pool, max_pool])
    attention = Conv2D(1, kernel_size=7, padding='same', activation='sigmoid', use_bias=False)(concat)
    return Multiply()([input_feature, attention])


class CheckpointedTransformerBlock(tf.keras.layers.Layer):
    def __init__(self, num_heads=4, ff_dim=128, dropout_rate=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate
        
        self.ln1 = LayerNormalization(epsilon=1e-6)
        self.mha = MultiHeadAttention(num_heads=num_heads, key_dim=ff_dim)
        self.ln2 = LayerNormalization(epsilon=1e-6)
        self.dense1 = Dense(ff_dim, activation='relu')
        self.dense2 = None
        self.dropout = Dropout(dropout_rate)

    def call(self, inputs, training=None):
        if self.dense2 is None:
            self.dense2 = Dense(inputs.shape[-1])
        
        def forward_pass(x):
            x1 = self.ln1(x)
            attn_output = self.mha(x1, x1)
            x2 = x + attn_output
            x3 = self.ln2(x2)
            ff = self.dense1(x3)
            ff = self.dense2(ff)
            ff = self.dropout(ff, training=training)
            return x2 + ff

        return tf.recompute_grad(forward_pass)(inputs)

def transformer_block(inputs, num_heads=4, ff_dim=128):
    x = LayerNormalization(epsilon=1e-6)(inputs)
    attention = MultiHeadAttention(num_heads=num_heads, key_dim=ff_dim)(x, x)
    x = Add()([x, attention])
    x = LayerNormalization(epsilon=1e-6)(x)
    ff = Dense(ff_dim, activation='relu')(x)
    ff = Dense(inputs.shape[-1])(ff)
    x = Add()([x, ff])
    return x


def build_model(use_checkpointing=False, use_full_wavelet=False):
    inp = Input(shape=(IMG_SIZE[0], IMG_SIZE[1], 3))

    # Choose wavelet type
    if use_full_wavelet:
        wavelet = Lambda(Wavelet, output_shape=Wavelet_out_shape)(inp)
    else:
        wavelet = Lambda(wavelet_transform)(inp)

    # Apply spatial attention
    w1 = spatial_attention(wavelet[0])
    w2 = spatial_attention(wavelet[1])
    w3 = spatial_attention(wavelet[2])
    w4 = spatial_attention(wavelet[3])


    x1 = Conv2D(64, 3, padding='same')(w1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)
    x1 = Conv2D(64, 3, strides=2, padding='same')(x1)
    x1 = BatchNormalization()(x1)
    x1 = Activation('relu')(x1)

    xa = Conv2D(64, 3, strides=2, padding='same')(w2)
    xa = BatchNormalization()(xa)
    xa = Activation('relu')(xa)

    x1 = MaxPooling2D(pool_size=(2, 2))(x1)
    x = Concatenate()([x1, xa])

    x = Conv2D(128, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(128, 3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    xb = Conv2D(64, 3, strides=2, padding='same')(w3)
    xb = BatchNormalization()(xb)
    xb = Activation('relu')(xb)
    xb = Conv2D(128, 3, padding='same')(xb)
    xb = BatchNormalization()(xb)
    xb = Activation('relu')(xb)

    x = Concatenate()([x, xb])
    x = Conv2D(256, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(256, 3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    xc = Conv2D(64, 3, strides=2, padding='same')(w4)
    xc = BatchNormalization()(xc)
    xc = Activation('relu')(xc)
    xc = Conv2D(256, 3, padding='same')(xc)
    xc = BatchNormalization()(xc)
    xc = Activation('relu')(xc)
    xc = Conv2D(256, 3, padding='same')(xc)
    xc = BatchNormalization()(xc)
    xc = Activation('relu')(xc)

    x = Concatenate()([x, xc])
    x = Conv2D(256, 3, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(256, 3, strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    shape = tf.keras.backend.int_shape(x)
    x = Reshape((shape[1]*shape[2], shape[3]))(x)

    if use_checkpointing:
        x = CheckpointedTransformerBlock(num_heads=4, ff_dim=128)(x)
    else:
        x = transformer_block(x, num_heads=4, ff_dim=128)

    x = GlobalAveragePooling1D()(x)
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.3)(x)
    out = Dense(1, activation='sigmoid', dtype='float32')(x)

    model = Model(inputs=inp, outputs=out)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model


def evaluate_model(model, val_gen, title):
    val_gen.reset()
    preds = model.predict(val_gen, verbose=0)
    y_true = val_gen.classes
    y_pred = (preds > 0.5).astype(int).reshape(-1)
    y_prob = preds.reshape(-1)

    print(f"\n📋 Classification Report for {title}:\n")
    print(classification_report(y_true, y_pred, digits=4))

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()
    spec = tn / (tn + fp)
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)

    print(f"✅ Overall Metrics for {title}:")
    print(f"Accuracy:     {acc:.4f}")
    print(f"Precision:    {prec:.4f}")
    print(f"Recall:       {rec:.4f}")
    print(f"F1 Score:     {f1:.4f}")
    print(f"Specificity:  {spec:.4f}")
    print(f"AUC:          {roc_auc:.4f}")

    plt.figure()
    plt.plot(fpr, tpr, label=f'{title} (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC Curve - {title}")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()


def train_and_measure_memory(use_checkpointing=False, use_full_wavelet=False):
    tf.keras.backend.clear_session()
    model = build_model(use_checkpointing=use_checkpointing, use_full_wavelet=use_full_wavelet)

    gpu_available = tf.config.list_physical_devices('GPU')
    if gpu_available:
        tf.config.experimental.reset_memory_stats('GPU:0')

    model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=EPOCHS,
        steps_per_epoch=len(train_gen),
        validation_steps=len(val_gen),
        verbose=1
    )

    title = f"{'Checkpointing' if use_checkpointing else 'No Checkpointing'} | {'Full Wavelet' if use_full_wavelet else 'Simulated Wavelet'}"
    evaluate_model(model, val_gen, title)

    if gpu_available:
        gpu_mem = tf.config.experimental.get_memory_info('GPU:0')
        peak_bytes = gpu_mem['peak'] if 'peak' in gpu_mem else gpu_mem['current']
        return peak_bytes
    return 0


print("🚀 Training WITHOUT checkpointing and simulated wavelet...")
peak_mem_no_ckpt = train_and_measure_memory(use_checkpointing=False, use_full_wavelet=False)
if peak_mem_no_ckpt > 0:
    print(f"💾 Peak GPU memory usage: {peak_mem_no_ckpt / 1e6:.2f} MB")

print("\n🚀 Training WITH checkpointing and full wavelet...")
peak_mem_ckpt = train_and_measure_memory(use_checkpointing=True, use_full_wavelet=True)
if peak_mem_ckpt > 0:
    print(f"💾 Peak GPU memory usage: {peak_mem_ckpt / 1e6:.2f} MB")

if peak_mem_no_ckpt > 0 and peak_mem_ckpt > 0:
    saved = peak_mem_no_ckpt - peak_mem_ckpt
    saved_percent = (saved / peak_mem_no_ckpt) * 100
    print(f"\n✅ Memory saved by gradient checkpointing: {saved / 1e6:.2f} MB ({saved_percent:.2f}%)")


In [None]:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tensorflow.keras.layers import Layer, Conv2D, Concatenate, Multiply

IMAGE_SIZE = (160, 160)


def WaveletTransformAxisY(input_array):
    even_indices = input_array[:, ::2]
    odd_indices = input_array[:, 1::2]
    low_frequency = (even_indices + odd_indices) / 2
    high_frequency = (even_indices - odd_indices) / 2
    return low_frequency, high_frequency

def WaveletTransformAxisX(input_array):
    even_indices = input_array[::2, :]
    odd_indices = input_array[1::2, :]
    low_frequency = (even_indices + odd_indices) / 2
    high_frequency = (even_indices - odd_indices) / 2
    return low_frequency, high_frequency

def wavelet_decompose_image(image_tensor):
    # Red channel only for visualization
    channel = image_tensor[:, :, 0]
    low, high = WaveletTransformAxisY(channel)
    low_low, low_high = WaveletTransformAxisX(low)
    high_low, high_high = WaveletTransformAxisX(high)
    return [low_low, low_high, high_low, high_high]


class SpatialAttention(Layer):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv2d = Conv2D(filters=1, kernel_size=7, strides=1,
                             padding='same', activation='sigmoid', use_bias=False)

    def call(self, input_feature):
        avg_pool = tf.reduce_mean(input_feature, axis=-1, keepdims=True)
        max_pool = tf.reduce_max(input_feature, axis=-1, keepdims=True)
        concat = Concatenate(axis=-1)([avg_pool, max_pool])
        attention = self.conv2d(concat)
        return Multiply()([input_feature, attention])


def load_and_preprocess_image(image_path, target_size=IMAGE_SIZE):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, target_size)
    image = image / 255.0
    return image

def normalize_component(component_array):
    min_val = component_array.min()
    max_val = component_array.max()
    return (component_array - min_val) / (max_val - min_val + 1e-8)


def visualize_wavelet_subbands_with_attention(image_path):
    image = load_and_preprocess_image(image_path)
    image_tensor = tf.convert_to_tensor(image, dtype=tf.float32)

    decomposed = wavelet_decompose_image(image_tensor)
    normalized = [normalize_component(comp.numpy()) for comp in decomposed]

    stacked = np.stack(normalized, axis=-1)
    stacked_tensor = tf.expand_dims(tf.convert_to_tensor(stacked, dtype=tf.float32), axis=0)

    sa = SpatialAttention()
    att_output = sa(stacked_tensor)
    attention_map = tf.reduce_mean(att_output, axis=-1)[0].numpy()

    titles = ['Low-Low (Approximation)',
              'Low-High (Horizontal Details)',
              'High-Low (Vertical Details)',
              'High-High (Diagonal Details)',
              'Spatial Attention Map']

    plt.figure(figsize=(15, 4))
    for i, comp in enumerate(normalized):
        plt.subplot(1, 5, i + 1)
        plt.imshow(comp, cmap='gray')
        plt.title(titles[i])
        plt.axis('off')

    plt.subplot(1, 5, 5)
    plt.imshow(attention_map, cmap='viridis')
    plt.title(titles[4])
    plt.axis('off')

    plt.tight_layout()
    plt.show()


visualize_wavelet_subbands_with_attention(
    "/kaggle/input/cell-images-2/cell_images/Parasitized/C100P61ThinF_IMG_20150918_144104_cell_162.png"
)
