# Test Load Encoder and Decoder

In [None]:
import os
import torch
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, CLIPTextModel
from PIL import Image
import pandas as pd
from tqdm import tqdm
import t2v_metrics



def load_pipeline(unet_ckpt, text_encoder_ckpt, device):
    """
    Load the base Stable Diffusion pipeline on the specified device.
    If a checkpoint is provided, replace the UNet.
    If a text encoder checkpoint is provided, replace the text encoder.
    """
    pipe = StableDiffusionPipeline.from_pretrained(base_model_name, torch_dtype=torch.float16)
    pipe = pipe.to(device)
    pipe.safety_checker = None  # disable safety checker if desired
    
    if unet_ckpt is not None:
        unet = UNet2DConditionModel.from_pretrained(
            unet_ckpt, subfolder="unet", torch_dtype=torch.float16
        ).to(device)
        pipe.unet = unet
        
    if text_encoder_ckpt is not None:
        text_encoder = CLIPTextModel.from_pretrained(
            text_encoder_ckpt, subfolder="text_encoder", torch_dtype=torch.float16
        ).to(device)
        pipe.text_encoder = text_encoder
        
    return pipe


