In [1]:
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler, UNet2DConditionModel
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class CustomStableDiffusion:
    def __init__(self, model_id="runwayml/stable-diffusion-v1-5", unet_path=None, device=None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

        # –ó–∞–≥—Ä—É–∑–∫–∞ –ø–∞–π–ø–ª–∞–π–Ω–∞
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
            safety_checker=None  # –ú–æ–∂–Ω–æ –æ—Ç–∫–ª—é—á–∏—Ç—å, –µ—Å–ª–∏ –Ω–µ –Ω—É–∂–µ–Ω
        ).to(self.device)

        # –ó–∞–º–µ–Ω—è–µ–º UNet –Ω–∞ –∫–∞—Å—Ç–æ–º–Ω—ã–π (–µ—Å–ª–∏ —É–∫–∞–∑–∞–Ω –ø—É—Ç—å)
        if unet_path:
            print(f"üîÅ Loading custom UNet from: {unet_path}")
            unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
            state_dict = torch.load(unet_path, map_location=self.device)
            unet.load_state_dict(state_dict)
            unet.to(self.device)
            self.pipe.unet = unet

    def generate(self, prompt, output_path="output.png", height=512, width=512, num_inference_steps=50, guidance_scale=7.5):
        with torch.autocast(self.device) if self.device == "cuda" else torch.no_grad():
            image = self.pipe(
                prompt,
                height=height,
                width=width,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale
            ).images[0]
            return image


In [None]:
generator = CustomStableDiffusion(
    model_id="runwayml/stable-diffusion-v1-5",
    unet_path="checkpoints/unet_final.pt"
)

image = generator.generate("a fantasy castle surrounded by fog and mountains")
image