# 基于扩散模型的商品背景美化demo

### 一、环境准备

In [None]:
# 安装huggingface库以及第三方包
!pip install diffusers
!pip install transformers scipy ftfy accelerate
!pip install onnxruntime-gpu

# 解压模型权重
!unzip /root/autodl-fs/epoch19.zip /root/autodl-tmp/
!unzip /root/autodl-fs/reg.zip /root/autodl-tmp/

In [None]:
from PIL import Image
import torch
from reg import get_img_mask
from BGInpaintPipeline import BGInpaintPipeline

### 二、初始化模型管道

In [None]:
# load diffusion model weight file
sd2_id = "/root/autodl-tmp/epoch_19"
pipe = BGInpaintPipeline.from_pretrained(sd2_id, torch_dtype=torch.float16, device_map="auto", safety_checker=None)

### 三、定义处理函数

In [None]:
def resize_img_mask(img, mask, img_size=768):
    w, h = img.size
    if w>h:
        radio = img_size/w
    else:
        radio = img_size/h

    return img.resize((int(w*radio//8*8),int(h*radio//8*8)), Image.Resampling.BICUBIC), mask.resize((int(w*radio//8*8),int(h*radio//8*8)), Image.Resampling.BICUBIC)


def gen_img(prompt, negative_prompt, num_inference_steps, strength, guidance_scale, img, mask):
    img , mask = resize_img_mask(img, mask)
    # height = h, width=w
    w, h = img.size
    image = pipe(prompt, img, mask,height = h, width=w, num_inference_steps=num_inference_steps, strength=strength,
                   guidance_scale=guidance_scale,negative_prompt = negative_prompt).images[0]
    return image, mask


def processing_img(img, prompt, negative_prompt, num_inference_steps, guidance_scale, strength):
    _, mask = get_img_mask(img)
    img, mask = gen_img(prompt, negative_prompt, num_inference_steps, strength, guidance_scale, img, mask)
    return img, mask

### 四、定义参数

In [None]:
negative_prompt='face, human, badhand, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, normal quality, monochrome, grayscale watermark'
num_inference_steps = 50
guidance_scale = 7
strength = 1

In [None]:
# example
image = Image.open('./img_data/cup.jpg')
image

In [None]:
prompt='window desk'
img_after, mask = processing_img(image, prompt, negative_prompt, num_inference_steps, guidance_scale, strength)