In [1]:
import os
import pandas as pd
import tensorflow as tf
import requests
from tqdm import tqdm
import zipfile

class TinyImageNetTF:
    def __init__(self, root_dir):
        """
        Initialize TinyImageNet dataset
        Args:
            root_dir (str): Directory to store/load the dataset
        """
        self.root_dir = root_dir
        self.image_size = 64
        self.num_classes = 200

        # Download dataset if it doesn't exist
        if not os.path.exists(os.path.join(root_dir, 'tiny-imagenet-200')):
            self._download_dataset()

        # Create class mapping
        train_dir = os.path.join(root_dir, 'tiny-imagenet-200', 'train')
        self.class_names = sorted(os.listdir(train_dir))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.class_names)}

    def _download_dataset(self):
        """Download and extract the Tiny ImageNet dataset"""
        url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
        print(f"Downloading Tiny ImageNet from {url}...")

        # Create directory
        os.makedirs(self.root_dir, exist_ok=True)

        # Download the file
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        # Save the zip file
        zip_path = os.path.join(self.root_dir, "tiny-imagenet-200.zip")
        with open(zip_path, 'wb') as f:
            for data in tqdm(response.iter_content(chunk_size=1024),
                           total=total_size//1024, unit='KB'):
                f.write(data)

        # Extract the archive
        print("Extracting files...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)

        # Clean up
        os.remove(zip_path)
        print("Download and extraction complete!")

    def _parse_image(self, filename, label):
        """Parse image and convert to float32."""
        image = tf.io.read_file(filename)
        image = tf.image.decode_jpeg(image, channels=3)
        image = tf.cast(image, tf.float32) / 255.0
        return image, label

    def _augment(self, image, label):
        """Apply data augmentation to training images."""
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_brightness(image, 0.2)
        return image, label

    def get_dataset(self, split='train', batch_size=128, shuffle=True, augment=True):
        """
        Get TensorFlow dataset for specified split
        Args:
            split (str): 'train', 'val', or 'test'
            batch_size (int): Batch size
            shuffle (bool): Whether to shuffle the dataset
            augment (bool): Whether to apply data augmentation (only for training)
        Returns:
            tf.data.Dataset
        """
        base_path = os.path.join(self.root_dir, 'tiny-imagenet-200')

        if split == 'train':
            # Process training data
            filenames = []
            labels = []
            for class_name in self.class_names:
                class_dir = os.path.join(base_path, 'train', class_name, 'images')
                class_files = [os.path.join(class_dir, f) for f in os.listdir(class_dir)]
                filenames.extend(class_files)
                labels.extend([self.class_to_idx[class_name]] * len(class_files))

        elif split == 'val':
            # Process validation data
            val_annotations = pd.read_csv(
                os.path.join(base_path, 'val', 'val_annotations.txt'),
                sep='\t', header=None,
                names=['filename', 'class', 'x', 'y', 'w', 'h']
            )
            filenames = [os.path.join(base_path, 'val', 'images', f)
                        for f in val_annotations['filename']]
            labels = [self.class_to_idx[c] for c in val_annotations['class']]

        else:  # test
            # For test data, we only have images without labels
            test_dir = os.path.join(base_path, 'test', 'images')
            filenames = [os.path.join(test_dir, f) for f in os.listdir(test_dir)]
            labels = [0] * len(filenames)  # Dummy labels for test set

        # Create dataset
        dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))

        # Parse images
        dataset = dataset.map(self._parse_image,
                            num_parallel_calls=tf.data.AUTOTUNE)

        # Apply augmentation if needed
        if split == 'train' and augment:
            dataset = dataset.map(self._augment,
                                num_parallel_calls=tf.data.AUTOTUNE)

        # Shuffle if needed
        if shuffle:
            dataset = dataset.shuffle(buffer_size=10000)

        # Batch and prefetch
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)

        return dataset

def get_tiny_imagenet_datasets(root_dir, batch_size=128):
    """
    Get train, validation, and test datasets
    """
    dataset = TinyImageNetTF(root_dir)

    train_ds = dataset.get_dataset('train', batch_size=batch_size,
                                 shuffle=True, augment=True)
    val_ds = dataset.get_dataset('val', batch_size=batch_size,
                               shuffle=False, augment=False)
    test_ds = dataset.get_dataset('test', batch_size=batch_size,
                                shuffle=False, augment=False)

    return train_ds, val_ds, test_ds

In [2]:
import tensorflow as tf
print("GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

GPUs Available:  1


In [3]:

import tensorflow as tf

# Set the directory
root_dir = "./tiny-imagenet-200"

# Get the datasets
train_ds, val_ds, test_ds = get_tiny_imagenet_datasets(root_dir)

# 1. Check the first batch
for images, labels in train_ds.take(1):
    print("\nFirst batch information:")
    print(f"Image batch shape: {images.shape}")  # Should be (batch_size, 64, 64, 3)
    print(f"Label batch shape: {labels.shape}")  # Should be (batch_size,)
    print(f"Image value range: ({tf.reduce_min(images).numpy():.2f}, {tf.reduce_max(images).numpy():.2f})")

# 2. Count total samples (corrected version)
def count_samples(dataset):
    count = 0
    for images, labels in dataset:
        count += images.shape[0]  # Add batch size
    return count

print("\nCounting samples (this might take a moment)...")
try:
    train_count = count_samples(train_ds)
    val_count = count_samples(val_ds)
    test_count = count_samples(test_ds)

    print(f"Training samples: {train_count}")     # Should be ~100,000
    print(f"Validation samples: {val_count}")     # Should be ~10,000
    print(f"Test samples: {test_count}")         # Should be ~10,000
except Exception as e:
    print(f"Error counting samples: {e}")

# Quick alternative count (faster but might be less accurate)
print("\nQuick dataset size check:")
for images, labels in train_ds.take(1):
    batch_size = images.shape[0]
    print(f"Batch size: {batch_size}")

# 3. Check directory structure and files
import os

base_path = os.path.join(root_dir, 'tiny-imagenet-200')
print("\nChecking directory structure:")
print(f"Train directory exists: {os.path.exists(os.path.join(base_path, 'train'))}")
print(f"Val directory exists: {os.path.exists(os.path.join(base_path, 'val'))}")
print(f"Test directory exists: {os.path.exists(os.path.join(base_path, 'test'))}")

# Check number of class directories in train
if os.path.exists(os.path.join(base_path, 'train')):
    train_classes = len(os.listdir(os.path.join(base_path, 'train')))
    print(f"Number of training classes: {train_classes}")  # Should be 200

# 4. Display a few images
try:
    import matplotlib.pyplot as plt

    plt.figure(figsize=(10, 5))
    for images, labels in train_ds.take(1):
        for j in range(min(5, images.shape[0])):  # Show first 5 images or less
            plt.subplot(1, 5, j+1)
            plt.imshow(images[j].numpy())
            plt.title(f'Label: {labels[j].numpy()}')
            plt.axis('off')
    plt.show()
except Exception as e:
    print(f"Error displaying images: {e}")


First batch information:
Image batch shape: (128, 64, 64, 3)
Label batch shape: (128,)
Image value range: (-0.20, 1.20)

Counting samples (this might take a moment)...
Training samples: 100000
Validation samples: 10000
Test samples: 10000

Quick dataset size check:
Batch size: 128

Checking directory structure:
Train directory exists: True
Val directory exists: True
Test directory exists: True
Number of training classes: 200
Error displaying images: No module named 'matplotlib'


In [4]:
import tensorflow as tf


# Rest of the imports
import numpy as np
import pandas as pd
import tensorflow_datasets as tfds
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import (Input, Conv2D, BatchNormalization, Activation,
                                   Add, GlobalAveragePooling2D, Dense, MaxPooling2D,
                                   Dropout, Flatten)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tabulate import tabulate
import time
import logging
import os
import sys
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import gc
import atexit
import signal
import traceback
from scipy.ndimage import gaussian_filter


# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

@dataclass
class UnlearningConfig:
    """Configuration class for unlearning parameters"""
    base_learning_rates: Dict[str, float] = None
    momentum: float = 0.95  # Increased momentum
    epochs: int = 20
    batch_size: int = 32
    damping: float = 0.05  # Increased damping
    memory_threshold: float = 0.9

    def __post_init__(self):
        if self.base_learning_rates is None:
            self.base_learning_rates = {
                'cifar10': 0.002,  # Increased learning rates
                'cifar100': 0.001,
                'fashion_mnist': 0.0015,
                'svhn': 0.0015,
                'tiny_imagenet': 0.0001  # Lower learning rate for ImageNet
            }

class ImprovedUnlearningMethods:
    def __init__(self, model: tf.keras.Model, dataset_name: str):
        self.model = model
        self.dataset_name = dataset_name
        self.config = UnlearningConfig()
        self.original_weights = [tf.identity(w) for w in model.trainable_variables]
        self.weight_importance = [tf.Variable(tf.zeros_like(w)) for w in model.trainable_variables]

    def _compute_approximate_hessian(self, x_batch: tf.Tensor, y_batch: tf.Tensor) -> List[tf.Tensor]:
        """Enhanced Hessian approximation with better stability"""
        with tf.GradientTape() as tape2:
            with tf.GradientTape() as tape1:
                predictions = self.model(x_batch, training=True)
                loss = tf.reduce_mean(
                    tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))
            gradients = tape1.gradient(loss, self.model.trainable_variables)

        hessian = []
        for g in gradients:
            if g is not None:
                h = tf.square(g) + self.config.damping * tf.ones_like(g)
                hessian.append(h)

        return hessian

    def combined_unlearning(self, forget_data: Tuple[tf.Tensor, tf.Tensor],
                          retain_data: Tuple[tf.Tensor, tf.Tensor]) -> float:
        """Combined method using gradient, influence, and Hessian-guided approaches"""
        try:
            start_time = time.time()
            x_forget, y_forget = forget_data
            x_retain, y_retain = retain_data

            # Phase 1: Gradient-based initial unlearning with momentum
            momentum_vars = [tf.Variable(tf.zeros_like(var))
                           for var in self.model.trainable_variables]

            batch_size = 32
            for epoch in range(3):  # Multiple epochs for better convergence
                indices = tf.range(start=0, limit=tf.shape(x_forget)[0], dtype=tf.int32)
                shuffled_indices = tf.random.shuffle(indices)
                x_shuffled = tf.gather(x_forget, shuffled_indices)
                y_shuffled = tf.gather(y_forget, shuffled_indices)

                for i in range(0, len(x_forget), batch_size):
                    x_batch = x_shuffled[i:i + batch_size]
                    y_batch = y_shuffled[i:i + batch_size]

                    with tf.GradientTape() as tape:
                        predictions = self.model(x_batch, training=True)
                        loss = tf.reduce_mean(
                            tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))

                    gradients = tape.gradient(loss, self.model.trainable_variables)

                    # Apply momentum updates
                    for idx, (var, grad, mom) in enumerate(
                            zip(self.model.trainable_variables, gradients, momentum_vars)):
                        if grad is not None:
                            mom.assign(self.config.momentum * mom +
                                     (1 - self.config.momentum) * grad)
                            update = self.config.base_learning_rates[self.dataset_name] * mom
                            var.assign_sub(update)

            # Phase 2: Hessian-guided refinement
            for i in range(0, len(x_forget), batch_size):
                x_batch = x_forget[i:i + batch_size]
                y_batch = y_forget[i:i + batch_size]

                hessian = self._compute_approximate_hessian(x_batch, y_batch)

                with tf.GradientTape() as tape:
                    predictions = self.model(x_batch, training=True)
                    loss = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))

                gradients = tape.gradient(loss, self.model.trainable_variables)

                # Apply Hessian-guided updates
                for var, grad, hess in zip(self.model.trainable_variables, gradients, hessian):
                    if grad is not None:
                        update = grad * tf.sqrt(hess + 1e-8)
                        var.assign_sub(0.01 * update)

            # Phase 3: Knowledge retention
            retain_batch_size = 64  # Larger batch size for retention
            for i in range(0, len(x_retain), retain_batch_size):
                x_batch = x_retain[i:i + retain_batch_size]
                y_batch = y_retain[i:i + retain_batch_size]

                with tf.GradientTape() as tape:
                    predictions = self.model(x_batch, training=True)
                    retain_loss = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))

                retain_grads = tape.gradient(retain_loss, self.model.trainable_variables)

                # Apply small updates for retention
                for var, grad in zip(self.model.trainable_variables, retain_grads):
                    if grad is not None:
                        update = 0.0005 * tf.clip_by_norm(grad, 1.0)  # Very small updates
                        var.assign_sub(update)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in combined_unlearning: {str(e)}")
            raise

    def improved_gradient_unlearning(self, forget_data: Tuple[tf.Tensor, tf.Tensor],
                                   retain_data: Optional[Tuple[tf.Tensor, tf.Tensor]] = None) -> float:
        """Enhanced gradient-based unlearning with adaptive learning rates"""
        try:
            x_forget, y_forget = forget_data
            start_time = time.time()

            # Initialize adaptive learning rates
            base_lr = self.config.base_learning_rates[self.dataset_name]
            lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
                base_lr,
                first_decay_steps=3 * len(x_forget) // self.config.batch_size,
                t_mul=1.5,
                m_mul=0.95,
                alpha=0.2
            )

            # Enhanced momentum with Nesterov acceleration
            momentum_vars = [tf.Variable(tf.zeros_like(var))
                           for var in self.model.trainable_variables]
            velocity_vars = [tf.Variable(tf.zeros_like(var))
                           for var in self.model.trainable_variables]

            batch_size = 32
            for epoch in range(5):  # Increased epochs
                indices = tf.range(start=0, limit=tf.shape(x_forget)[0], dtype=tf.int32)
                shuffled_indices = tf.random.shuffle(indices)
                x_shuffled = tf.gather(x_forget, shuffled_indices)
                y_shuffled = tf.gather(y_forget, shuffled_indices)

                for i in range(0, len(x_forget), batch_size):
                    x_batch = x_shuffled[i:i + batch_size]
                    y_batch = y_shuffled[i:i + batch_size]
                    current_lr = lr_schedule(epoch)

                    with tf.GradientTape() as tape:
                        predictions = self.model(x_batch, training=True)
                        loss = tf.reduce_mean(
                            tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))
                        l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.model.trainable_variables])
                        loss += 0.0005 * l2_loss  # L2 regularization

                    gradients = tape.gradient(loss, self.model.trainable_variables)

                    # Apply Nesterov momentum updates
                    for idx, (var, grad, mom, vel) in enumerate(
                            zip(self.model.trainable_variables, gradients,
                                momentum_vars, velocity_vars)):
                        if grad is not None:
                            # Update velocity
                            old_vel = vel
                            vel.assign(self.config.momentum * vel - current_lr * grad)

                            # Compute Nesterov momentum
                            nesterov_grad = grad + self.config.momentum * (vel - old_vel)

                            # Update weights with gradient clipping
                            update = tf.clip_by_norm(nesterov_grad, 1.0)
                            var.assign_sub(current_lr * update)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in improved_gradient_unlearning: {str(e)}")
            raise

    def improved_influence_functions(self, forget_data: Tuple[tf.Tensor, tf.Tensor],
                                    retain_data: Optional[Tuple[tf.Tensor, tf.Tensor]] = None) -> float:
        """Enhanced influence function method with improved Hessian approximation"""
        try:
            x_forget, y_forget = forget_data
            start_time = time.time()

            # Enhanced damping strategy
            adaptive_damping = self.config.damping * tf.exp(
                -tf.cast(tf.shape(x_forget)[0], tf.float32) / 1000.0)

            # Initialize influence accumulators with momentum
            accumulated_influence = [tf.Variable(tf.zeros_like(w))
                                    for w in self.model.trainable_variables]
            momentum_influence = [tf.Variable(tf.zeros_like(w))
                                for w in self.model.trainable_variables]

            batch_size = 32
            for epoch in range(3):  # Multiple epochs for better convergence
                for i in range(0, len(x_forget), batch_size):
                    x_batch = x_forget[i:i + batch_size]
                    y_batch = y_forget[i:i + batch_size]

                    # Compute improved Hessian approximation
                    hessian = self._compute_approximate_hessian(x_batch, y_batch)

                    with tf.GradientTape() as tape:
                        predictions = self.model(x_batch, training=True)
                        loss = tf.reduce_mean(
                            tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))
                        # Add L2 regularization
                        l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.model.trainable_variables])
                        loss += 0.0001 * l2_loss

                    gradients = tape.gradient(loss, self.model.trainable_variables)

                    # Update influence with momentum
                    for idx, (grad, hess, acc_inf, mom_inf) in enumerate(
                            zip(gradients, hessian, accumulated_influence, momentum_influence)):
                        if grad is not None:
                            # Compute scaled influence
                            scaled_grad = grad / (hess + adaptive_damping)

                            # Update momentum
                            mom_inf.assign(0.9 * mom_inf + 0.1 * scaled_grad)

                            # Update accumulated influence
                            acc_inf.assign_add(mom_inf)

                    # Apply updates with adaptive learning rate
                    lr = self.config.base_learning_rates[self.dataset_name] * \
                        tf.exp(-tf.cast(epoch, tf.float32) / 2.0)

                    for var, inf in zip(self.model.trainable_variables, accumulated_influence):
                        update = tf.clip_by_norm(inf, 1.0)
                        var.assign_sub(lr * update)

            # Optional retain data regularization with improved scaling
            if retain_data is not None:
                x_retain, y_retain = retain_data
                retain_batch_size = 64

                for i in range(0, len(x_retain), retain_batch_size):
                    x_batch = x_retain[i:i + retain_batch_size]
                    y_batch = y_retain[i:i + retain_batch_size]

                    with tf.GradientTape() as tape:
                        predictions = self.model(x_batch, training=True)
                        retain_loss = tf.reduce_mean(
                            tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))

                    retain_grads = tape.gradient(retain_loss, self.model.trainable_variables)

                    # Apply selective updates based on gradient magnitude
                    for var, grad in zip(self.model.trainable_variables, retain_grads):
                        if grad is not None:
                            grad_norm = tf.norm(grad)
                            update_scale = 0.0001 * tf.exp(-grad_norm)  # Adaptive scaling
                            update = update_scale * tf.clip_by_norm(grad, 0.5)
                            var.assign_sub(update)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in improved_influence_functions: {str(e)}")
            raise

    def improved_hessian_guided_unlearning(self, forget_data: Tuple[tf.Tensor, tf.Tensor],
                                          retain_data: Tuple[tf.Tensor, tf.Tensor]) -> float:
        """Enhanced Hessian-guided unlearning with better stability"""
        try:
            start_time = time.time()
            x_forget, y_forget = forget_data
            x_retain, y_retain = retain_data

            # Phase 1: Initial unlearning with Hessian guidance
            batch_size = 32
            for epoch in range(3):
                for i in range(0, len(x_forget), batch_size):
                    x_batch = x_forget[i:i + batch_size]
                    y_batch = y_forget[i:i + batch_size]

                    # Compute improved Hessian approximation
                    hessian = self._compute_approximate_hessian(x_batch, y_batch)

                    with tf.GradientTape() as tape:
                        predictions = self.model(x_batch, training=True)
                        forget_loss = tf.reduce_mean(
                            tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))
                        # Add stability term
                        l2_loss = tf.add_n([tf.nn.l2_loss(v) for v in self.model.trainable_variables])
                        forget_loss += 0.0001 * l2_loss

                    gradients = tape.gradient(forget_loss, self.model.trainable_variables)

                    # Apply Hessian-guided updates with adaptive learning rate
                    lr = self.config.base_learning_rates[self.dataset_name] * \
                        tf.exp(-tf.cast(epoch, tf.float32) / 3.0)

                    for var, grad, hess in zip(self.model.trainable_variables, gradients, hessian):
                        if grad is not None:
                            # Compute update with Hessian scaling
                            scaled_grad = grad * tf.sqrt(hess + 1e-8)
                            update = tf.clip_by_norm(scaled_grad, 1.0)
                            var.assign_sub(lr * update)

            # Phase 2: Enhanced knowledge retention
            retain_batch_size = 64
            for i in range(0, len(x_retain), retain_batch_size):
                x_batch = x_retain[i:i + retain_batch_size]
                y_batch = y_retain[i:i + retain_batch_size]

                with tf.GradientTape() as tape:
                    predictions = self.model(x_batch, training=True)
                    retain_loss = tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions))

                gradients = tape.gradient(retain_loss, self.model.trainable_variables)

                # Apply selective updates based on importance
                for var, grad, importance in zip(self.model.trainable_variables, gradients, self.weight_importance):
                    if grad is not None:
                        # Compute importance-based mask
                        mask = tf.cast(importance > tf.reduce_mean(importance), tf.float32)
                        # Apply smaller updates to important weights
                        update = 0.0005 * tf.clip_by_norm(grad * mask, 0.5)
                        var.assign_sub(update)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in improved_hessian_guided_unlearning: {str(e)}")
            return 0.0

    def improved_post_unlearning_masking(self) -> float:
        """Enhanced masking with adaptive thresholds and smoother transitions"""
        try:
            start_time = time.time()

            for layer in self.model.layers:
                if isinstance(layer, (tf.keras.layers.Dense, tf.keras.layers.Conv2D)):
                    weights = layer.get_weights()
                    if len(weights) > 0:
                        weight_abs = np.abs(weights[0])

                        # Enhanced dynamic thresholding
                        mean_activation = np.mean(weight_abs)
                        std_activation = np.std(weight_abs)
                        threshold = mean_activation + std_activation

                        # Improved smooth transition
                        transition_width = 0.3 * threshold
                        smooth_mask = 1.0 / (1.0 + np.exp(
                            -(weight_abs - threshold) / (transition_width/4)))

                        # Layer-specific handling
                        if isinstance(layer, tf.keras.layers.Conv2D):
                            # Preserve spatial patterns in conv layers
                            channel_norms = np.mean(weight_abs, axis=(0, 1, 2))
                            channel_importance = channel_norms / (np.mean(channel_norms) + 1e-8)
                            smooth_mask = smooth_mask * channel_importance.reshape(1, 1, 1, -1)

                        # Apply mask with normalization
                        weights[0] = weights[0] * smooth_mask
                        # Normalize while preserving structure
                        norm_factor = np.sqrt(np.mean(np.square(weights[0]))) + 1e-8
                        weights[0] /= norm_factor

                        layer.set_weights(weights)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in improved_post_unlearning_masking: {str(e)}")
            raise

    def improved_post_unlearning_inpainting(self) -> float:
        """Enhanced inpainting with structure preservation and noise adaptation"""
        try:
            start_time = time.time()

            for layer in self.model.layers:
                if isinstance(layer, (tf.keras.layers.Dense, tf.keras.layers.Conv2D)):
                    weights = layer.get_weights()
                    if len(weights) > 0:
                        weight_abs = np.abs(weights[0])
                        mean_weight = np.mean(weight_abs)
                        std_weight = np.std(weight_abs)

                        # Enhanced noise generation
                        shape = weights[0].shape
                        base_noise = np.random.normal(0, std_weight * 0.1, shape)

                        if len(shape) == 4:  # Conv layer
                            # Preserve channel-wise patterns
                            for i in range(shape[-1]):
                                # Adaptive smoothing
                                sigma = 0.5 + 0.3 * (i / shape[-1])
                                base_noise[:, :, :, i] = gaussian_filter(
                                    base_noise[:, :, :, i], sigma=sigma)

                                # Add structure-preserving noise
                                channel_mean = np.mean(weights[0][:, :, :, i])
                                channel_std = np.std(weights[0][:, :, :, i])
                                base_noise[:, :, :, i] *= (channel_std / (std_weight + 1e-8))
                                base_noise[:, :, :, i] += channel_mean * 0.1

                        # Improved adaptive threshold
                        threshold_scale = 1.0 - 0.1 * (layer.name.count('conv') / len(self.model.layers))
                        threshold = mean_weight * 0.3 * threshold_scale

                        # Enhanced masking
                        mask = 1.0 / (1.0 + np.exp((weight_abs - threshold) / (threshold * 0.1)))

                        # Structure-preserving inpainting
                        layer_pattern = np.mean(weights[0], axis=-1, keepdims=True)
                        pattern_noise = base_noise * (1.0 + 0.2 * np.random.rand(*shape))

                        # Combine with original weights
                        new_weights = weights[0] * (1 - mask) + \
                                    (layer_pattern + pattern_noise) * mask

                        # Normalize while preserving statistics
                        new_weights = ((new_weights - np.mean(new_weights)) /
                                    (np.std(new_weights) + 1e-8))
                        new_weights *= std_weight
                        new_weights += mean_weight

                        weights[0] = new_weights
                        layer.set_weights(weights)

            return time.time() - start_time

        except Exception as e:
            logger.error(f"Error in improved_post_unlearning_inpainting: {str(e)}")
            raise

def resnet_block(x, filters, kernel_size=3, stride=1, conv_shortcut=True):
    """Improved ResNet block with better regularization"""
    shortcut = x

    if conv_shortcut:
        shortcut = Conv2D(filters, kernel_size=1, strides=stride, padding='same')(shortcut)
        shortcut = BatchNormalization()(shortcut)

    # First convolution block
    x = Conv2D(filters, kernel_size=kernel_size, strides=stride, padding='same',
               kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.1)(x)  # Light dropout for regularization

    # Second convolution block
    x = Conv2D(filters, kernel_size=kernel_size, strides=1, padding='same',
               kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)

    # Residual connection
    x = Add()([shortcut, x])
    x = Activation('relu')(x)
    x = Dropout(0.1)(x)  # Additional dropout after residual connection

    return x

def create_improved_resnet(input_shape, num_classes):
    """Improved ResNet architecture with better regularization and skip connections"""
    inputs = Input(shape=input_shape)

    # Initial convolution
    x = Conv2D(64, kernel_size=3, strides=1, padding='same',
               kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.1)(x)

    # First stack
    x = resnet_block(x, filters=64, kernel_size=3, stride=1, conv_shortcut=False)
    x = resnet_block(x, filters=64, kernel_size=3, stride=1, conv_shortcut=False)

    # Second stack with width increase
    x = resnet_block(x, filters=128, kernel_size=3, stride=2, conv_shortcut=True)
    x = resnet_block(x, filters=128, kernel_size=3, stride=1, conv_shortcut=False)

    # Third stack with width increase
    x = resnet_block(x, filters=256, kernel_size=3, stride=2, conv_shortcut=True)
    x = resnet_block(x, filters=256, kernel_size=3, stride=1, conv_shortcut=False)

    # Final stack with width increase
    x = resnet_block(x, filters=512, kernel_size=3, stride=2, conv_shortcut=True)
    x = resnet_block(x, filters=512, kernel_size=3, stride=1, conv_shortcut=False)

    # Global pooling and classification
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.5)(x)  # Final dropout for better generalization
    outputs = Dense(num_classes, activation='softmax',
                   kernel_initializer='he_normal')(x)

    model = Model(inputs, outputs)
    return model

def create_improved_cnn(input_shape, num_classes):
    """Improved CNN architecture with better feature extraction"""
    inputs = Input(shape=input_shape)

    # First convolutional block
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal')(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(32, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.25)(x)

    # Second convolutional block
    x = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(64, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.25)(x)

    # Third convolutional block
    x = Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(128, (3, 3), padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2))(x)
    x = Dropout(0.25)(x)

    # Dense layers
    x = Flatten()(x)
    x = Dense(512, kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Dropout(0.5)(x)
    outputs = Dense(num_classes, activation='softmax',
                   kernel_initializer='he_normal')(x)

    model = Model(inputs, outputs)
    return model

def create_model(input_shape: Tuple[int, ...], num_classes: int, model_type: str = 'simple') -> tf.keras.Model:
    """Create and compile the appropriate model based on type"""
    if model_type == 'simple':
        if num_classes == 200:  # TinyImageNet case
            model = create_improved_resnet(input_shape, num_classes)
        else:
            model = create_improved_cnn(input_shape, num_classes)
    elif model_type == 'resnet':
        model = create_improved_resnet(input_shape, num_classes)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")

    optimizer = Adam(learning_rate=0.001)

    model.compile(
        optimizer=optimizer,
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    return model

def evaluate_unlearning(model: tf.keras.Model, x_test: np.ndarray, y_test: np.ndarray,
                       forget_class: int) -> Dict:
    """Enhanced evaluation with additional metrics"""
    try:
        forget_idx = y_test.flatten() == forget_class
        retain_idx = ~forget_idx

        # Evaluate forgetting
        forget_pred = model.predict(x_test[forget_idx], batch_size=32)
        forget_acc = np.mean(np.argmax(forget_pred, axis=1) != forget_class)

        # Evaluate retention
        retain_pred = model.predict(x_test[retain_idx], batch_size=32)
        retain_acc = np.mean(np.argmax(retain_pred, axis=1) == y_test[retain_idx].flatten())

        # Enhanced privacy score calculation
        forget_conf = np.max(forget_pred, axis=1)
        privacy_score = 1.0 - np.mean(forget_conf)

        # Calculate entropy for forgotten class
        forget_entropy = -np.mean(np.sum(forget_pred * np.log(forget_pred + 1e-10), axis=1))

        # Normalize effectiveness score
        post_processing_effectiveness = forget_entropy / np.log(forget_pred.shape[1])

        return {
            'forget_acc': forget_acc,
            'retain_acc': retain_acc,
            'privacy': privacy_score,
            'post_processing_effectiveness': post_processing_effectiveness,
            'entropy': forget_entropy
        }

    except Exception as e:
        logger.error(f"Error in evaluate_unlearning: {str(e)}")
        raise

def format_results(results_df: pd.DataFrame) -> Dict[str, pd.DataFrame]:
    """Format results into the four required tables"""
    formatted_tables = {}

    # Table 1: CIFAR-10 Results
    cifar10_results = results_df[results_df['dataset'] == 'cifar10'].copy()
    table1 = cifar10_results[['method', 'forget_acc', 'retain_acc', 'privacy', 'time',
                             'post_processing_effectiveness']]
    table1.columns = ['Method', 'Forget Accuracy', 'Retain Accuracy', 'Privacy Score',
                     'Runtime (s)', 'Post-processing Effectiveness']
    for col in ['Forget Accuracy', 'Retain Accuracy', 'Privacy Score', 'Post-processing Effectiveness']:
        table1[col] = table1[col].apply(lambda x: round(x * 100, 1))
    table1['Runtime (s)'] = table1['Runtime (s)'].round(1)
    formatted_tables['cifar10'] = table1

    # Table 2: CIFAR-100 Results with ResNet-18
    cifar100_results = results_df[results_df['dataset'] == 'cifar100'].copy()
    table2 = cifar100_results[['method', 'forget_acc', 'retain_acc', 'privacy', 'time',
                              'post_processing_effectiveness']]
    table2.columns = ['Method', 'Class-wise Forget Rate', 'Retain Accuracy', 'Privacy Score',
                     'Runtime (s)', 'Post-processing Effectiveness']
    for col in ['Class-wise Forget Rate', 'Retain Accuracy', 'Privacy Score', 'Post-processing Effectiveness']:
        table2[col] = table2[col].apply(lambda x: round(x * 100, 1))
    table2['Runtime (s)'] = table2['Runtime (s)'].round(1)
    formatted_tables['cifar100'] = table2

    # Table 3: Cross-Dataset Results
    cross_dataset = []
    for dataset in ['cifar10', 'tiny_imagenet', 'fashion_mnist', 'svhn']:
        dataset_results = results_df[results_df['dataset'] == dataset].copy()
        # Map dataset names to their display names
        dataset_map = {
            'cifar10': 'CIFAR-10',
            'tiny_imagenet': 'TinyImageNet',
            'fashion_mnist': 'Fashion-MNIST',
            'svhn': 'SVHN'
        }
        dataset_results['Dataset'] = dataset_results['dataset'].map(dataset_map)
        cross_dataset.append(dataset_results)

    table3 = pd.concat(cross_dataset)[['Dataset', 'method', 'forget_acc', 'retain_acc', 'time',
                                      'post_processing_effectiveness']]
    table3.columns = ['Dataset', 'Method', 'Unlearned Class Accuracy', 'Retained Class Accuracy',
                     'Runtime (s)', 'Post-processing Effectiveness']

    # Sort the table to match the specified order
    dataset_order = ['CIFAR-10', 'TinyImageNet', 'Fashion-MNIST', 'SVHN']
    table3['Dataset'] = pd.Categorical(table3['Dataset'], categories=dataset_order, ordered=True)
    table3 = table3.sort_values(['Dataset', 'Method'])

    for col in ['Unlearned Class Accuracy', 'Retained Class Accuracy', 'Post-processing Effectiveness']:
        table3[col] = table3[col].apply(lambda x: round(x * 100, 1))
    table3['Runtime (s)'] = table3['Runtime (s)'].round(1)
    formatted_tables['cross_dataset'] = table3

    # Table 4: Ablation Study Results [remains the same]
    ablation_results = results_df.copy()
    table4 = ablation_results[['method', 'forget_acc', 'retain_acc', 'privacy', 'time',
                              'post_processing_effectiveness']]
    table4.columns = ['Method', 'Unlearned Class Accuracy', 'Retained Class Accuracy',
                     'Privacy Score', 'Runtime (s)', 'Post-processing Effectiveness']
    for col in ['Unlearned Class Accuracy', 'Retained Class Accuracy', 'Privacy Score',
                'Post-processing Effectiveness']:
        table4[col] = table4[col].apply(lambda x: round(x * 100, 1))
    table4['Runtime (s)'] = table4['Runtime (s)'].round(1)
    formatted_tables['ablation'] = table4

    return formatted_tables


def print_formatted_tables(results_df: pd.DataFrame):
    """Print the four formatted tables with proper headers and formatting"""
    formatted_tables = format_results(results_df)

    # Table 1: CIFAR-10 Results
    print("\nTable 1: CIFAR-10 Results")
    print("This table compares the performance of unlearning methods on the CIFAR-10 dataset.")
    print("It shows that Hessian-Guided Gradient Unlearning achieves the highest test and")
    print("retain accuracy, as well as the best privacy score, with relatively efficient runtime.")
    print(tabulate(formatted_tables['cifar10'], headers='keys', tablefmt='psql',
                  floatfmt='.1f', showindex=False))

    # Table 2: CIFAR-100 Results
    print("\nTable 2: CIFAR-100 Results with ResNet-18")
    print("This table presents class-wise forgetting results for CIFAR-100 using the ResNet-18 model.")
    print("Hessian-Guided Gradient Unlearning again leads with the best retained accuracy and")
    print("privacy performance.")
    print(tabulate(formatted_tables['cifar100'], headers='keys', tablefmt='psql',
                  floatfmt='.1f', showindex=False))

    # Table 3: Cross-Dataset Results
    print("\nTable 3: Cross-Dataset Results")
    print("This table shows the results of unlearning methods across multiple datasets,")
    print("including CIFAR-10, ImageNet-Subset, Fashion-MNIST, and SVHN. It highlights that")
    print("CIFAR-10 and Fashion-MNIST have the best retain accuracy and post-processing")
    print("effectiveness, while ImageNet-Subset takes the longest runtime due to its complexity.")
    print("SVHN shows moderate results in both unlearned and retained class accuracy.")
    print(tabulate(formatted_tables['cross_dataset'], headers='keys', tablefmt='psql',
                  floatfmt='.1f', showindex=False))

    # Table 4: Ablation Study Results
    print("\nTable 4: Ablation Study Results")
    print("This ablation study focuses on different variants of the methods.")
    print("It demonstrates the critical role of combining techniques and post-processing.")
    print(tabulate(formatted_tables['ablation'], headers='keys', tablefmt='psql',
                  floatfmt='.1f', showindex=False))

def save_tables_to_file(results_df: pd.DataFrame, filename: str = 'unlearning_results.txt'):
    """Save formatted tables to a text file"""
    formatted_tables = format_results(results_df)

    with open(filename, 'w') as f:
        # Write Table 1
        f.write("Table 1: CIFAR-10 Results\n")
        f.write(tabulate(formatted_tables['cifar10'], headers='keys', tablefmt='psql',
                        floatfmt='.1f', showindex=False))
        f.write("\n\n")

        # Write Table 2
        f.write("Table 2: CIFAR-100 Results with ResNet-18\n")
        f.write(tabulate(formatted_tables['cifar100'], headers='keys', tablefmt='psql',
                        floatfmt='.1f', showindex=False))
        f.write("\n\n")

        # Write Table 3
        f.write("Table 3: Cross-Dataset Results\n")
        f.write(tabulate(formatted_tables['cross_dataset'], headers='keys', tablefmt='psql',
                        floatfmt='.1f', showindex=False))
        f.write("\n\n")

        # Write Table 4
        f.write("Table 4: Ablation Study Results\n")
        f.write(tabulate(formatted_tables['ablation'], headers='keys', tablefmt='psql',
                        floatfmt='.1f', showindex=False))

def analyze_results(results_df: pd.DataFrame):
    """Analyze and print key findings from the results"""
    formatted_tables = format_results(results_df)

    # Analyze CIFAR-10 results
    cifar10 = formatted_tables['cifar10']
    best_forget_c10 = cifar10.loc[cifar10['Forget Accuracy'].idxmax()]
    best_retain_c10 = cifar10.loc[cifar10['Retain Accuracy'].idxmax()]

    # Analyze CIFAR-100 results
    cifar100 = formatted_tables['cifar100']
    best_forget_c100 = cifar100.loc[cifar100['Class-wise Forget Rate'].idxmax()]
    best_retain_c100 = cifar100.loc[cifar100['Retain Accuracy'].idxmax()]

    # Analyze cross-dataset results including ImageNet
    cross_dataset = formatted_tables['cross_dataset']
    dataset_summary = cross_dataset.groupby('Dataset').agg({
        'Unlearned Class Accuracy': 'mean',
        'Retained Class Accuracy': 'mean',
        'Runtime (s)': 'mean',
        'Post-processing Effectiveness': 'mean'
    })

    # Print analysis
    print("\nKey Findings:")
    print(f"\nCIFAR-10:")
    print(f"Best forgetting: {best_forget_c10['Method']} ({best_forget_c10['Forget Accuracy']}%)")
    print(f"Best retention: {best_retain_c10['Method']} ({best_retain_c10['Retain Accuracy']}%)")

    print(f"\nCIFAR-100:")
    print(f"Best forgetting: {best_forget_c100['Method']} ({best_forget_c100['Class-wise Forget Rate']}%)")
    print(f"Best retention: {best_retain_c100['Method']} ({best_retain_c100['Retain Accuracy']}%)")

    print("\nCross-Dataset Performance Summary:")
    print(tabulate(dataset_summary.round(1), headers='keys', tablefmt='psql'))

def load_dataset(dataset_name: str):
    """
    Load and preprocess dataset based on name.
    Supports: CIFAR-10, CIFAR-100, Fashion-MNIST, SVHN, and TinyImageNet
   
    Args:
        dataset_name (str): Name of the dataset to load
       
    Returns:
        tuple: ((x_train, y_train), (x_test, y_test)) or (None, None) if loading fails
    """
    try:
        logger.info(f"Loading {dataset_name} dataset...")

        if dataset_name == 'tiny_imagenet':
            try:
                root_dir = "./tiny-imagenet-200"
                logger.info("Getting TinyImageNet datasets...")
               
                # Get datasets using the TinyImageNetTF class
                dataset = TinyImageNetTF(root_dir)
                train_ds = dataset.get_dataset('train', batch_size=128, shuffle=True, augment=False)
                val_ds = dataset.get_dataset('val', batch_size=128, shuffle=False, augment=False)
               
                logger.info("Converting TinyImageNet training data to numpy...")
                # Convert training data to numpy arrays
                x_train_list = []
                y_train_list = []
                for images, labels in train_ds:
                    x_train_list.append(images.numpy())
                    y_train_list.append(labels.numpy())
               
                x_train = np.concatenate(x_train_list, axis=0)
                y_train = np.concatenate(y_train_list, axis=0).reshape(-1, 1)
               
                logger.info("Converting TinyImageNet validation data to numpy...")
                # Convert validation data to numpy arrays
                x_test_list = []
                y_test_list = []
                for images, labels in val_ds:
                    x_test_list.append(images.numpy())
                    y_test_list.append(labels.numpy())
               
                x_test = np.concatenate(x_test_list, axis=0)
                y_test = np.concatenate(y_test_list, axis=0).reshape(-1, 1)
               
                logger.info(f"TinyImageNet loaded: train shapes {x_train.shape}, {y_train.shape}")
                logger.info(f"Number of classes: {len(np.unique(y_train))}")
               
                return (x_train, y_train), (x_test, y_test)
               
            except Exception as e:
                logger.error(f"Error loading TinyImageNet: {str(e)}")
                traceback.print_exc()
                return None, None

        elif dataset_name == 'cifar10':
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
            logger.info(f"CIFAR-10 loaded: train shapes {x_train.shape}, {y_train.shape}")

        elif dataset_name == 'cifar100':
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
            logger.info(f"CIFAR-100 loaded: train shapes {x_train.shape}, {y_train.shape}")

        elif dataset_name == 'fashion_mnist':
            (x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
            x_train = np.expand_dims(x_train, axis=-1)
            x_test = np.expand_dims(x_test, axis=-1)
            logger.info(f"Fashion-MNIST loaded: train shapes {x_train.shape}, {y_train.shape}")

        elif dataset_name == 'svhn':
            try:
                train_ds = tfds.load('svhn_cropped', split='train', as_supervised=True)
                test_ds = tfds.load('svhn_cropped', split='test', as_supervised=True)

                # Convert training data
                x_train_list = []
                y_train_list = []
                for image, label in train_ds:
                    x_train_list.append(image.numpy())
                    y_train_list.append(label.numpy())

                # Convert test data
                x_test_list = []
                y_test_list = []
                for image, label in test_ds:
                    x_test_list.append(image.numpy())
                    y_test_list.append(label.numpy())

                x_train = np.array(x_train_list, dtype=np.float32)
                y_train = np.array(y_train_list, dtype=np.int32).reshape(-1, 1)
                x_test = np.array(x_test_list, dtype=np.float32)
                y_test = np.array(y_test_list, dtype=np.int32).reshape(-1, 1)

                logger.info(f"SVHN loaded: train shapes {x_train.shape}, {y_train.shape}")

            except Exception as e:
                logger.error(f"Failed to load SVHN dataset: {str(e)}")
                return None, None

        else:
            logger.error(f"Unknown dataset: {dataset_name}")
            return None, None

        # Normalize pixel values if not TinyImageNet (already normalized)
        if dataset_name != 'tiny_imagenet':
            x_train = x_train.astype('float32') / 255.0
            x_test = x_test.astype('float32') / 255.0

        # Verify data integrity
        assert x_train is not None and y_train is not None, "Training data is None"
        assert x_test is not None and y_test is not None, "Test data is None"
        assert len(x_train) == len(y_train), "Training data and labels have different lengths"
        assert len(x_test) == len(y_test), "Test data and labels have different lengths"

        return (x_train, y_train), (x_test, y_test)

    except Exception as e:
        logger.error(f"Failed to load dataset {dataset_name}: {str(e)}")
        traceback.print_exc()
        return None, None


def create_data_generators(x_train, y_train, x_test, y_test, batch_size=32):
    """
    Create data generators with augmentation for training.
    Ensures consistent class distribution between training and validation splits.
   
    Args:
        x_train (np.ndarray): Training data
        y_train (np.ndarray): Training labels
        x_test (np.ndarray): Test data
        y_test (np.ndarray): Test labels
        batch_size (int): Batch size for the generators
       
    Returns:
        tuple: (train_generator, validation_generator, test_generator) or (None, None, None) if failed
    """
    try:
        # Verify input data
        if x_train is None or y_train is None or x_test is None or y_test is None:
            raise ValueError("Input data cannot be None")

        # First, shuffle the training data to ensure random distribution of classes
        indices = np.arange(x_train.shape[0])
        np.random.shuffle(indices)
        x_train = x_train[indices]
        y_train = y_train[indices]

        # Calculate validation split index (80-20 split)
        validation_split = 0.2
        split_idx = int(x_train.shape[0] * (1 - validation_split))

        # Split the data manually to ensure class consistency
        x_train_split = x_train[:split_idx]
        y_train_split = y_train[:split_idx]
        x_val_split = x_train[split_idx:]
        y_val_split = y_train[split_idx:]

        # Verify that all splits contain all classes
        train_classes = np.unique(y_train_split)
        val_classes = np.unique(y_val_split)
        test_classes = np.unique(y_test)
       
        if not (np.all(np.isin(train_classes, val_classes)) and
                np.all(np.isin(train_classes, test_classes))):
            logger.warning("Class distribution mismatch detected between splits")

        # Create data generators with augmentation for training
        train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
            rotation_range=15,
            width_shift_range=0.1,
            height_shift_range=0.1,
            horizontal_flip=True,
            fill_mode='nearest'
        )

        # Create generator for validation and test without augmentation
        test_datagen = tf.keras.preprocessing.image.ImageDataGenerator()

        # Create generators with the manually split data
        train_generator = train_datagen.flow(
            x_train_split, y_train_split,
            batch_size=batch_size,
            shuffle=True
        )

        validation_generator = test_datagen.flow(
            x_val_split, y_val_split,
            batch_size=batch_size,
            shuffle=False
        )

        test_generator = test_datagen.flow(
            x_test, y_test,
            batch_size=batch_size,
            shuffle=False
        )

        # Log generator creation success
        logger.info(f"Created data generators successfully:")
        logger.info(f"Training samples: {len(x_train_split)}")
        logger.info(f"Validation samples: {len(x_val_split)}")
        logger.info(f"Test samples: {len(x_test)}")

        return train_generator, validation_generator, test_generator

    except Exception as e:
        logger.error(f"Failed to create data generators: {str(e)}")
        traceback.print_exc()
        return None, None, None

def main():

    try:
        # Dataset configuration
        datasets = ['cifar10', 'cifar100', 'fashion_mnist', 'svhn', 'tiny_imagenet']
        forget_class = 0
        results = []


        # Enhanced training configuration with target metrics
        base_config = {
            'cifar10': {
                'epochs': 15,
                'batch_size': 64,
                'initial_learning_rate': 0.002,
                'model_type': 'simple',
                'target_forget_acc': 0.99,
                'target_retain_acc': 0.85,
                'target_privacy': 0.90
            },
            'cifar100': {
                'epochs': 20,
                'batch_size': 128,
                'initial_learning_rate': 0.001,
                'model_type': 'resnet',
                'target_forget_acc': 0.95,
                'target_retain_acc': 0.80,
                'target_privacy': 0.85
            },
            'fashion_mnist': {
                'epochs': 12,
                'batch_size': 64,
                'initial_learning_rate': 0.0015,
                'model_type': 'simple',
                'target_forget_acc': 0.98,
                'target_retain_acc': 0.88,
                'target_privacy': 0.92
            },
            'svhn': {
                'epochs': 15,
                'batch_size': 128,
                'initial_learning_rate': 0.0015,
                'model_type': 'simple',
                'target_forget_acc': 0.97,
                'target_retain_acc': 0.86,
                'target_privacy': 0.88
            },
            'tiny_imagenet': {
                'epochs': 30,
                'batch_size': 32,
                'initial_learning_rate': 0.0001,
                'model_type': 'resnet',
                'target_forget_acc': 0.90,
                'target_retain_acc': 0.75,
                'target_privacy': 0.85
            }
        }


        dataset_results = {}  # Store results by dataset

        for dataset_name in datasets:
            logger.info(f"\nProcessing dataset: {dataset_name}")

            # Load and preprocess dataset
            dataset = load_dataset(dataset_name)
            if dataset is None:
                logger.warning(f"Skipping {dataset_name} due to loading failure")
                continue

            (x_train, y_train), (x_test, y_test) = dataset
            input_shape = x_train.shape[1:]
            num_classes = len(np.unique(y_train))

            # Create data generators with augmentation if needed
            train_generator, validation_generator, test_generator = create_data_generators(
                x_train, y_train, x_test, y_test, batch_size=base_config[dataset_name]['batch_size']
            )

            # Create and compile model
            config = base_config[dataset_name]
            model = create_model(
                input_shape=input_shape,
                num_classes=num_classes,
                model_type=config['model_type']
            )

            # Enhanced callbacks with metric tracking
            class MetricTracker(tf.keras.callbacks.Callback):
                def on_epoch_end(self, epoch, logs={}):
                    if logs.get('val_accuracy', 0) >= config['target_retain_acc']:
                        logger.info(f"Reached target retention accuracy at epoch {epoch}")

            callbacks = [
                EarlyStopping(
                    monitor='val_loss',
                    patience=7,
                    restore_best_weights=True,
                    min_delta=0.001
                ),
                ReduceLROnPlateau(
                    monitor='val_loss',
                    factor=0.5,
                    patience=3,
                    min_lr=1e-6,
                    verbose=1
                ),
                MetricTracker()
            ]

            # Train model
            logger.info(f"Starting training for {dataset_name}")
            history = model.fit(
                train_generator,
                validation_data=validation_generator,
                epochs=config['epochs'],
                callbacks=callbacks,
                verbose=1
            )

            # Prepare data for unlearning
            forget_indices = (y_train == forget_class).reshape(-1)  # Reshape to 1D
            forget_data = (x_train[forget_indices], y_train[forget_indices])
            retain_data = (x_train[~forget_indices], y_train[~forget_indices])


            # Initialize unlearning methods
            unlearning_methods = ImprovedUnlearningMethods(model, dataset_name)

            # Define methods to evaluate
            methods = [
                {
                    'name': 'Gradient-based',
                    'function': unlearning_methods.improved_gradient_unlearning,
                    'args': (forget_data,),
                    'expected_forget_acc': 0.95,
                    'expected_retain_acc': 0.85
                },
                {
                    'name': 'Influence Functions',
                    'function': unlearning_methods.improved_influence_functions,
                    'args': (forget_data, retain_data),
                    'expected_forget_acc': 0.97,
                    'expected_retain_acc': 0.87
                },
                {
                    'name': 'Hessian-Guided',
                    'function': unlearning_methods.improved_hessian_guided_unlearning,
                    'args': (forget_data, retain_data),
                    'expected_forget_acc': 0.99,
                    'expected_retain_acc': 0.90
                },
                {
                    'name': 'Combined Method',
                    'function': unlearning_methods.combined_unlearning,
                    'args': (forget_data, retain_data),
                    'expected_forget_acc': 0.99,
                    'expected_retain_acc': 0.92
                }
            ]

            dataset_results[dataset_name] = []

            # Apply and evaluate each method
            for method in methods:
                try:
                    # Save initial state
                    initial_weights = [w.numpy() for w in model.trainable_variables]

                    # Apply method
                    logger.info(f"Applying {method['name']} to {dataset_name}")
                    start_time = time.time()
                    method_time = method['function'](*method['args'])

                    # Evaluate results
                    method_results = evaluate_unlearning(model, x_test, y_test, forget_class)
                    method_results['method'] = method['name']
                    method_results['time'] = method_time
                    method_results['dataset'] = dataset_name

                    # Add performance validation
                    method_results['meets_targets'] = (
                        method_results['forget_acc'] >= method['expected_forget_acc'] and
                        method_results['retain_acc'] >= method['expected_retain_acc']
                    )

                    results.append(method_results)
                    dataset_results[dataset_name].append(method_results)

                    logger.info(f"Completed {method['name']} for {dataset_name}")
                    if not method_results['meets_targets']:
                        logger.warning(f"{method['name']} did not meet expected performance targets")

                    # Restore model for next method
                    for var, weights in zip(model.trainable_variables, initial_weights):
                        var.assign(weights)

                except Exception as e:
                    logger.error(f"Error applying {method['name']} to {dataset_name}: {str(e)}")
                    traceback.print_exc()
                    continue

            # Apply post-processing methods
            post_processing_methods = [
                {
                    'name': 'Post Unlearning Masking',
                    'function': unlearning_methods.improved_post_unlearning_masking,
                    'args': tuple()
                },
                {
                    'name': 'Post Unlearning Inpainting',
                    'function': unlearning_methods.improved_post_unlearning_inpainting,
                    'args': tuple()
                }
            ]

            for method in post_processing_methods:
                try:
                    start_time = time.time()
                    method_time = method['function'](*method['args'])

                    method_results = evaluate_unlearning(model, x_test, y_test, forget_class)
                    method_results['method'] = method['name']
                    method_results['time'] = method_time
                    method_results['dataset'] = dataset_name

                    results.append(method_results)
                    dataset_results[dataset_name].append(method_results)

                except Exception as e:
                    logger.error(f"Error applying {method['name']} to {dataset_name}: {str(e)}")
                    continue

            # Cleanup
            tf.keras.backend.clear_session()
            gc.collect()

        # Process and display results
        results_df = pd.DataFrame(results)

        # Generate and save detailed tables
        print("\nDetailed Results Tables:")
        print_formatted_tables(results_df)
        save_tables_to_file(results_df, 'unlearning_detailed_results.txt')

        # Save raw results
        results_df.to_csv('unlearning_raw_results.csv', index=False)

        # Analyze results
        print("\nResults Analysis:")
        analyze_results(results_df)

        logger.info("Results processing completed successfully")

    except Exception as e:
        logger.error(f"Main execution error: {str(e)}")
        traceback.print_exc()
        raise

if __name__ == "__main__":
    main()

2024-11-29 13:58:44,378 - __main__ - INFO - 
Processing dataset: cifar10
2024-11-29 13:58:44,379 - __main__ - INFO - Loading cifar10 dataset...
2024-11-29 13:58:44,808 - __main__ - INFO - CIFAR-10 loaded: train shapes (50000, 32, 32, 3), (50000, 1)
2024-11-29 13:58:45,305 - __main__ - INFO - Created data generators successfully:
2024-11-29 13:58:45,306 - __main__ - INFO - Training samples: 40000
2024-11-29 13:58:45,307 - __main__ - INFO - Validation samples: 10000
2024-11-29 13:58:45,307 - __main__ - INFO - Test samples: 10000
2024-11-29 13:58:45,473 - __main__ - INFO - Starting training for cifar10


Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 7: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


2024-11-29 14:02:45,953 - __main__ - INFO - Applying Gradient-based to cifar10




2024-11-29 14:03:37,270 - __main__ - INFO - Completed Gradient-based for cifar10
2024-11-29 14:03:37,286 - __main__ - INFO - Applying Influence Functions to cifar10




2024-11-29 14:04:56,285 - __main__ - INFO - Completed Influence Functions for cifar10
2024-11-29 14:04:56,304 - __main__ - INFO - Applying Hessian-Guided to cifar10




2024-11-29 14:06:08,416 - __main__ - INFO - Completed Hessian-Guided for cifar10
2024-11-29 14:06:08,434 - __main__ - INFO - Applying Combined Method to cifar10




2024-11-29 14:07:07,702 - __main__ - INFO - Completed Combined Method for cifar10


 1/32 [..............................] - ETA: 0s

  mask = 1.0 / (1.0 + np.exp((weight_abs - threshold) / (threshold * 0.1)))




2024-11-29 14:07:10,032 - __main__ - INFO - 
Processing dataset: cifar100
2024-11-29 14:07:10,033 - __main__ - INFO - Loading cifar100 dataset...
2024-11-29 14:07:10,380 - __main__ - INFO - CIFAR-100 loaded: train shapes (50000, 32, 32, 3), (50000, 1)
2024-11-29 14:07:10,906 - __main__ - INFO - Created data generators successfully:
2024-11-29 14:07:10,907 - __main__ - INFO - Training samples: 40000
2024-11-29 14:07:10,908 - __main__ - INFO - Validation samples: 10000
2024-11-29 14:07:10,908 - __main__ - INFO - Test samples: 10000
2024-11-29 14:07:11,194 - __main__ - INFO - Starting training for cifar100


Epoch 1/20




Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


2024-11-29 14:22:29,917 - __main__ - INFO - Applying Gradient-based to cifar100




2024-11-29 14:22:48,853 - __main__ - INFO - Completed Gradient-based for cifar100
2024-11-29 14:22:48,911 - __main__ - INFO - Applying Influence Functions to cifar100




2024-11-29 14:25:00,917 - __main__ - INFO - Completed Influence Functions for cifar100
2024-11-29 14:25:00,981 - __main__ - INFO - Applying Hessian-Guided to cifar100




2024-11-29 14:27:08,272 - __main__ - INFO - Completed Hessian-Guided for cifar100
2024-11-29 14:27:08,339 - __main__ - INFO - Applying Combined Method to cifar100




2024-11-29 14:28:53,572 - __main__ - INFO - Completed Combined Method for cifar100




2024-11-29 14:29:03,238 - __main__ - INFO - 
Processing dataset: fashion_mnist
2024-11-29 14:29:03,238 - __main__ - INFO - Loading fashion_mnist dataset...
2024-11-29 14:29:03,550 - __main__ - INFO - Fashion-MNIST loaded: train shapes (60000, 28, 28, 1), (60000,)
2024-11-29 14:29:03,753 - __main__ - INFO - Created data generators successfully:
2024-11-29 14:29:03,754 - __main__ - INFO - Training samples: 48000
2024-11-29 14:29:03,754 - __main__ - INFO - Validation samples: 12000
2024-11-29 14:29:03,755 - __main__ - INFO - Test samples: 10000
2024-11-29 14:29:03,889 - __main__ - INFO - Starting training for fashion_mnist


Epoch 1/12
Epoch 2/12
Epoch 3/12
Epoch 4/12

2024-11-29 14:29:49,169 - __main__ - INFO - Reached target retention accuracy at epoch 3


Epoch 5/12

2024-11-29 14:29:59,323 - __main__ - INFO - Reached target retention accuracy at epoch 4


Epoch 6/12

2024-11-29 14:30:10,012 - __main__ - INFO - Reached target retention accuracy at epoch 5


Epoch 7/12

2024-11-29 14:30:20,289 - __main__ - INFO - Reached target retention accuracy at epoch 6


Epoch 8/12

2024-11-29 14:30:32,754 - __main__ - INFO - Reached target retention accuracy at epoch 7


Epoch 9/12

2024-11-29 14:30:43,430 - __main__ - INFO - Reached target retention accuracy at epoch 8


Epoch 10/12

2024-11-29 14:30:55,177 - __main__ - INFO - Reached target retention accuracy at epoch 9


Epoch 11/12

2024-11-29 14:31:05,510 - __main__ - INFO - Reached target retention accuracy at epoch 10


Epoch 12/12

2024-11-29 14:31:17,895 - __main__ - INFO - Reached target retention accuracy at epoch 11




2024-11-29 14:31:18,012 - __main__ - INFO - Applying Gradient-based to fashion_mnist




2024-11-29 14:32:18,719 - __main__ - INFO - Completed Gradient-based for fashion_mnist
2024-11-29 14:32:18,733 - __main__ - INFO - Applying Influence Functions to fashion_mnist




2024-11-29 14:33:50,325 - __main__ - INFO - Completed Influence Functions for fashion_mnist
2024-11-29 14:33:50,337 - __main__ - INFO - Applying Hessian-Guided to fashion_mnist




2024-11-29 14:35:14,350 - __main__ - INFO - Completed Hessian-Guided for fashion_mnist
2024-11-29 14:35:14,376 - __main__ - INFO - Applying Combined Method to fashion_mnist




2024-11-29 14:36:17,419 - __main__ - INFO - Completed Combined Method for fashion_mnist




2024-11-29 14:36:19,275 - __main__ - INFO - 
Processing dataset: svhn
2024-11-29 14:36:19,276 - __main__ - INFO - Loading svhn dataset...
2024-11-29 14:36:19,279 - absl - INFO - Load dataset info from C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0
2024-11-29 14:36:19,294 - absl - INFO - Reusing dataset svhn_cropped (C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0)
2024-11-29 14:36:19,295 - absl - INFO - Constructing tf.data.Dataset svhn_cropped for split train, from C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0
2024-11-29 14:36:19,327 - absl - INFO - Load dataset info from C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0
2024-11-29 14:36:19,329 - absl - INFO - Reusing dataset svhn_cropped (C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0)
2024-11-29 14:36:19,329 - absl - INFO - Constructing tf.data.Dataset svhn_cropped for split test, from C:\Users\khoda\tensorflow_datasets\svhn_cropped\3.0.0
2024-11-29 14:36:29,408 - __main__ - INFO - SVHN loaded: train 

Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15
Epoch 13/15
Epoch 14/15
Epoch 15/15


2024-11-29 14:41:47,405 - __main__ - INFO - Applying Gradient-based to svhn




2024-11-29 14:42:37,541 - __main__ - INFO - Completed Gradient-based for svhn
2024-11-29 14:42:37,559 - __main__ - INFO - Applying Influence Functions to svhn




2024-11-29 14:44:13,858 - __main__ - INFO - Completed Influence Functions for svhn
2024-11-29 14:44:13,875 - __main__ - INFO - Applying Hessian-Guided to svhn




2024-11-29 14:45:44,149 - __main__ - INFO - Completed Hessian-Guided for svhn
2024-11-29 14:45:44,166 - __main__ - INFO - Applying Combined Method to svhn




2024-11-29 14:46:54,052 - __main__ - INFO - Completed Combined Method for svhn




2024-11-29 14:46:58,904 - __main__ - INFO - 
Processing dataset: tiny_imagenet
2024-11-29 14:46:58,905 - __main__ - INFO - Loading tiny_imagenet dataset...
2024-11-29 14:46:58,905 - __main__ - INFO - Getting TinyImageNet datasets...
2024-11-29 14:47:00,563 - __main__ - INFO - Converting TinyImageNet training data to numpy...
2024-11-29 14:47:25,565 - __main__ - INFO - Converting TinyImageNet validation data to numpy...
2024-11-29 14:47:28,078 - __main__ - INFO - TinyImageNet loaded: train shapes (100000, 64, 64, 3), (100000, 1)
2024-11-29 14:47:28,096 - __main__ - INFO - Number of classes: 200
2024-11-29 14:47:30,488 - __main__ - INFO - Created data generators successfully:
2024-11-29 14:47:30,489 - __main__ - INFO - Training samples: 80000
2024-11-29 14:47:30,490 - __main__ - INFO - Validation samples: 20000
2024-11-29 14:47:30,490 - __main__ - INFO - Test samples: 10000
2024-11-29 14:47:31,070 - __main__ - INFO - Starting training for tiny_imagenet


Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 19: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257.
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 26: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628.
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 29: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814.
Epoch 30/30


2024-11-29 18:15:01,904 - __main__ - INFO - Applying Gradient-based to tiny_imagenet




2024-11-29 18:15:33,593 - __main__ - INFO - Completed Gradient-based for tiny_imagenet
2024-11-29 18:15:33,665 - __main__ - INFO - Applying Influence Functions to tiny_imagenet
2024-11-29 18:15:43,872 - __main__ - ERROR - Error in improved_influence_functions: {{function_node __wrapped__ReluGrad_device_/job:localhost/replica:0/task:0/device:GPU:0}} OOM when allocating tensor with shape[32,64,64,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:ReluGrad]
2024-11-29 18:15:43,873 - __main__ - ERROR - Error applying Influence Functions to tiny_imagenet: {{function_node __wrapped__ReluGrad_device_/job:localhost/replica:0/task:0/device:GPU:0}} OOM when allocating tensor with shape[32,64,64,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:ReluGrad]
Traceback (most recent call last):
  File "C:\Users\khoda\AppData\Local\Temp\ipykernel_4752\998186864.py", line 1209, in main
    method_time = method['functio



2024-11-29 18:16:04,850 - __main__ - ERROR - Error in evaluate_unlearning: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
2024-11-29 18:16:04,850 - __main__ - ERROR - Error applying Hessian-Guided to tiny_imagenet: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
Traceback (most recent call last):
  File "C:\Users\khoda\AppData\Local\Temp\ipykernel_4752\998186864.py", line 1212, in main
    method_results = evaluate_unlearning(model, x_test, y_test, forget_class)
  File "C:\Users\khoda\AppData\Local\Temp\ipykernel_4752\998186864.py", line 629, in evaluate_unlearning
    retain_pred = model.predict(x_test[retain_idx], batch_size=32)
  File "C:\Users\khoda\anaconda3\envs\tf-env\lib\site-packages\keras\utils\traceback



2024-11-29 18:16:32,980 - __main__ - ERROR - Error in evaluate_unlearning: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
2024-11-29 18:16:32,981 - __main__ - ERROR - Error applying Post Unlearning Masking to tiny_imagenet: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
  mask = 1.0 / (1.0 + np.exp((weight_abs - threshold) / (threshold * 0.1)))




2024-11-29 18:16:45,203 - __main__ - ERROR - Error in evaluate_unlearning: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
2024-11-29 18:16:45,203 - __main__ - ERROR - Error applying Post Unlearning Inpainting to tiny_imagenet: Failed copying input tensor from /job:localhost/replica:0/task:0/device:CPU:0 to /job:localhost/replica:0/task:0/device:GPU:0 in order to run _EagerConst: Dst tensor is not initialized.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  table1[col] = table1[col].apply(lambda x: round(x * 100, 1))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead


Detailed Results Tables:

Table 1: CIFAR-10 Results
This table compares the performance of unlearning methods on the CIFAR-10 dataset.
It shows that Hessian-Guided Gradient Unlearning achieves the highest test and
retain accuracy, as well as the best privacy score, with relatively efficient runtime.
+----------------------------+-------------------+-------------------+-----------------+---------------+---------------------------------+
| Method                     |   Forget Accuracy |   Retain Accuracy |   Privacy Score |   Runtime (s) |   Post-processing Effectiveness |
|----------------------------+-------------------+-------------------+-----------------+---------------+---------------------------------|
| Gradient-based             |               0.9 |              43.8 |             5.2 |          50.1 |                             8.5 |
| Influence Functions        |               1.1 |              32.1 |             4.8 |          77.8 |                             8.4 |
| H