In [1]:
import os,sys
import json
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

In [2]:
# 构建数据集
from torch.utils.data import DataLoader,Subset
from src.dataset import build_global_vocab_and_maxcount, CLEVRMultiLabelByImage
import torchvision.transforms as T


clevr_root = "../CLEVR_v1.0"

# 全局统计：colors/shapes/max_objects（train+val+test）
colors, shapes, max_objects, _ = build_global_vocab_and_maxcount(clevr_root, splits=("train","val"))

print("num_colors:", len(colors), colors)
print("num_shapes:", len(shapes), shapes)
print("max_objects:", max_objects)

tfm64 = T.Compose([
    T.Resize((64, 64)),
    T.ToTensor(),
])

tfm128 = T.Compose([
    T.Resize((128, 128)),
    T.ToTensor(),
])

train_ds_64 = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
    transform=tfm64
)

train_ds_128 = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
    transform=tfm128
)

train_dl_64 = DataLoader(train_ds_64, batch_size=100, shuffle=True, num_workers=0)
train_dl_128 = DataLoader(train_ds_128, batch_size=100, shuffle=True, num_workers=0)

num_colors: 8 ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']
num_shapes: 3 ['cube', 'cylinder', 'sphere']
max_objects: 10


In [3]:
# 构建模型
from src.model import REVAE_V64, REVAE_V128

revae_64 = REVAE_V64()
revae_128 = REVAE_V128()

In [4]:
from src.train import TrainConfig, fit, evaluate
import torch.optim as optim

cfg64 = TrainConfig(
    epochs=100,
    lr=1e-3,
    beta_kl=1,
    lam_color=0,
    lam_shape=0,
    lam_count=0,
    recon_loss="bce_logits",   # 你已确认 images 在 [0,1] bce_logits\mse\l1
    use_amp=False,
    log_every=100,
    save_best=False,
    ckpt_dir="../checkpoints",
    ckpt_name="revae_v64.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

cfg128 = TrainConfig(
    epochs=100,
    lr=1e-3,
    beta_kl=1,
    lam_color=0,
    lam_shape=0,
    lam_count=0,
    recon_loss="bce_logits",   # 你已确认 images 在 [0,1] bce_logits\mse\l1
    use_amp=False,
    log_every=100,
    save_best=False,
    ckpt_dir="../checkpoints",
    ckpt_name="revae_v128.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

optimizer64 = optim.Adam(revae_64.parameters(), lr=cfg64.lr)
optimizer128 = optim.Adam(revae_128.parameters(), lr=cfg128.lr)

# 若你还没做 val_dl，可以先 val_loader=None
result = fit(revae_64, train_dl_64, optimizer64, cfg64, val_loader=None)

epoch1


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.use_amp and device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(cfg.use_amp and device.type == "cuda")):


[train] epoch 1 step 0/700 total=9230.7988 recon=9230.4375 kl=0.3617 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 100/700 total=8371.2178 recon=8361.0195 kl=10.1978 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 200/700 total=8347.8906 recon=8336.1846 kl=11.7064 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 300/700 total=8346.2080 recon=8335.2510 kl=10.9572 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 400/700 total=8350.2139 recon=8338.2588 kl=11.9550 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 500/700 total=8344.2939 recon=8332.3945 kl=11.8995 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 600/700 total=8348.4209 recon=8336.2822 kl=12.1390 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
epoch2
[train] epoch 2 step 0/700 total=8345.2646 recon=8333.5488 kl=11.7155 loss_color=0.0000 loss_shape=0.0000 loss_cou

In [5]:
result = fit(revae_128, train_dl_128, optimizer128, cfg128, val_loader=None)

epoch1
[train] epoch 1 step 0/700 total=37296.2734 recon=37295.9297 kl=0.3446 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 100/700 total=33578.3633 recon=33549.5586 kl=28.8055 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 200/700 total=33534.2773 recon=33511.2109 kl=23.0683 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 300/700 total=33467.2734 recon=33446.8828 kl=20.3925 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 400/700 total=33481.5586 recon=33460.2383 kl=21.3199 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 500/700 total=33453.1680 recon=33431.4961 kl=21.6711 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 600/700 total=33434.0195 recon=33412.0078 kl=22.0104 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
epoch2
[train] epoch 2 step 0/700 total=33442.7422 recon=33420.8203 kl=21.9226 loss_color=0.0000 los