# 安裝套件、匯入所需函式庫

In [None]:
!pip install torch torchvision matplotlib diffusers

# 匯入函式庫
import torch
from diffusers import AutoencoderKL
from torchvision import transforms
from torchvision.transforms import ToPILImage
from PIL import Image
import matplotlib.pyplot as plt


# 下載預訓練模型

In [None]:

# 加載預訓練的 VAE 模型
print("Loading VAE model from Hugging Face...")
model_id = "shi-labs/versatile-diffusion"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae")
vae.eval()
print("Model loaded successfully!")

# 圖片預處理
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 模型需要的輸入尺寸
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # 標準化至 [-1, 1]
])



# 讀取、顯示圖片

In [None]:
def load_image(image_path):
    """載入並預處理圖片"""
    image = Image.open(image_path).convert('RGB')
    original_size = image.size  # 記錄原圖尺寸
    return transform(image).unsqueeze(0), original_size

def show_image(tensor, title=""):
    """顯示圖片"""
    image = tensor.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    image = (image * 0.5 + 0.5)  # 反標準化至 [0, 1]
    plt.imshow(image)
    plt.axis('off')
    plt.title(title)
    plt.show()

# 載入圖片
content_image_path = "/content/orig.jpeg"  # 原始圖片路徑
style_image_path = "/content/starry_night.jpg"  # 風格圖片路徑

content_image, content_size = load_image(content_image_path)  # 載入原圖並記錄尺寸
style_image, _ = load_image(style_image_path)  # 載入風格圖

# 顯示原圖和風格圖
show_image(content_image, title="Content Image")
show_image(style_image, title="Style Image")



# 風格轉換、顯示風格轉換結果

In [None]:
# 編碼和解碼（風格融合）
with torch.no_grad():
    content_latents = vae.encode(content_image).latent_dist.sample()
    style_latents = vae.encode(style_image).latent_dist.sample()

    alpha = 0.5  # 風格比例
    mixed_latents = alpha * style_latents + (1 - alpha) * content_latents

    # 解碼生成
    generated_image = vae.decode(mixed_latents).sample

# 調整生成圖片的尺寸與原圖一致
resized_image = torch.nn.functional.interpolate(
    generated_image, size=(content_size[1], content_size[0]), mode="bilinear", align_corners=False
)

# 顯示生成結果
show_image(resized_image)

# 保存生成圖片
#output_image = ToPILImage()((resized_image.squeeze(0) * 0.5 + 0.5).clamp(0, 1))
#output_image.save("/content/generated_image.jpg")
#print("Generated image saved as 'generated_image.jpg'")