In [5]:
import os,sys
import json
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

In [6]:
# 构建数据集
from torch.utils.data import DataLoader
from src.dataset import build_global_vocab_and_maxcount, CLEVRMultiLabelByImage

clevr_root = "../CLEVR_v1.0"

# 全局统计：colors/shapes/max_objects（train+val+test）
colors, shapes, max_objects, _ = build_global_vocab_and_maxcount(clevr_root, splits=("train","val"))

print("num_colors:", len(colors), colors)
print("num_shapes:", len(shapes), shapes)
print("max_objects:", max_objects)

train_ds = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
)

train_dl = DataLoader(train_ds, batch_size=100, shuffle=True, num_workers=0)

images, color_mh, shape_mh, count_oh, img_fns = next(iter(train_dl))
print("images:", images.shape)
print("color_mh:", color_mh.shape)   # [B, num_colors]
print("shape_mh:", shape_mh.shape)   # [B, num_shapes]
print("count_oh:", count_oh.shape)   # [B, max_objects+1]
print("example fn:", img_fns[0])
print("example count one-hot argmax:", int(count_oh[0].argmax()))

num_colors: 8 ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']
num_shapes: 3 ['cube', 'cylinder', 'sphere']
max_objects: 10
images: torch.Size([100, 3, 224, 224])
color_mh: torch.Size([100, 8])
shape_mh: torch.Size([100, 3])
count_oh: torch.Size([100, 11])
example fn: CLEVR_train_066291.png
example count one-hot argmax: 8


In [7]:
# 构建模型
from src.model import REVAE_V1

revae = REVAE_V1()


In [9]:
from src.train import TrainConfig, fit, evaluate
import torch.optim as optim

cfg = TrainConfig(
    epochs=10,
    lr=1e-3,
    beta_kl=1.0,
    lam_color=1.0,
    lam_shape=1.0,
    lam_count=1.0,
    recon_loss="bce_logits",   # 你已确认 images 在 [0,1]
    use_amp=False,
    log_every=100,
    save_best=False,
    ckpt_dir="../checkpoints",
    ckpt_name="revae_v1.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

optimizer = optim.Adam(revae.parameters(), lr=cfg.lr)

# 若你还没做 val_dl，可以先 val_loader=None
result = fit(revae, train_dl, optimizer, cfg, val_loader=None)
print(result["ckpt_path"])


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.use_amp and device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(cfg.use_amp and device.type == "cuda")):


[train] epoch 1 step 0/700 total=101901.2031 recon=101822.8750 kl=75.6508 loss_color=0.6492 loss_shape=0.3415 loss_count=1.6900
[train] epoch 1 step 100/700 total=102189.5391 recon=102105.8047 kl=81.1955 loss_color=0.6638 loss_shape=0.3541 loss_count=1.5224
[train] epoch 1 step 200/700 total=101889.0234 recon=101806.1016 kl=80.4592 loss_color=0.6473 loss_shape=0.2287 loss_count=1.5895
[train] epoch 1 step 300/700 total=102091.9453 recon=102008.9297 kl=80.4331 loss_color=0.6521 loss_shape=0.3313 loss_count=1.6084
[train] epoch 1 step 400/700 total=101971.2422 recon=101891.7266 kl=76.7131 loss_color=0.6697 loss_shape=0.3185 loss_count=1.8148
[train] epoch 1 step 500/700 total=101859.3750 recon=101780.4141 kl=76.2682 loss_color=0.6419 loss_shape=0.3026 loss_count=1.7528
[train] epoch 1 step 600/700 total=102046.2578 recon=101961.5547 kl=82.0323 loss_color=0.6451 loss_shape=0.3105 loss_count=1.7133
[train] epoch 2 step 0/700 total=101863.1875 recon=101785.1797 kl=75.2800 loss_color=0.6564 