Modeli fp16 yüklediğimizde hız ne durumda?

In [None]:
import PIL
import torch
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image, ImageDraw, ImageFilter
from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler

In [None]:
pipeline = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to("cuda", torch.float16)

In [None]:
tile_size = 768
overlap = 256
brush_size = 70
num_inference_steps = 50
strength = 0.99
prompt = " "
generator = torch.Generator().manual_seed(231874959)

In [None]:
def generate_mask(image, brush_size):
    left = 0
    right = image.width
    top = (image.height * 0.5) - brush_size
    bottom = (image.height * 0.5) + brush_size

    mask = Image.new('L', (image.width, image.height), 0)
    draw = ImageDraw.Draw(mask)

    draw.rounded_rectangle([left, top, right, bottom], fill=255, radius=70)
    return mask

In [None]:
input_image = Image.open('../1_media/input_images/cat.jpg').convert('RGB')
input_image = input_image.resize((1280,768))
input_mask = generate_mask(input_image, brush_size)

plt.imshow(input_image)
plt.show()

plt.imshow(input_mask)
plt.show()

input_image = torch.from_numpy(np.array(input_image, np.float32).transpose(2,0,1))
input_image = input_image / 127.5 - 1.0

input_mask = torch.from_numpy(np.array(input_mask,np.float32)).unsqueeze(0)
input_mask = input_mask / 255.0

print("input_image.shape: ", input_image.shape)
print("input_mask.shape: ", input_mask.shape)
print("input_image.dtype: ", input_image.dtype)
print("input_mask.dtype: ", input_mask.dtype)

In [None]:
tile_coords = [(0, 0, 768, 768), (512, 0, 1280, 768)]

from PIL import ImageDraw
p_im = Image.fromarray(((input_image.numpy() + 1 ) * 127.5).transpose(1,2,0).astype(np.uint8))
draw = ImageDraw.Draw(p_im)
for i in range(len(tile_coords)):
    draw.rectangle(tile_coords[i], outline='red', width=10)

    print("tile_width: ", tile_coords[i][2] - tile_coords[i][0])
    print("tile_height: ", tile_coords[i][3] - tile_coords[i][1])

p_im.save('p_im.png')
plt.imshow(p_im)
plt.show()

In [None]:
def apply_gauss_filter(mask_image: PIL.Image,
                       kernel_size: int = 15,
                       num_of_fill_pixels_x: int = 15,
                       num_of_fill_pixels_y: int = 15):

    np_mask_image = np.array(mask_image)
    np_mask_image[:,:num_of_fill_pixels_x] = 0
    np_mask_image[:num_of_fill_pixels_y,:] = 0
    pil_mask_image = Image.fromarray((np_mask_image * 255.0).astype(np.uint8))
    blurred_mask_image = pil_mask_image.filter(ImageFilter.GaussianBlur(radius=kernel_size))
    blurred_mask_image = np.array(blurred_mask_image)
    blurred_mask_image = torch.from_numpy(blurred_mask_image).unsqueeze(0).float()
    blurred_mask_image = blurred_mask_image / 255.0
    return blurred_mask_image

In [None]:
for tile_coord in tile_coords:
    tile = input_image[:, tile_coord[1]:tile_coord[3], tile_coord[0]:tile_coord[2]]
    tile_mask = input_mask[:, tile_coord[1]:tile_coord[3], tile_coord[0]:tile_coord[2]]

    tile_output = pipeline(image=tile, 
                            mask_image=tile_mask, 
                            width=tile_size,
                            height=tile_size,
                            prompt=prompt, 
                            num_inference_steps=num_inference_steps, 
                            strength=strength, 
                            generator=generator,
                            output_type="pt").images[0]

    m_g = apply_gauss_filter(mask_image=tile_mask.squeeze(0), kernel_size=15)
    
    tile_output = tile_output.cpu()
    
    output = (tile_output * 2.0 - 1) * m_g + tile * (1 - m_g)

    input_image[:, tile_coord[1]:tile_coord[3], tile_coord[0]:tile_coord[2]] = output
    

In [None]:
image_pt = (input_image.permute(1,2,0) + 1) * 127.5
image_pt = image_pt.cpu().numpy()
image_pt = image_pt.astype(np.uint8)
shifted_image = Image.fromarray(image_pt)
shifted_image.save('output.png')
plt.imshow(shifted_image)
plt.show()