In [8]:
import torch
import wandb
import numpy as np
from tqdm import trange
from src.models.base_texture_nca import BaseTextureNCA
from src.loss.vgg_loss import VGGLoss
from src.loss.mse_loss import MSELoss
from src.trainer.nca_trainer import NCATrainer
from src.help.utils import set_seed, load_target_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

wandb.login()

True

# config

In [9]:
# 1. 設定預設 Config
default_config = {
    'project_name': 'base_texture_nca',
    'run_name': 'test_mse_baseline',  # [修改點] 自訂 Run Name
    'loss_type': 'mse',               # [修改點] 切換 'mse' 或 'vgg'
    'chn': 12,
    'hidden_n': 96,
    'batch_size': 8,
    'step_min': 32,
    'step_max': 96,
    'pool_reset_freq': 8,
    'pool_size': 1024,
    'lr': 2e-3,
    'seed': 42,
    'target_img_path': r'C:\Users\GAI\Desktop\Scott\NCA_Research\dataset\lizard.png',
    'img_size': 128,
    'use_gradient_checkpoint': False
}

# define logic

## setup case

In [10]:
def setup_case(config):
    """初始化所有訓練所需物件"""
    # 1. 設定種子
    set_seed(config.seed)
    
    # 2. 準備目標圖片 & Loss
    # 假設你有一個 target.jpg，若無可先用雜訊代替測試
    try:
        target_img = load_target_image(config.target_img_path, config.img_size, device)
    except FileNotFoundError:
        print(f"Warning: {config.target_img_path} not found. Using random target.")
        target_img = torch.rand(1, 3, config.img_size, config.img_size).to(device)
    
    # 3. 依照 config 選擇 Loss 函數
    if config.loss_type == 'mse':
        print("Using Simple MSE Loss")
        loss_fn = MSELoss(target_img)
    else:
        print("Using VGG / OT Loss")
        loss_fn = VGGLoss(target_img, device=device)

    # 4. 建立模型
    model = BaseTextureNCA(config).to(device)
    
    # 5. 建立優化器與排程器
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    # 範例 Scheduler: 2000 步後降低學習率
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2000], gamma=0.3)
    
    # 6. 建立資料池 (Data Pool)
    # 直接使用 Model 的 seed 方法產生初始池
    pool = model.seed(config.pool_size, sz=config.img_size)
    
    return model, optimizer, scheduler, pool, loss_fn

## train

In [11]:
def train(model, optimizer, scheduler, pool, loss_fn, config):
    """訓練迴圈"""
    
    # 初始化我們之前寫好的 Trainer
    trainer = NCATrainer(
        model=model,
        pool=pool,
        optimizer=optimizer,
        scheduler=scheduler,
        loss_fn=loss_fn,
        config=config,
        device=device
    )

    # 訓練迴圈
    steps = 5000 # 可以放入 config
    for i in trange(steps):
        # 執行一步訓練
        stats = trainer.train_step(current_step=i)
        
        # Wandb Logging
        log_dict = {
            "train/loss": stats['loss'],
            "train/lr": stats['lr'],
            "step": i
        }

        # 定期記錄圖片 (例如每 100 步)
        if i % 100 == 0:
            # 取 batch 中的第一張圖轉為 wandb Image
            img_tensor = stats['batch_x'][0].detach().cpu() # [C, H, W]
            img_rgb = img_tensor[:3].permute(1, 2, 0).numpy() # [H, W, 3]
            img_rgb = np.clip(img_rgb, 0, 1)
            
            log_dict["train/image"] = wandb.Image(img_rgb, caption=f"Step {i}")

        wandb.log(log_dict)

## test

In [12]:
def test(model, config):
    """測試/生成階段：產生一段長時間的演化影片"""
    print("Testing (Generating Video)...")
    with torch.no_grad():
        x = model.seed(1, sz=config.img_size)
        frames = []
        
        for _ in range(300): # 演化 300 步
            x = model(x)
            img = x[0, :3].permute(1, 2, 0).cpu().numpy()
            img = np.clip(img, 0, 1)
            # 轉為 uint8 [0, 255] 以便 wandb 處理
            frames.append((img * 255).astype(np.uint8))
            
        # 記錄為影片
        frames = np.array(frames).transpose(0, 3, 1, 2) # [T, C, H, W]
        wandb.log({"test/video": wandb.Video(frames, fps=30, format="mp4")})

## pipeline

In [13]:
def model_pipeline(hyperparams):
    """主流程"""
    # 這裡傳入 config=hyperparams 讓 wandb 幫我們管理 config
    run_name = hyperparams.get('run_name', None)
    with wandb.init(project=hyperparams['project_name'], config=hyperparams, name=run_name):
        # 使用 wandb.config (這樣如果用 Sweeps 調整參數會自動生效)
        config = wandb.config
        
        # 1. Setup
        model, optim, sched, pool, loss_fn = setup_case(config)
        
        # 2. Train
        train(model, optim, sched, pool, loss_fn, config)
        
        # 3. Test
        test(model, config)
        
    return model

# run case

In [14]:
final_model = model_pipeline(default_config)

Using Simple MSE Loss


 22%|██▏       | 1079/5000 [04:36<16:43,  3.91it/s]
Traceback (most recent call last):
  File "C:\Users\GAI\AppData\Local\Temp\ipykernel_30860\3639408828.py", line 13, in model_pipeline
    train(model, optim, sched, pool, loss_fn, config)
  File "C:\Users\GAI\AppData\Local\Temp\ipykernel_30860\3682625580.py", line 19, in train
    stats = trainer.train_step(current_step=i)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\users\gai\desktop\scott\nca_research\src\trainer\nca_trainer.py", line 102, in train_step
    loss.backward()
  File "c:\Users\GAI\miniconda3\envs\pytorch-py311\Lib\site-packages\torch\_tensor.py", line 647, in backward
    torch.autograd.backward(
  File "c:\Users\GAI\miniconda3\envs\pytorch-py311\Lib\site-packages\torch\autograd\__init__.py", line 354, in backward
    _engine_run_backward(
  File "c:\Users\GAI\miniconda3\envs\pytorch-py311\Lib\site-packages\torch\autograd\graph.py", line 829, in _engine_run_backward
    return Variable._execution_engine.run

0,1
step,▁▁▁▂▂▃▃▃▄▄▅▅▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇█████
train/loss,▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
step,1078.0
train/loss,0.05061
train/lr,0.002


KeyboardInterrupt: 