pipline combine 2 modle

In [None]:
import torch
from PIL import Image
from transformers import CLIPTokenizer
from Chain_Of_Thought.COT_text_gen import ImprovedChainOfThought, load_improved_model
from diffusion_model.pipeline import generate
from diffusion_model.diffusion import Diffusion
from diffusion_model.VAen_decoder import encoder as VAEEncoder, decoder as VAEDecoder
from diffusion_model.clip import CLIP

class TextToImagePipeline:
    def __init__(
        self,
        text_model: ImprovedChainOfThought,
        diffusion_model: Diffusion,
        encoder_model: VAEEncoder,
        decoder_model: VAEDecoder,
        clip_model: CLIP,
        tokenizer: CLIPTokenizer,
        device: str = "auto"
    ):
        if device == "auto":
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            device = torch.device(device)
        self.device = device
        self.text_model = text_model.to(device)
        self.models = {
            "diffusion": diffusion_model.to(device),
            "encoder": encoder_model.to(device),
            "decoder": decoder_model.to(device),
            "clip": clip_model.to(device)
        }
        self.tokenizer = tokenizer

    def __call__(
        self,
        user_prompt: str,
        uncond_prompt: str = "",
        cot_type: str = "step_by_step",
        cfg_scale: float = 7.5,
        strength: float = 0.8,
        n_steps: int = 50,
        seed: int = None
    ):
        if not isinstance(user_prompt, str) or not user_prompt.strip():
            raise ValueError("User prompt must be a non-empty string")
        refined_prompt = self.text_model.generate_with_cot(
            question=user_prompt,
            cot_type=cot_type,
            max_length=150,
            temperature=0.7
        ).strip()
        print(f"Refined Prompt: {refined_prompt}")
        if len(self.tokenizer(refined_prompt, return_tensors="pt").input_ids[0]) > 77:
            print("Warning: Refined prompt too long, truncating")
            refined_prompt = refined_prompt[:100]
        image = generate(
            prompt=refined_prompt,
            uncond_prompt=uncond_prompt,
            input_image=None,
            strength=strength,
            do_cfg=True,
            cfg_scale=cfg_scale,
            sampler_name="ddpm",
            n_inference_steps=n_steps,
            seed=seed,
            models=self.models,
            device=self.device,
            idle_device="cpu" if self.device != torch.device("cpu") else None,
            tokenizer=self.tokenizer
        )
        return image

if __name__ == "__main__":
    # Load tokenizer
    try:
        tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
    except Exception as e:
        raise RuntimeError(f"Failed to load tokenizer: {e}")

    # Load CoT model
    text_model, vocab = load_improved_model("model.pth", device="auto")
    if text_model is None:
        raise RuntimeError("Failed to load CoT model")

    # Load other models
    diffusion_model = Diffusion()
    encoder = VAEEncoder()
    decoder = VAEDecoder()
    clip_model = CLIP()

    # Load pretrained weights (replace with actual paths)
    try:
        diffusion_model.load_state_dict(torch.load("diffusion_model.pth", map_location="cpu"))
        encoder.load_state_dict(torch.load("vae_encoder.pth", map_location="cpu"))
        decoder.load_state_dict(torch.load("vae_decoder.pth", map_location="cpu"))
        clip_model.load_state_dict(torch.load("clip_model.pth", map_location="cpu"))
    except FileNotFoundError as e:
        print(f"Warning: Pretrained weights not found ({e}). Using uninitialized models.")

    # Create pipeline
    pipe = TextToImagePipeline(
        text_model=text_model,
        diffusion_model=diffusion_model,
        encoder_model=encoder,
        decoder_model=decoder,
        clip_model=clip_model,
        tokenizer=tokenizer,
        device="auto"
    )

    # Generate image
    try:
        img = pipe(
            user_prompt='A DG EAT HOTDOG',
            seed=42,
            cot_type="step_by_step",
            n_steps=50
        )
        Image.fromarray(img).save("output.png")
        print("Image saved as output.png")
    except Exception as e:
        print(f"Error during image generation: {e}")

ModuleNotFoundError: No module named 'VAen_decoder'