In [None]:
import torch
import torch.nn as nn
from datetime import datetime
import numpy as np
import os

from vae_slim import PCAPipeline, PCAModel

from huggingface_hub import login
from diffusers import (
    AutoencoderKL
)
from dotenv import load_dotenv
load_dotenv()

# env
Token = os.getenv("HUGGINGFACE_TOKEN", "")
cache_dir = os.getenv("HF_CACHE_DIR", "/root/autodl-tmp/cache_dir/huggingface/hub/")
ckpt_path = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q8_0.gguf"
login(token=Token)


model_path = "black-forest-labs/FLUX.1-dev"

print("loading vae from:", model_path)

vae = AutoencoderKL.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    subfolder="vae",
    torch_dtype=torch.bfloat16,
    cache_dir=cache_dir,
    proxies={'http': '127.0.0.1:7890'}
)
import pdb;
vae.to("cuda")

pca_components_add = "/workspace/DiffBrush/VIS/pca3d_pca_components.csv"
pca_mean_add = "/workspace/DiffBrush/VIS/pca3d_pca_mean.csv"

pca_model = PCAModel(
    pca_components_freeze=np.loadtxt(pca_components_add, delimiter=',', dtype=np.float16),  # [3, 16]
    pca_mean=np.loadtxt(pca_mean_add, delimiter=',', dtype=np.float16),  # [16]
    device="cuda"
)

from dataloader import image_dataloader


train_loader = image_dataloader(
    data_dir="eval_images",
    batch_size=4,  # 根据你的显存大小调整
    shuffle=True,
    num_workers=4  # 根据需要调整
)

generator=torch.manual_seed(int(42))
from tqdm import tqdm

device = "cuda"

pipe = PCAPipeline(
    vae=vae,
    pca_model=pca_model,
    residual_detail=True,  # 是否使用残差细节预测器
    device=device
)

pipe.eval()

vis_add = 'vis/'
if not os.path.exists(vis_add):
    os.makedirs(vis_add)

In [None]:
ckpt_path = "/workspace/VAE_SLIM/ckpt/pca_pipeline_20251120_105700.pth"
pipe.load(ckpt_path)

In [None]:
test_batch = next(iter(train_loader)).to(device).bfloat16()

In [None]:
# 计算模型参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() )
print("Model parameters of encoder:", count_parameters(pipe.vae.encoder))

print("Model parameters of decoder:", count_parameters(pipe.vae.decoder))

print("Model parameters of pca_predictor:", count_parameters(pipe.pca_predictor))

print("Model parameters of residual_detail_predictor:", count_parameters(pipe.residual_detail_predictor))


In [None]:
from utils import plot_images, batch_edge_analysis, comprehensive_rgb_edge_analysis, rgb_color_gradient_vis
plot_images(test_batch.float())

In [None]:
recon, z_pca_pred, x_recon = pipe.generate_for_comparsion(test_batch, generator=generator, x_recon=True)
plot_images(recon.float())

In [None]:
recon, z_pca_pred, x_recon = pipe.generate_for_comparsion(test_batch, generator=generator, x_recon=True)

    
plot_images(
    [test_batch[0].float(), recon[0].float(), x_recon[0].float()],
    titles=["Original", "Pred Recon", "PCA Recon"]
)

In [None]:
z_x, z_pca_pred, z_pred = pipe.diff_between_pca_and_latents(test_batch, generator=generator)

In [None]:
# analyze the distribution of the latent space
# z_x : B x 16 x 64 x 64， latents of original images
# z_pca_pred: B x 16 x 64 x 64, latents from PCA prediction
# z_pred: B x 16 x 64 x 64, latents from PCA prediction add high freq prediction

z_pca_x = pipe.pca_transform_batch(z_x)
z_pca_x_recon = pipe.pca_inverse_transform_batch(z_pca_x)

pca_diff = z_pca_pred - z_pca_x_recon

pca_diff.mean(), pca_diff.std()


In [None]:
diff.max()

In [None]:
batch_results = batch_edge_analysis(test_batch.float(), recon.float())

In [None]:
result = comprehensive_rgb_edge_analysis(test_batch[1].float(), recon[1].float(), save_path=vis_add + "rgb_analysis_between_x&pred_recon.png")

In [None]:
recon[1].float().max()

In [None]:
result = comprehensive_rgb_edge_analysis(recon[1].float(), test_batch[1].float(), save_path=vis_add + "rgb_analysis_between_pred_recon&x.png")

In [None]:
result = comprehensive_rgb_edge_analysis(test_batch[1].float(), x_recon[1].float(), save_path=vis_add + "rgb_analysis_between_x&pca_recon.png")

In [None]:
result = comprehensive_rgb_edge_analysis(recon[1].float(), x_recon[1].float(), save_path=vis_add + "rgb_analysis_between_pred_recon&pca_recon.png")

In [None]:
grad_map = rgb_color_gradient_vis(test_batch.float(), method='sobel', return_magnitude=True, normalized=True)

In [None]:
test_batch.max()

In [None]:
plot_images([grad_map[0].float(), grad_map[1].float(), grad_map[2].float(), grad_map[3].float()],
            titles=["Grad Map 1", "Grad Map 2", "Grad Map 3", "Grad Map 4"])