In [None]:
import os

# TODO: change ID
ID = "vae-v1-ddp"

# Distributed
RANK = int(os.getenv("RANK", 0))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1))
DISTRIBUTED = WORLD_SIZE > 1

# Mixed precision
MIXED_PRECISION = True

# TODO: change device
DEVICE = "cuda"
DEVICE_ID = 1  # None for CPU
DEVICE_IDS = [0, 1, 2, 3]
OMP_NUM_THREADS = 10
SEED = 42

# Dataset
DATASET_REPETITIONS = 1
IMAGE_SIZE = 128

# Model architecture
LATENT_DIM = 4
NUM_HEADS = 8
VAE_SCALE = 2
VAE_BETA = 0.5

# Training
# TODO: change epochs
START_EPOCH = 0
EPOCHS = 100
PLOT_EVERY = 1
BATCH_SIZE = 8 * WORLD_SIZE

LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

# Dataset
DS_PATH = "./dataset"
DS_IMAGE_PATH = DS_PATH
DS_ID2WORD_PATH = os.path.join(DS_PATH, "dictionary/id2Word.npy")
DS_VOCAB_PATH = os.path.join(DS_PATH, "dictionary/vocab.npy")
DS_WORD2ID_PATH = os.path.join(DS_PATH, "dictionary/word2Id.npy")
DS_TEXT2IMG_PATH = os.path.join(DS_PATH, "dataset/text2ImgData.pkl")
DS_TEST_DATA_PATH = os.path.join(DS_PATH, "dataset/testData.pkl")

# Others
CHECKPOINT_DIR = os.path.join("./ckpts/", ID)
CHECKPOINT_NAME = "ckpt"
OUTPUT_DIR = os.path.join("./outputs/", ID)
SAVE_PLOTS = True

In [None]:
import os
import random
import torch
import warnings
from IPython import get_ipython


# Check if CUDA is available
if torch.cuda.is_available():
    gpus = torch.cuda.device_count()
    DEVICE_ID = DEVICE_ID if DEVICE_ID < gpus else 0
    torch.cuda.set_device(DEVICE_ID)
    if DISTRIBUTED:
        DEVICE_IDS = [id for id in DEVICE_IDS if id < gpus]
    else:
        for device_id in DEVICE_IDS:
            if device_id >= gpus:
                raise ValueError(f"GPU {device_id} is not available.")
else:
    DEVICE = "cpu"
    DEVICE_ID = None

if DISTRIBUTED:
    print(f"Rank {RANK} Using distributed training with devices: {DEVICE_IDS}")
else:
    print(f"Using device id: {DEVICE_ID}")

os.environ["OMP_NUM_THREADS"] = str(OMP_NUM_THREADS)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

random.seed(SEED)
torch.manual_seed(SEED)

warnings.filterwarnings("ignore")

# Dataset

In [None]:
from IPython import get_ipython

from utils.dataset import DatasetToolConfig, DatasetTool


def test_datasets():
    ipy = get_ipython()
    if ipy is not None:
        ipy.magic("matplotlib inline")
        import matplotlib.pyplot as plt

        dt_cfg = DatasetToolConfig(
            id2word_path=DS_ID2WORD_PATH,
            vocab_path=DS_VOCAB_PATH,
            word2id_path=DS_WORD2ID_PATH,
            image_path=DS_IMAGE_PATH,
            text2img_path=DS_TEXT2IMG_PATH,
            test_data_path=DS_TEST_DATA_PATH,
        )
        dt = DatasetTool(dt_cfg)

        train_loader = dt.get_train_loader(
            16,
            IMAGE_SIZE,
            repeats=DATASET_REPETITIONS,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
            rank=RANK,
            world_size=WORLD_SIZE,
            distributed=DISTRIBUTED,
        )

        images, captions = next(iter(train_loader))
        plt.figure(figsize=(16, 4))
        for i, image in enumerate(images):
            plt.subplot(2, 8, i + 1)
            plt.imshow(image.permute(1, 2, 0))
            plt.title(captions[i][:12])
            plt.axis("off")
            plt.tight_layout()
        plt.show()


# test_datasets()

In [None]:
import time
from datetime import datetime
import logging

import torch
import torch.distributed as dist

from utils.dataset import DatasetToolConfig, DatasetTool
from models.sd import VAEModel, VAEModelConfig
from utils.checkpoint import load_checkpoint, save_checkpoint
from utils.distributed import init_distributed, cleanup_distributed

cfg = VAEModelConfig(
    image_height=IMAGE_SIZE,
    image_width=IMAGE_SIZE,
    latent_height=IMAGE_SIZE // 8,
    latent_width=IMAGE_SIZE // 8,
    latent_dim=LATENT_DIM,
    num_heads=NUM_HEADS,
    vae_scale=VAE_SCALE,
    vae_beta=VAE_BETA,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    mixed_precision=MIXED_PRECISION,
    device=DEVICE,
    device_type="cuda" if DEVICE != "cpu" else "cpu",
    rank=RANK,
    world_size=WORLD_SIZE,
    device_ids=DEVICE_IDS,
    distributed=DISTRIBUTED,
)
ds_cfg = DatasetToolConfig(
    id2word_path=DS_ID2WORD_PATH,
    vocab_path=DS_VOCAB_PATH,
    word2id_path=DS_WORD2ID_PATH,
    image_path=DS_IMAGE_PATH,
    text2img_path=DS_TEXT2IMG_PATH,
    test_data_path=DS_TEST_DATA_PATH,
)

In [None]:
def train():
    if DISTRIBUTED:
        # Initialize
        init_distributed(RANK, WORLD_SIZE)
        torch.cuda.set_device(DEVICE_IDS[RANK])

    start_ts = time.perf_counter()
    model = VAEModel(cfg)
    print(f"Model created in {time.perf_counter() - start_ts:.2f}s")

    checkpoint = load_checkpoint(
        CHECKPOINT_DIR, CHECKPOINT_NAME, epoch=START_EPOCH, device=DEVICE
    )
    if checkpoint is not None:
        model.load_checkpoint(checkpoint)

    ds_tool = DatasetTool(ds_cfg)
    train_loader, val_loader = ds_tool.get_train_val_loader(
        BATCH_SIZE,
        IMAGE_SIZE,
        repeats=DATASET_REPETITIONS,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
        rank=RANK,
        world_size=WORLD_SIZE,
        distributed=DISTRIBUTED,
    )

    for epoch in range(START_EPOCH + 1, EPOCHS + START_EPOCH + 1):
        print(
            "{}, epoch {:3d}/{:3d}".format(
                datetime.now(), epoch, EPOCHS + START_EPOCH
            )
        )
        epoch_train_metrics = torch.zeros(1).to(DEVICE)
        epoch_val_metrics = torch.zeros(1).to(DEVICE)

        start_ts = time.perf_counter()

        if DISTRIBUTED:
            train_loader.sampler.set_epoch(epoch)
        for _idx, (images, _) in enumerate(train_loader):
            idx = _idx + 1
            train_metrics = model.train_step(images)
            epoch_train_metrics += torch.tensor([train_metrics["loss"]]).to(
                DEVICE
            )
            if idx % 20 == 0:
                print(
                    "rank {:2d}, train epoch {:3d}/{:3d}, batch {:4d}/{:4d}, loss: {:.4f}".format(
                        RANK,
                        epoch,
                        EPOCHS + START_EPOCH,
                        idx,
                        len(train_loader),
                        train_metrics["loss"],
                    )
                )

        # Save checkpoint
        if RANK == 0:
            save_checkpoint(
                epoch,
                model.checkpoint(),
                CHECKPOINT_DIR,
                CHECKPOINT_NAME,
            )

        if DISTRIBUTED:
            val_loader.sampler.set_epoch(epoch)
        for _idx, (images, _) in enumerate(val_loader):
            idx = _idx + 1
            val_metrics = model.test_step(images)
            epoch_val_metrics += torch.tensor([val_metrics["loss"]]).to(DEVICE)

        if DISTRIBUTED:
            dist.all_reduce(epoch_train_metrics, op=dist.ReduceOp.AVG)

        # Print metrics
        if RANK == 0:
            avg_train_metrics = epoch_train_metrics / len(train_loader)
            epoch_train_metrics.zero_()
            print(
                "rank {:2d}, epoch {:3d}/{:3d}, loss: {:.4f}, time: {:.2f}s".format(
                    RANK,
                    epoch,
                    EPOCHS + START_EPOCH,
                    avg_train_metrics[0],
                    time.perf_counter() - start_ts,
                )
            )

        if RANK == 0 and epoch % PLOT_EVERY == 0:
            if DISTRIBUTED:
                val_loader.sampler.set_epoch(epoch)
            images, _ = next(iter(val_loader))
            model.plot_images(
                images,
                save=SAVE_PLOTS,
                output_dir=OUTPUT_DIR,
                epoch=epoch,
            )

        if DISTRIBUTED:
            dist.barrier()

    if DISTRIBUTED:
        # Cleanup
        cleanup_distributed()


if EPOCHS > 0:
    train()