In [None]:
# %% [markdown]
# # AutoTrainAI 基础使用示例
# ## Basic Usage Example

# %% [markdown]
# ### 1. 环境准备 / Environment Setup
# 安装依赖库 | Install dependencies
# !pip install -r requirements.txt

# %% [markdown]
# ### 2. 配置加载 / Load Configuration
import yaml
from addict import Dict

# 加载配置文件
with open('../configs/default.yaml') as f:
    cfg = Dict(yaml.safe_load(f))
    
print("当前配置:\n", yaml.dump(cfg.to_dict()))

# %% [markdown]
# ### 3. 数据预处理 / Data Preprocessing
import numpy as np
from autotrain.utils.preprocess import DataProcessor

# 生成示例数据 | Generate demo data
raw_data = np.random.rand(1000, cfg.model.input_dim)
np.save(cfg.data.input_path, raw_data)

# 初始化处理器
processor = DataProcessor(
    normalize=cfg.data.preprocessing.normalize,
    augment=cfg.data.preprocessing.augment
)

# 加载并预处理数据
dataset = processor.load_data(cfg.data.input_path)
print("处理后的数据形状:", dataset.shape)

# %% [markdown]
# ### 4. 模型初始化 / Model Initialization
import torch
from autotrain.core import AutoEncoder

# 检查硬件配置
device = torch.device(cfg.hardware.device if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 创建模型实例
model = AutoEncoder(
    input_dim=cfg.model.input_dim,
    latent_dim=cfg.model.latent_dim,
    encoder_dims=cfg.model.encoder_layers,
    decoder_dims=cfg.model.decoder_layers
).to(device)

# %% [markdown]
# ### 5. 训练流程 / Training Process
from autotrain.utils.trainer import SelfTrainer

# 初始化训练器
trainer = SelfTrainer(
    model=model,
    train_cfg=cfg.train,
    device=device,
    log_dir=cfg.logging.log_dir
)

# 执行训练
history = trainer.fit(dataset)

# %% [markdown]
# ### 6. 结果可视化 / Visualization
import matplotlib.pyplot as plt

# 绘制损失曲线
plt.figure(figsize=(10, 5))
plt.plot(history['train_loss'], label='Training Loss')
plt.plot(history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.show()

# %% [markdown]
# ### 7. 保存模型 / Save Model
torch.save(model.state_dict(), f"{cfg.logging.checkpoint_dir}/final_model.pth")
print(f"模型已保存至 {cfg.logging.checkpoint_dir}")