In [1]:
import os,sys
import json
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

In [None]:
# 构建数据集
from torch.utils.data import DataLoader,Subset
from src.dataset import build_global_vocab_and_maxcount, CLEVRMultiLabelByImage

clevr_root = "../CLEVR_v1.0"

# 全局统计：colors/shapes/max_objects（train+val+test）
colors, shapes, max_objects, _ = build_global_vocab_and_maxcount(clevr_root, splits=("train","val"))

print("num_colors:", len(colors), colors)
print("num_shapes:", len(shapes), shapes)
print("max_objects:", max_objects)

tfm64 = T.Compose([
    T.Resize((64, 64)),
    T.ToTensor(),
])

tfm128 = T.Compose([
    T.Resize((128, 128)),
    T.ToTensor(),
])

train_ds_64 = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
    transform=tfm64
)

train_ds_128 = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
    transform=tfm128
)

train_dl_64 = DataLoader(train_ds_64, batch_size=100, shuffle=True, num_workers=0)
train_dl_128 = DataLoader(train_ds_128, batch_size=100, shuffle=True, num_workers=0)

num_colors: 8 ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']
num_shapes: 3 ['cube', 'cylinder', 'sphere']
max_objects: 10
images: torch.Size([100, 3, 224, 224])
color_mh: torch.Size([100, 8])
shape_mh: torch.Size([100, 3])
count_oh: torch.Size([100, 11])
example fn: CLEVR_train_028897.png
example count one-hot argmax: 9


In [None]:
# 构建模型
from src.model import REVAE_V1

revae_64 = REVAE_V1()
revae_128 = REVAE_V1()

In [None]:
from src.train import TrainConfig, fit, evaluate
import torch.optim as optim

cfg64 = TrainConfig(
    epochs=100,
    lr=1e-3,
    beta_kl=1,
    lam_color=0,
    lam_shape=0,
    lam_count=0,
    recon_loss="bce_logits",   # 你已确认 images 在 [0,1] bce_logits\mse\l1
    use_amp=False,
    log_every=100,
    save_best=False,
    ckpt_dir="../checkpoints",
    ckpt_name="revae_v64.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

cfg128 = TrainConfig(
    epochs=100,
    lr=1e-3,
    beta_kl=1,
    lam_color=0,
    lam_shape=0,
    lam_count=0,
    recon_loss="bce_logits",   # 你已确认 images 在 [0,1] bce_logits\mse\l1
    use_amp=False,
    log_every=100,
    save_best=False,
    ckpt_dir="../checkpoints",
    ckpt_name="revae_v128.pt",
    device="cuda" if torch.cuda.is_available() else "cpu",
)

optimizer64 = optim.Adam(revae_64.parameters(), lr=cfg64.lr)
optimizer128 = optim.Adam(revae_128.parameters(), lr=cfg128.lr)

# 若你还没做 val_dl，可以先 val_loader=None
result = fit(revae_64, train_dl_64, optimizer64, cfg64, val_loader=None)

epoch1


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.use_amp and device.type == "cuda"))
  with torch.cuda.amp.autocast(enabled=(cfg.use_amp and device.type == "cuda")):


[train] epoch 1 step 0/700 total=113803.2266 recon=113802.9609 kl=0.2647 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 100/700 total=103110.3906 recon=102987.4375 kl=122.9548 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 200/700 total=102936.2422 recon=102870.5156 kl=65.7277 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 300/700 total=102820.4062 recon=102768.8594 kl=51.5507 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 400/700 total=102676.5078 recon=102554.8203 kl=121.6905 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 500/700 total=102646.7031 recon=102590.1406 kl=56.5633 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
[train] epoch 1 step 600/700 total=102552.2734 recon=102503.6172 kl=48.6561 loss_color=0.0000 loss_shape=0.0000 loss_count=0.0000
epoch2
[train] epoch 2 step 0/700 total=102369.8359 recon=102320.3750 kl=49.4619 loss_color

In [None]:
result = fit(revae_128, train_dl_128, optimizer128, cfg128, val_loader=None)