In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

# --- 1. 모델 클래스 정의 (이전과 동일) ---
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels))
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.silu = nn.SiLU()
    def forward(self, x): return self.silu(self.block(x) + self.shortcut(x))

class EncoderSuperDeep(nn.Module):
    def __init__(self, in_channels=3, base_channels=128, latent_channels=1):
        super().__init__()
        self.encoder = nn.Sequential(nn.Conv2d(in_channels, base_channels, 3, 1, 1), ResBlock(base_channels, base_channels), nn.Conv2d(base_channels, base_channels*2, 3, 2, 1), ResBlock(base_channels*2, base_channels*2), nn.Conv2d(base_channels*2, base_channels*4, 3, 2, 1), ResBlock(base_channels*4, base_channels*4), nn.Conv2d(base_channels*4, base_channels*8, 3, 2, 1), ResBlock(base_channels*8, base_channels*8), nn.Conv2d(base_channels*8, base_channels*16, 3, 2, 1), ResBlock(base_channels*16, base_channels*16), nn.Conv2d(base_channels*16, 2 * latent_channels, 3, 1, 1))
    def forward(self, x):
        x = self.encoder(x)
        mu, log_var = torch.chunk(x, 2, dim=1)
        return mu, log_var

class DecoderSuperDeep(nn.Module):
    def __init__(self, out_channels=3, base_channels=128, latent_channels=1):
        super().__init__()
        self.decoder = nn.Sequential(nn.Conv2d(latent_channels, base_channels*16, 3, 1, 1), ResBlock(base_channels*16, base_channels*16), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*16, base_channels*8, 3, 1, 1), ResBlock(base_channels*8, base_channels*8), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*8, base_channels*4, 3, 1, 1), ResBlock(base_channels*4, base_channels*4), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*4, base_channels*2, 3, 1, 1), ResBlock(base_channels*2, base_channels*2), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*2, base_channels, 3, 1, 1), ResBlock(base_channels, base_channels), nn.Conv2d(base_channels, out_channels, 3, 1, 1), nn.Tanh())
    def forward(self, z): return self.decoder(z)

class InterpolationNet(nn.Module):
    def __init__(self, latent_dim=256, hidden_dim=512):
        super().__init__()
        input_dim = latent_dim * 2 + 1
        self.network = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim // 2, latent_dim))
    def forward(self, z_a, z_b, alpha):
        net_input = torch.cat([z_a, z_b, alpha], dim=1)
        return self.network(net_input)

class PathInterpolationDataset(Dataset):
    def __init__(self, latent_vectors_path, num_samples, num_alphas_per_pair=10):
        self.latent_vectors = torch.load(latent_vectors_path).to('cpu')
        self.num_latents = len(self.latent_vectors)
        self.num_samples = num_samples
        self.num_alphas = num_alphas_per_pair
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        idx1, idx2 = torch.randperm(self.num_latents)[:2]
        z_a = self.latent_vectors[idx1]
        z_b = self.latent_vectors[idx2]
        alphas = torch.rand(self.num_alphas, 1)
        return z_a, z_b, alphas, idx1, idx2
        
# --- 2. 기본 설정 및 경로 ---
print("Step 1: 기본 설정 로드")
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
AE_CHECKPOINT_DIR = '/home/nas/data/YMG/superdeep_ae/checkpoints/'
ENCODER_PATH = os.path.join(AE_CHECKPOINT_DIR, 'encoder_superdeep_best.pth')
DECODER_PATH = os.path.join(AE_CHECKPOINT_DIR, 'decoder_superdeep_best.pth')
LATENT_VECTORS_PATH = '/home/nas/data/YMG/superdeep_ae/my_checkpoints/real_latent_vectors_20k.pt'
OUTPUT_DIR = '/home/nas/data/YMG/superdeep_ae/interpolation_network_final_sequential/'
os.makedirs(OUTPUT_DIR, exist_ok=True)
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, 'interpolation_net_best.pth')

# --- 하이퍼파라미터 ---
LATENT_DIM = 256
HIDDEN_DIM = 512
BATCH_SIZE = 64 # 원래대로 64 또는 32 유지
EPOCHS = 100
LR = 1e-4
SAMPLES_PER_EPOCH = 20000
NUM_ALPHAS_PER_PAIR = 10
ATTRACTION_WEIGHT = 0.1

# --- 3. 모델 로드 및 설정 ---
# (이전과 동일)
print("Step 2: 모델 로드 및 설정")
encoder = EncoderSuperDeep(base_channels=128).to(DEVICE); encoder.eval()
decoder = DecoderSuperDeep(base_channels=128).to(DEVICE); decoder.eval()
for param in encoder.parameters(): param.requires_grad = False
for param in decoder.parameters(): param.requires_grad = False
print("VAE 모델을 로드하고 가중치를 동결했습니다.")

interpolation_net = InterpolationNet(latent_dim=LATENT_DIM, hidden_dim=HIDDEN_DIM).to(DEVICE)
optimizer = optim.Adam(interpolation_net.parameters(), lr=LR)
print("보간 네트워크를 초기화했습니다.")

# --- 4. 데이터 로더 준비 ---
# (이전과 동일)
print("Step 3: 데이터 로더 준비")
train_dataset = PathInterpolationDataset(LATENT_VECTORS_PATH, SAMPLES_PER_EPOCH, NUM_ALPHAS_PER_PAIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# --- 5. 학습 루프 ---
print("Step 4: 학습 시작")
best_loss = float('inf')
for epoch in range(EPOCHS):
    interpolation_net.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for z_a, z_b, alphas, idx1, idx2 in pbar:
        z_a, z_b, alphas = z_a.to(DEVICE), z_b.to(DEVICE), alphas.to(DEVICE)
        
        # ... (텐서 변환 로직 동일) ...
        z_a_flat = z_a.unsqueeze(1).expand(-1, NUM_ALPHAS_PER_PAIR, -1).reshape(-1, LATENT_DIM)
        z_b_flat = z_b.unsqueeze(1).expand(-1, NUM_ALPHAS_PER_PAIR, -1).reshape(-1, LATENT_DIM)
        alphas_flat = alphas.reshape(-1, 1)
        
        optimizer.zero_grad()
        z_interp = interpolation_net(z_a_flat, z_b_flat, alphas_flat)
        
        # <<< --- 수정: 재투영을 작은 덩어리로 나누어 실행 --- >>>
        z_hat_interp_list = []
        # 예를 들어, 한 번에 64개씩 나누어 처리
        chunk_size = 64 
        with torch.no_grad():
            for i in range(0, z_interp.size(0), chunk_size):
                # z_interp를 chunk_size 만큼씩 잘라냄
                z_chunk = z_interp[i:i+chunk_size]
                z_chunk_reshaped = z_chunk.view(-1, 1, 16, 16)
                
                # 작은 덩어리만 VAE에 통과
                recon_chunk = decoder(z_chunk_reshaped)
                z_hat_chunk_reshaped, _ = encoder(recon_chunk)
                
                z_hat_chunk = z_hat_chunk_reshaped.view(-1, LATENT_DIM)
                z_hat_interp_list.append(z_hat_chunk)
        
        # 나누어 처리된 결과들을 다시 하나로 합침
        z_hat_interp = torch.cat(z_hat_interp_list, dim=0)
        # <<< ------------------------------------------- >>>
        
        # 손실 계산 및 역전파 (이하 동일)
        loss_reprojection = nn.functional.mse_loss(z_interp, z_hat_interp)
        with torch.no_grad():
            z_linear = (1 - alphas_flat) * z_a_flat + alphas_flat * z_b_flat
        loss_attraction = nn.functional.mse_loss(z_interp, z_linear)
        total_loss_batch = loss_reprojection + ATTRACTION_WEIGHT * loss_attraction
        total_loss_batch.backward()
        optimizer.step()
        
        total_loss += total_loss_batch.item()
        pbar.set_postfix({"Total Loss": f"{total_loss_batch.item():.6f}"})
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} 완료 | Avg Total Loss: {avg_loss:.6f}")

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(interpolation_net.state_dict(), CHECKPOINT_PATH)
        # 이 에포크의 쌍 인덱스를 저장하려면 current_epoch_pairs 로직을 다시 추가해야 함
        print(f"New best model found! Loss: {best_loss:.6f}. Checkpoint saved.")

print("\n학습이 모두 완료되었습니다.")

Step 1: 기본 설정 로드
Step 2: 모델 로드 및 설정
VAE 모델을 로드하고 가중치를 동결했습니다.
보간 네트워크를 초기화했습니다.
Step 3: 데이터 로더 준비
Step 4: 학습 시작


  self.latent_vectors = torch.load(latent_vectors_path).to('cpu')


Epoch 1/100:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1 완료 | Avg Total Loss: 0.098884
New best model found! Loss: 0.098884. Checkpoint saved.


Epoch 2/100:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2 완료 | Avg Total Loss: 0.071538
New best model found! Loss: 0.071538. Checkpoint saved.


Epoch 3/100:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3 완료 | Avg Total Loss: 0.070746
New best model found! Loss: 0.070746. Checkpoint saved.


Epoch 4/100:   0%|          | 0/313 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [2]:
import os
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import matplotlib.pyplot as plt
import random

# --- 1. 모든 필요한 클래스 정의 ---
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels), nn.SiLU(), nn.Conv2d(out_channels, out_channels, 3, 1, 1), nn.BatchNorm2d(out_channels))
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.silu = nn.SiLU()
    def forward(self, x): return self.silu(self.block(x) + self.shortcut(x))

class EncoderSuperDeep(nn.Module):
    def __init__(self, in_channels=3, base_channels=128, latent_channels=1):
        super().__init__()
        self.encoder = nn.Sequential(nn.Conv2d(in_channels, base_channels, 3, 1, 1), ResBlock(base_channels, base_channels), nn.Conv2d(base_channels, base_channels*2, 3, 2, 1), ResBlock(base_channels*2, base_channels*2), nn.Conv2d(base_channels*2, base_channels*4, 3, 2, 1), ResBlock(base_channels*4, base_channels*4), nn.Conv2d(base_channels*4, base_channels*8, 3, 2, 1), ResBlock(base_channels*8, base_channels*8), nn.Conv2d(base_channels*8, base_channels*16, 3, 2, 1), ResBlock(base_channels*16, base_channels*16), nn.Conv2d(base_channels*16, 2 * latent_channels, 3, 1, 1))
    def forward(self, x):
        x = self.encoder(x)
        mu, log_var = torch.chunk(x, 2, dim=1)
        return mu, log_var

class DecoderSuperDeep(nn.Module):
    def __init__(self, out_channels=3, base_channels=128, latent_channels=1):
        super().__init__()
        self.decoder = nn.Sequential(nn.Conv2d(latent_channels, base_channels*16, 3, 1, 1), ResBlock(base_channels*16, base_channels*16), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*16, base_channels*8, 3, 1, 1), ResBlock(base_channels*8, base_channels*8), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*8, base_channels*4, 3, 1, 1), ResBlock(base_channels*4, base_channels*4), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*4, base_channels*2, 3, 1, 1), ResBlock(base_channels*2, base_channels*2), nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(base_channels*2, base_channels, 3, 1, 1), ResBlock(base_channels, base_channels), nn.Conv2d(base_channels, out_channels, 3, 1, 1), nn.Tanh())
    def forward(self, z): return self.decoder(z)

class InterpolationNet(nn.Module):
    def __init__(self, latent_dim=256, hidden_dim=512):
        super().__init__()
        input_dim = latent_dim * 2 + 1
        self.network = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.LeakyReLU(0.2, inplace=True), nn.Linear(hidden_dim // 2, latent_dim))
    def forward(self, z_a, z_b, alpha):
        net_input = torch.cat([z_a, z_b, alpha], dim=1)
        return self.network(net_input)

class CustomImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.paths = sorted([os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.lower().endswith(('.jpg','.jpeg','.png'))])
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, idx

def denormalize(tensor):
    return (tensor.clamp(-1, 1) * 0.5) + 0.5


def test_final_network(num_tests=5):
    # --- 2. 경로 설정 및 모델 로드 ---
    DEVICE = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
    AE_CHECKPOINT_DIR = '/home/nas/data/YMG/superdeep_ae/checkpoints/'
    ENCODER_PATH = os.path.join(AE_CHECKPOINT_DIR, 'encoder_superdeep_best.pth')
    DECODER_PATH = os.path.join(AE_CHECKPOINT_DIR, 'decoder_superdeep_best.pth')
    DATA_DIR = '/home/nas/data/YMG/datas/celeba_hq_256/'
    LATENT_VECTORS_PATH = '/home/nas/data/YMG/superdeep_ae/my_checkpoints/real_latent_vectors_20k.pt'
    INTERPOLATION_NET_DIR = '/home/nas/data/YMG/superdeep_ae/interpolation_network_final_sequential/' # **수정 필요 시**
    INTERPOLATION_NET_PATH = os.path.join(INTERPOLATION_NET_DIR, 'interpolation_net_best.pth')

    # 모든 모델 로드
    encoder = EncoderSuperDeep(base_channels=128).to(DEVICE); encoder.eval()
    decoder = DecoderSuperDeep(base_channels=128).to(DEVICE); decoder.eval()
    interpolation_net = InterpolationNet(hidden_dim=512).to(DEVICE); interpolation_net.eval()

    encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=DEVICE))
    decoder.load_state_dict(torch.load(DECODER_PATH, map_location=DEVICE))
    interpolation_net.load_state_dict(torch.load(INTERPOLATION_NET_PATH, map_location=DEVICE))
    print("모든 모델(VAE, 최종 보간 네트워크)을 성공적으로 로드했습니다.")
    
    # 원본 이미지 및 잠재 벡터 로드
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
    dataset = CustomImageDataset(DATA_DIR, transform)
    latent_vectors = torch.load(LATENT_VECTORS_PATH).to(DEVICE)
    
    # --- 3. 비교 테스트 수행 ---
    for i in range(num_tests):
        print(f"\n--- Test Pair #{i+1} ---")
        start_idx, end_idx = random.sample(range(20000), 2)
        num_inter_steps = 5 # 원본 사이의 보간 이미지 개수
        total_steps = num_inter_steps + 2
        alphas = torch.linspace(0, 1, total_steps)

        # 원본 이미지 및 잠재 벡터 준비
        img_a_orig, _ = dataset[start_idx]
        img_b_orig, _ = dataset[end_idx]
        mu_a = latent_vectors[start_idx].unsqueeze(0) # shape: [1, 1, 16, 16]
        mu_b = latent_vectors[end_idx].unsqueeze(0) # shape: [1, 1, 16, 16]
        
        # <<< 수정: 디코더에 넣기 전에 항상 4D 텐서로 변환 >>>
        # 선형 보간
        lerp_latents = torch.cat([(1 - alpha) * mu_a + alpha * mu_b for alpha in alphas], dim=0)
        
        # 네트워크 보간
        network_latents = []
        with torch.no_grad():
            for alpha in alphas:
                mu_a_flat = mu_a.view(1, -1)
                mu_b_flat = mu_b.view(1, -1)
                alpha_tensor = torch.tensor([[alpha]], device=DEVICE)
                inter_latent_flat = interpolation_net(mu_a_flat, mu_b_flat, alpha_tensor)
                network_latents.append(inter_latent_flat.view(1, 1, 16, 16))
        network_latents = torch.cat(network_latents, dim=0)

        # 이미지 생성
        with torch.no_grad():
            lerp_images = decoder(lerp_latents).cpu()
            network_images = decoder(network_latents).cpu()

        # 시각화
        fig, axes = plt.subplots(2, total_steps, figsize=(18, 5.5))
        fig.suptitle(f"Final Interpolation Comparison (Pair: {start_idx} & {end_idx})", fontsize=16)
        axes[0, 0].set_ylabel("Linear (Lerp)", fontsize=12)
        axes[1, 0].set_ylabel("Trained Network ($I_\\theta$)", fontsize=12)

        for j in range(total_steps):
            # 1행: 선형 보간 결과
            axes[0, j].imshow(denormalize(lerp_images[j]).permute(1, 2, 0)); axes[0, j].axis('off')
            axes[0, j].set_title(f"α={alphas[j]:.2f}")
            
            # 2행: 네트워크 보간 결과
            axes[1, j].imshow(denormalize(network_images[j]).permute(1, 2, 0)); axes[1, j].axis('off')

        # <<< 수정: 3행 대신 1행의 양 끝을 원본 이미지로 교체 >>>
        axes[0, 0].imshow(denormalize(img_a_orig).permute(1, 2, 0))
        axes[0, 0].set_title(f"Origin #{start_idx}")
        axes[0, -1].imshow(denormalize(img_b_orig).permute(1, 2, 0))
        axes[0, -1].set_title(f"Origin #{end_idx}")
            
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

if __name__ == '__main__':
    test_final_network()

  encoder.load_state_dict(torch.load(ENCODER_PATH, map_location=DEVICE))
  decoder.load_state_dict(torch.load(DECODER_PATH, map_location=DEVICE))
  interpolation_net.load_state_dict(torch.load(INTERPOLATION_NET_PATH, map_location=DEVICE))


모든 모델(VAE, 최종 보간 네트워크)을 성공적으로 로드했습니다.


  latent_vectors = torch.load(LATENT_VECTORS_PATH).to(DEVICE)



--- Test Pair #1 ---


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [7, 256]