# Diffusion model training

#### This notebook aims to launch the training of the main diffusion model. It does not train the classifier and regressor that are used to perform *classifier* and *regressor guidance*. The trainings of the three models (diffusion model, regressor and classifier) are independant.

In [None]:
import torch as th
import numpy as np
import matplotlib.pyplot as plt
import os
import sys

The environment variable 'TOPODIFF_LOGDIR' defines the directory where the logs and model checkpoints will be saved.

In [None]:
os.environ['TOPODIFF_LOGDIR'] = './checkpoints/3d_diff_logdir'

The 'TRAIN FLAGS', 'MODEL_FLAGS', 'DIFFUSION_FLAGS' and 'DATA_FLAGS' respectively set the training parameters, the model and diffusion hyperparameters and the directories where the training data are.

The default values indicated below correspond to the hyperparameters indicated in the Appendix to the paper.

In [None]:
TRAIN_FLAGS = "--batch_size 4 --save_interval 10000 --use_fp16 True --microbatch 2"
MODEL_FLAGS = "--image_size 64 --num_channels 64 --num_res_blocks 2 --learn_sigma True --dropout 0.3 --use_checkpoint True"
DIFFUSION_FLAGS = "--diffusion_steps 1000 --noise_schedule cosine"

In order to run the training, make sure you have placed the data folder at the root of this directory.

All the images, physical fields, and load arrays must be altogether in the same folder (done by default in the data directory that we provide you with).

In [None]:
DATA_FLAGS = "--data_dir /home/yeoneung/Euihyun/3D_TPMS_topoDIff/data"

In [None]:
VOLUME_FLAGS = "--dims 3 --volume_size 64"

In [None]:
%run scripts/image_train.py $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS $DATA_FLAGS $VOLUME_FLAGS

In [1]:
import os
os.environ['TOPODIFF_LOGDIR'] = './checkpoints/3d_diff_logdir'

# dims 파라미터 없이, UNetModel에서 dims=3으로 하드코딩했으므로
TRAIN_FLAGS = "--batch_size 2 --save_interval 10000 --use_fp16 True"
MODEL_FLAGS = "--image_size 64 --num_channels 32 --num_res_blocks 2 --learn_sigma True --dropout 0.1 --use_checkpoint True"
DIFFUSION_FLAGS = "--diffusion_steps 1000 --noise_schedule cosine"
DATA_FLAGS = "--data_dir /home/yeoneung/Euihyun/3D_TPMS_topoDIff/data"

%run scripts/image_train.py $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS $DATA_FLAGS

Logging to ./checkpoints/3d_diff_logdir
creating model and diffusion...
creating data loader...
training...
----------------------------
| grad_norm     | 4.61     |
| lg_loss_scale | 20       |
| loss          | 1        |
| mse           | 0.998    |
| param_norm    | 98.3     |
| samples       | 2        |
| step          | 0        |
| vb            | 0.00562  |
----------------------------
saving model 0...
saving model 0.9999...
----------------------------
| grad_norm     | 4.67     |
| lg_loss_scale | 20       |
| loss          | 0.955    |
| mse           | 0.933    |
| param_norm    | 98.3     |
| samples       | 22       |
| step          | 10       |
| vb            | 0.0222   |
----------------------------
----------------------------
| grad_norm     | 4.65     |
| lg_loss_scale | 20       |
| loss          | 0.803    |
| mse           | 0.796    |
| param_norm    | 98.3     |
| samples       | 42       |
| step          | 20       |
| vb            | 0.00679  |
----------

KeyboardInterrupt: 

In [1]:
import os
os.environ['TOPODIFF_LOGDIR'] = './checkpoints/3d_diff_logdir2'

# 모든 학습 파라미터 포함
TRAIN_FLAGS = """
--batch_size 2 
--save_interval 1000 
--use_fp16 True 
--lr 5e-5 
--weight_decay 0.01
--ema_rate 0.9999
--log_interval 10
--microbatch 1
--schedule_sampler uniform
--resume_checkpoint ""
"""

MODEL_FLAGS = "--image_size 64 --num_channels 32 --num_res_blocks 2 --learn_sigma True --dropout 0.1 --use_checkpoint True"
DIFFUSION_FLAGS = "--diffusion_steps 1000 --noise_schedule cosine"
DATA_FLAGS = "--data_dir /home/yeoneung/Euihyun/3D_TPMS_topoDIff/data"

%run scripts/image_train.py $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS $DATA_FLAGS

Logging to ./checkpoints/3d_diff_logdir2
creating model and diffusion...
creating data loader...
training...
----------------------------
| grad_norm     | 7.33     |
| lg_loss_scale | 20       |
| loss          | 1.01     |
| mse           | 0.998    |
| param_norm    | 98.3     |
| samples       | 2        |
| step          | 0        |
| vb            | 0.0116   |
----------------------------
saving model 0...
saving model 0.9999...
----------------------------
| grad_norm     | 9.33     |
| lg_loss_scale | 20       |
| loss          | 0.982    |
| mse           | 0.97     |
| param_norm    | 98.3     |
| samples       | 22       |
| step          | 10       |
| vb            | 0.0129   |
----------------------------
Found NaN, decreased lg_loss_scale to 19.01700000000002
----------------------------
| grad_norm     | 9.72     |
| lg_loss_scale | 19.7     |
| loss          | 0.967    |
| mse           | 0.908    |
| param_norm    | 98.3     |
| samples       | 42       |
| step     

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import re
import numpy as np
from collections import defaultdict

def parse_log_file(log_path):
    """로그 파일을 파싱해서 메트릭들을 추출"""
    metrics = defaultdict(list)
    
    with open(log_path, 'r') as f:
        content = f.read()
    
    # 로그 블록들을 찾기 (| key | value | 형태)
    log_blocks = re.findall(r'(?:\| \w+.*?\|.*?\|\n)+', content)
    
    for block in log_blocks:
        step_data = {}
        lines = block.strip().split('\n')
        
        for line in lines:
            # | key | value | 형태 파싱
            match = re.match(r'\|\s*(\w+)\s*\|\s*([0-9\.e\-\+]+)\s*\|', line)
            if match:
                key, value = match.groups()
                try:
                    step_data[key] = float(value)
                except ValueError:
                    step_data[key] = value
        
        # step이 있으면 해당 블록의 데이터 저장
        if 'step' in step_data:
            for key, value in step_data.items():
                metrics[key].append(value)
    
    return dict(metrics)

def plot_training_metrics(log_path):
    """훈련 메트릭들을 그래프로 시각화"""
    metrics = parse_log_file(log_path)
    
    if not metrics or 'step' not in metrics:
        print("❌ 로그 파일에서 step 정보를 찾을 수 없습니다.")
        return
    
    steps = metrics['step']
    
    # 사용 가능한 메트릭들 확인
    available_metrics = [k for k in metrics.keys() if k != 'step' and len(metrics[k]) == len(steps)]
    
    print(f"📊 발견된 메트릭들: {available_metrics}")
    print(f"📈 총 {len(steps)} 스텝의 데이터")
    
    # 서브플롯 개수 결정
    n_metrics = len(available_metrics)
    if n_metrics == 0:
        print("❌ 플롯할 메트릭이 없습니다.")
        return
    
    # 그리드 크기 계산
    n_cols = min(3, n_metrics)
    n_rows = (n_metrics + n_cols - 1) // n_cols
    
    plt.figure(figsize=(15, 5*n_rows))
    
    for i, metric in enumerate(available_metrics, 1):
        plt.subplot(n_rows, n_cols, i)
        values = metrics[metric]
        
        plt.plot(steps, values, 'b-', linewidth=2, alpha=0.8)
        plt.title(f'{metric.replace("_", " ").title()}', fontsize=14, fontweight='bold')
        plt.xlabel('Step')
        plt.ylabel(metric)
        plt.grid(True, alpha=0.3)
        
        # Y축 스케일 조정
        if metric in ['loss', 'vb'] and max(values) > 10:
            plt.yscale('log')
            plt.ylabel(f'{metric} (log scale)')
        
        # 통계 정보 표시
        mean_val = np.mean(values)
        final_val = values[-1] if values else 0
        plt.text(0.02, 0.98, f'Mean: {mean_val:.3f}\nFinal: {final_val:.3f}', 
                transform=plt.gca().transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    plt.tight_layout()
    plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # 요약 통계
    print("\n📋 훈련 요약:")
    print("-" * 40)
    for metric in available_metrics:
        values = metrics[metric]
        if values:
            initial = values[0]
            final = values[-1]
            change = ((final - initial) / initial * 100) if initial != 0 else 0
            print(f"{metric:12}: {initial:8.4f} → {final:8.4f} ({change:+6.1f}%)")

def analyze_learning_progress(log_path):
    """학습 진행 상황 상세 분석"""
    metrics = parse_log_file(log_path)
    
    if 'loss' not in metrics or 'step' not in metrics:
        print("❌ loss 또는 step 정보가 없습니다.")
        return
    
    steps = np.array(metrics['step'])
    loss = np.array(metrics['loss'])
    
    # 학습 안정성 분석
    if len(loss) > 5:
        # 최근 손실의 변화율
        recent_steps = min(10, len(loss))
        recent_loss = loss[-recent_steps:]
        loss_trend = np.polyfit(range(recent_steps), recent_loss, 1)[0]
        
        print(f"\n🔍 학습 상태 분석:")
        print(f"- 총 스텝: {len(steps)}")
        print(f"- 초기 손실: {loss[0]:.4f}")
        print(f"- 현재 손실: {loss[-1]:.4f}")
        print(f"- 최근 {recent_steps}스텝 트렌드: {'⬇️ 감소' if loss_trend < 0 else '⬆️ 증가'} ({loss_trend:.6f}/step)")
        
        # 수렴 여부 판단
        if len(loss) > 20:
            recent_var = np.var(recent_loss)
            if recent_var < 0.1 and abs(loss_trend) < 0.01:
                print("✅ 모델이 수렴하고 있습니다!")
            elif loss_trend > 0.1:
                print("⚠️ 손실이 증가하고 있습니다. 학습률을 낮춰보세요.")
            else:
                print("🔄 모델이 계속 학습 중입니다.")

# 실행
if __name__ == "__main__":
    log_path = "/home/yeoneung/Euihyun/3D_TPMS_topoDIff/topodiff/checkpoints/3d_diff_logdir/log.txt"
    
    print("🚀 훈련 로그 분석 시작...")
    plot_training_metrics(log_path)
    analyze_learning_progress(log_path)

In [None]:
import os
import numpy as np

def convert_npz_keys_to_arr0(directory):
    for fname in os.listdir(directory):
        if fname.endswith(".npz"):
            path = os.path.join(directory, fname)
            try:
                data = np.load(path)
                if 'surface_field' in data:
                    # 'surface_field' -> 'arr_0'로 다시 저장
                    np.savez(path, arr_0=data['surface_field'])
                    print(f"[변환 완료] {fname}")
                elif 'arr_0' in data:
                    print(f"[스킵] 이미 arr_0 있음: {fname}")
                else:
                    print(f"[경고] 유효한 key 없음: {fname}")
            except Exception as e:
                print(f"[에러] {fname}: {e}")

# 사용 예시
convert_npz_keys_to_arr0("/home/yeoneung/Euihyun/3D_TPMS_topoDIff/data")


By the end of the training, you should get in the diff_logdir a series of checkpoints. You can then use the last checkpoint as the difusion model when sampling from TopoDiff (see the notebook **4_TopoDiff_sample**).