<a href="https://colab.research.google.com/github/ketanmewara/Image-Generation-and-Background-Replacement-using-Stable-Diffusion/blob/main/image_generation_and_background_replacement.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Generation and Background Replacement using Stable Diffusion

## Introduction

This project demonstrates how to leverage various image generation models to create and modify images using Stable Diffusion. The notebook covers:
- **Image Generation with Text-to-Image and Image-to-Image models.**
- **Using Stable Diffusion 3.5 and 3 models for image manipulation.**
- **Implementing object replacement and background modification through automatic masking.**
- **Exploring the capabilities of model quantization for efficiency.**

The notebook will guide you through model selection, image generation, and techniques to replace the background of images using AI models.

## Goals
- Explore the functionalities of different Stable Diffusion versions (3 and 3.5).
- Implement a background replacement feature using image-to-image generation.
- Apply adapter fusion (e.g., face/style adapters) for enhanced image generation.
- Demonstrate automatic mask generation for object replacement.


## Model Selection

#### Stable Diffusion Models Overview

- **Text-to-Image:** Generates images from textual descriptions.
- **Image-to-Image:** Modifies an image based on an input image with guidance from a text prompt.
- **Inpainting:** Allows for specific parts of an image to be modified (e.g., replace background).

### Key Differences:
- Flux and Efficiency: Stable Diffusion 3.5 has Flux, which allows for smoother integration of text and images and faster results.
- Image Quality: Stable Diffusion 3.5 generally offers higher quality and better fine-tuning options.
- Speed and Resource Usage: Stable Diffusion 3.5 is typically more efficient, allowing for faster image generation with fewer resources.

  In short, Stable Diffusion 3.5 is a more advanced and optimized version of Stable Diffusion 3, with improvements in speed, quality, and feature set.

## Adapter Fusion
Adapters such as face or style adapters can enhance the capabilities of Stable Diffusion models. These adapters allow for specific adjustments to images, like improving facial details or adjusting the artistic style of the output.

## Quantization
Quantization is used to optimize models for faster and more efficient inference, especially on hardware with limited resources. For Stable Diffusion models, this process can be done using diffusers for quantized weights.

## Object Replacement Feature
  1. Load the quantized model.
  2. Automatically generate a mask.
  3. Replace the object with a new one.
  4. Display the results.

## Load the quantized model.

In [None]:
!pip install torch torchvision supervision diffusers transformers bitsandbytes pillow

In [None]:
import torch
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from PIL import Image
from transformers import BitsAndBytesConfig

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(DEVICE)

In [None]:
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4")

In [None]:
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-1",
    quantization_config=nf4_config, # quantization config
    #torch_dtype=torch.float16,  # Use float16 for better performance
).to(DEVICE)  # Move model to GPU if available

## Automatically generate a mask

In [None]:
# Cloning the GitHub repository for Segment Anything model from Facebook Research
!git clone "https://github.com/facebookresearch/segment-anything"

In [None]:
# Downloading a pre-trained model file (weights) from the provided URL
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

In [None]:
# Importing the necessary library for image handling
from PIL import Image
from io import BytesIO
import requests
import os
import numpy as np

# Defining the URL of an image to process
image_path = "https://img.freepik.com/free-photo/adorable-cat-lifestyle_23-2151593320.jpg"

# Fetching the image from the URL and reading it into memory as bytes
image_bytes = BytesIO(requests.get(image_path).content)

# Opening the image using PIL from the downloaded bytes
image = Image.open(image_bytes)

# Converting the image into a NumPy array, then back to an image object
image = Image.fromarray(np.array(image))

# Resizing the image to a fixed 512x512 resolution
image = image.resize((512, 512))

In [None]:
model_type = "vit_b"
model_path = '/content/sam_vit_b_01ec64.pth'
print(model_path, "exist:", os.path.isfile(model_path))

In [None]:
import sys
sys.path.append('/content/segment-anything')

In [None]:
# Importing necessary components from the Segment Anything module
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# Initializing the model using a checkpoint (weights) and moving it to the desired device (GPU/CPU)
sam = sam_model_registry[model_type](checkpoint=model_path).to(device=DEVICE)

# Creating a mask generator object from the initialized model
mask_generator = SamAutomaticMaskGenerator(sam)

# Generating the segmentation results from the image (as a NumPy array)
result = mask_generator.generate(np.array(image))

# Sorting the result by area and extracting segmentation data, bounding box information, point coordinates
segmentations = [segment['segmentation'] for segment in sorted(result, key=lambda x: x['area'], reverse=True)]
bboxs = [bbox['bbox'] for bbox in sorted(result, key=lambda x: x['area'], reverse=True)]
points = [point['point_coords'] for point in sorted(result, key=lambda x: x['area'], reverse=True)]

print(bboxs)
print(points)
# Checking the number of segmentations (masks) generated
len(segmentations)

In [None]:
# Importing a library for visualization
import supervision as  sv

# Displaying the first 10 segmentations as images in a grid of size (2 rows, 5 columns)
sv.plot_images_grid(
    images=segmentations[0:10],
    grid_size=(2,5)
)

In [None]:
import numpy as np
# Inverting the segmentation mask at index 1 in the segmentations list
# The segmentation masks are binary, where `True` represents the object and `False` represents the background.
# np.logical_not() inverts the boolean values, so `True` becomes `False` and `False` becomes `True`.
mask = np.logical_not(segmentations[0])
mask = Image.fromarray(mask).convert('L')

In [None]:
mask

## Replace the object with a new one

In [None]:
def inpaint_with_mask(image, mask, prompt):
  image = Image.fromarray(np.array(image))
  mask = Image.fromarray(np.array(mask))

  image = image.resize((512, 512))
  mask = mask.resize((512, 512))

  inpainted_image = pipe(prompt=prompt, image=image, mask_image=mask).images[0]

  return inpainted_image

In [None]:
prompt = "background beach"
background_change_image = inpaint_with_mask(image, mask, prompt)

## Display the results.

In [None]:
# Plot the original image, mask image, and background replacement image
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Display the images
axes[0].imshow(image)
axes[0].set_title('Original Image')
axes[0].axis('off')

axes[1].imshow(mask, cmap='gray')
axes[1].set_title('Mask Image')
axes[1].axis('off')

axes[2].imshow(background_change_image)
axes[2].set_title('After Background Replacement Image')
axes[2].axis('off')

# Show the plot
plt.tight_layout()
plt.show()