# Background replacement

Define input image and the mask

In [None]:
from PIL import Image
import os, sys

input_path = '../inputs/background-replacement/camera.png'
input_image = Image.open(input_path)

current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(os.path.dirname(current_dir), '..'))
sys.path.append(os.path.join(parent_dir, 'code'))

In [None]:
from matplotlib import pyplot as plt

plt.imshow(input_image)
plt.axis('off')
plt.show()

Define the number of steps

In [None]:
steps = 50

Define the prompt for the background

In [None]:
prompt = 'savanna with wild animals'

Extract foreground image and mask using RMBG-1.4 model

In [None]:
from extract_foreground import extract_foreground_image, extract_foreground_mask

forground_image = extract_foreground_image(input_image)
foreground_mask = extract_foreground_mask(forground_image)

In [None]:
plt.imshow(foreground_mask)
plt.axis('off')
plt.show()

Load the model

In [None]:
from diffusers import StableDiffusionInpaintPipeline
import torch

pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16,
    )
pipe.to("cuda")

Apply dilation to the mask

In [None]:
from operations_image import expand_white_areas_outpainting
import numpy as np

size = np.array(input_image).shape[:2]
    
# reverse the mask for outpainting
reversed_mask_array = 255 - np.array(foreground_mask)
reversed_mask_array = Image.fromarray(reversed_mask_array)
reversed_mask_array = expand_white_areas_outpainting(reversed_mask_array, 2) #optional

In [None]:
plt.imshow(reversed_mask_array)
plt.axis('off')
plt.show()

Infer the model to the input image and obtained mask

In [None]:
resized_input_image = input_image.resize((512, 512))
resized_reversed_mask_array = reversed_mask_array.resize((512, 512))
output_image = pipe(prompt=prompt, 
                    image=resized_input_image, 
                    mask_image=resized_reversed_mask_array,  
                    guidance_scale=7.5, 
                    num_inference_steps=steps).images[0]
resized_output_image = output_image.resize((size[1], size[0]))

In [None]:
plt.imshow(resized_output_image)
plt.axis('off')
plt.show()