In [None]:
!pip install --upgrade diffusers[torch]
!pip install transformers[sentencepiece]

import torch

from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers import DDIMScheduler
from transformers import (
    CLIPFeatureExtractor,
    CLIPTextModel,
    CLIPTokenizer,
    MBart50TokenizerFast,
    MBartForConditionalGeneration,
    pipeline,
)

from translate import KoreanStableDiffusion

# Pretrain model & tokenizer all load

MBart_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-one-mmt", torch_dtype = torch.float16).to("cuda")
MBart_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype = torch.float16)
clip_text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype = torch.float16).to("cuda")
clip_feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-large-patch14")

unet_model = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda")
ddim_scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler", torch_dtype = torch.float16)
auto_encoder_kl = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder = "vae", torch_dtype=torch.float16).to("cuda")
stable_diffusion_safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder = "safety_checker", torch_dtype=torch.float16).to("cuda")

pipeline = KoreanStableDiffusion(
      translation_model= MBart_model,
      translation_tokenizer= MBart_tokenizer,
      vae= auto_encoder_kl,
      text_encoder= clip_text_model,
      tokenizer= clip_tokenizer,
      unet= unet_model,
      scheduler= ddim_scheduler,
      safety_checker= stable_diffusion_safety_checker,
      feature_extractor= clip_feature_extractor,
)

In [None]:
result = pipeline(prompt = "작고 귀여운 소 장난감")

In [None]:
result.images[0]