<a href="https://colab.research.google.com/github/fjadidi2001/Image_Inpaint/blob/main/Image_inpaint_New_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/badasstechie/celebahq-resized-256x256?dataset_version_number=1...


100%|██████████| 283M/283M [00:14<00:00, 21.1MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/badasstechie/celebahq-resized-256x256/versions/1


In [9]:
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, Model
import cv2
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Limit GPU memory usage
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        tf.config.experimental.set_virtual_device_configuration(
            gpus[0],
            [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
    except RuntimeError as e:
        print(e)

class ImageInpainting:
    def __init__(self, data_path, img_size=256, batch_size=8):
        self.data_path = data_path
        self.img_size = img_size
        self.batch_size = batch_size

    def load_dataset(self):
        """Load and prepare the dataset"""
        images = []
        for img_path in tqdm(os.listdir(self.data_path)[:1000]):  # Limiting to 1000 images for memory
            if img_path.endswith(('.jpg', '.png')):
                img = cv2.imread(os.path.join(self.data_path, img_path))
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = cv2.resize(img, (self.img_size, self.img_size))
                images.append(img)
        return np.array(images)

    def perform_eda(self, images):
        """Perform exploratory data analysis"""
        plt.figure(figsize=(15, 5))
        for i in range(5):
            plt.subplot(1, 5, i+1)
            plt.imshow(images[i])
            plt.axis('off')
        plt.show()

        print(f"Dataset shape: {images.shape}")
        print(f"Data type: {images.dtype}")
        print(f"Min value: {images.min()}, Max value: {images.max()}")

    def preprocess_images(self, images):
        """Preprocess the images"""
        # Normalize to [-1, 1]
        images = (images.astype('float32') - 127.5) / 127.5
        return images

    def create_masks(self, shape):
        """Create random masks for inpainting"""
        masks = []
        for _ in range(shape[0]):
            mask = np.ones((self.img_size, self.img_size, 1))
            # Random rectangular masks
            y1, x1 = np.random.randint(0, self.img_size-64, 2)
            mask[y1:y1+64, x1:x1+64] = 0
            masks.append(mask)
        return np.array(masks)

    def build_unet(self):
        """Build U-Net model"""
        def conv_block(x, filters, kernel_size=3):
            x = layers.Conv2D(filters, kernel_size, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.ReLU()(x)
            return x

        inputs = layers.Input((self.img_size, self.img_size, 3))
        mask = layers.Input((self.img_size, self.img_size, 1))

        # Concatenate mask with input
        x = layers.Concatenate()([inputs, mask])

        # Encoder
        e1 = conv_block(x, 64)
        e2 = conv_block(layers.MaxPooling2D()(e1), 128)
        e3 = conv_block(layers.MaxPooling2D()(e2), 256)

        # Bridge
        b = conv_block(layers.MaxPooling2D()(e3), 512)

        # Decoder
        d3 = conv_block(layers.UpSampling2D()(b), 256)
        d3 = layers.Concatenate()([d3, e3])

        d2 = conv_block(layers.UpSampling2D()(d3), 128)
        d2 = layers.Concatenate()([d2, e2])

        d1 = conv_block(layers.UpSampling2D()(d2), 64)
        d1 = layers.Concatenate()([d1, e1])

        outputs = layers.Conv2D(3, 1, activation='tanh')(d1)

        return Model([inputs, mask], outputs)

    def build_hint(self):
        """Build simplified HINT model for limited resources"""
        def transformer_block(x, filters):
            # Self-attention
            attention = layers.MultiHeadAttention(
                num_heads=4, key_dim=filters//4)(x, x, x)
            x = layers.Add()([x, attention])
            x = layers.LayerNormalization()(x)

            # FFN
            ffn = layers.Dense(filters*2)(x)
            ffn = layers.ReLU()(ffn)
            ffn = layers.Dense(filters)(ffn)

            x = layers.Add()([x, ffn])
            x = layers.LayerNormalization()(x)
            return x

        inputs = layers.Input((self.img_size, self.img_size, 3))
        mask = layers.Input((self.img_size, self.img_size, 1))

        x = layers.Concatenate()([inputs, mask])

        # Simplified architecture for limited resources
        x = conv_block(x, 64)
        x = layers.Reshape((self.img_size * self.img_size, 64))(x)
        x = transformer_block(x, 64)
        x = layers.Reshape((self.img_size, self.img_size, 64))(x)

        outputs = layers.Conv2D(3, 1, activation='tanh')(x)

        return Model([inputs, mask], outputs)

    def combined_model(self):
        """Combine U-Net and HINT models"""
        inputs = layers.Input((self.img_size, self.img_size, 3))
        mask = layers.Input((self.img_size, self.img_size, 1))

        unet = self.build_unet()
        hint = self.build_hint()

        unet_out = unet([inputs, mask])
        hint_out = hint([inputs, mask])

        # Weighted combination
        alpha = 0.7  # Weight for U-Net
        outputs = layers.Lambda(
            lambda x: alpha * x[0] + (1-alpha) * x[1])([unet_out, hint_out])

        return Model([inputs, mask], outputs)

    def evaluate_model(self, model, test_images, test_masks):
        """Evaluate model using various metrics"""
        predictions = model.predict([test_images, test_masks])

        # Calculate metrics
        mse = np.mean((test_images - predictions) ** 2)
        psnr = 20 * np.log10(2.0 / np.sqrt(mse))  # Assuming normalized [-1, 1]

        # Visualize results
        plt.figure(figsize=(15, 5))
        for i in range(3):
            plt.subplot(1, 3, i*3 + 1)
            plt.imshow((test_images[i] + 1) / 2)
            plt.title('Original')
            plt.axis('off')

            plt.subplot(1, 3, i*3 + 2)
            masked = test_images[i] * test_masks[i]
            plt.imshow((masked + 1) / 2)
            plt.title('Masked')
            plt.axis('off')

            plt.subplot(1, 3, i*3 + 3)
            plt.imshow((predictions[i] + 1) / 2)
            plt.title('Inpainted')
            plt.axis('off')

        plt.show()
        print(f"MSE: {mse:.4f}")
        print(f"PSNR: {psnr:.2f} dB")

    def train(self, epochs=10):
        """Train the model"""
        # Load and prepare data
        print("Loading dataset...")
        images = self.load_dataset()
        self.perform_eda(images)

        print("\nPreprocessing images...")
        images = self.preprocess_images(images)
        masks = self.create_masks(images.shape)

        # Split dataset
        train_images, test_images, train_masks, test_masks = train_test_split(
            images, masks, test_size=0.2, random_state=42)

        # Build and compile model
        print("\nBuilding model...")
        model = self.combined_model()
        model.compile(
            optimizer=tf.keras.optimizers.Adam(1e-4),
            loss='mse',
            metrics=['mae']
        )

        # Train
        print("\nTraining model...")
        history = model.fit(
            [train_images, train_masks],
            train_images,
            batch_size=self.batch_size,
            epochs=epochs,
            validation_split=0.2
        )

        # Evaluate
        print("\nEvaluating model...")
        self.evaluate_model(model, test_images, test_masks)

        return model, history

# Usage example
if __name__ == "__main__":
    path = kagglehub.dataset_download("badasstechie/celebahq-resized-256x256")

    inpainting = ImageInpainting("path_to_dataset")
    model, history = inpainting.train()

ValueError: Cannot set memory growth on device when virtual devices configured