In [1]:
import os,sys
import json
import torch

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

In [2]:
# 构建数据集
from torch.utils.data import DataLoader
from src.dataset import build_global_vocab_and_maxcount, CLEVRMultiLabelByImage

clevr_root = "../CLEVR_v1.0"

# 全局统计：colors/shapes/max_objects（train+val+test）
colors, shapes, max_objects, _ = build_global_vocab_and_maxcount(clevr_root, splits=("train","val"))

print("num_colors:", len(colors), colors)
print("num_shapes:", len(shapes), shapes)
print("max_objects:", max_objects)

train_ds = CLEVRMultiLabelByImage(
    clevr_root=clevr_root,
    split="train",
    colors=colors,
    shapes=shapes,
    max_objects=max_objects,
)

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=0)

images, color_mh, shape_mh, count_oh, img_fns = next(iter(train_dl))
print("images:", images.shape)
print("color_mh:", color_mh.shape)   # [B, num_colors]
print("shape_mh:", shape_mh.shape)   # [B, num_shapes]
print("count_oh:", count_oh.shape)   # [B, max_objects+1]
print("example fn:", img_fns[0])
print("example count one-hot argmax:", int(count_oh[0].argmax()))

num_colors: 8 ['blue', 'brown', 'cyan', 'gray', 'green', 'purple', 'red', 'yellow']
num_shapes: 3 ['cube', 'cylinder', 'sphere']
max_objects: 10
images: torch.Size([4, 3, 224, 224])
color_mh: torch.Size([4, 8])
shape_mh: torch.Size([4, 3])
count_oh: torch.Size([4, 11])
example fn: CLEVR_train_014489.png
example count one-hot argmax: 8


In [4]:
# 构建模型
from src.model import REVAE_V1
import torch.optim as optim

revae = REVAE_V1()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
revae.to(device)
optimizer = optim.Adam(revae.parameters())