# model_hps.py

VDVAE의 encoder, decoder 내부 layer 수 및 이미지 설정 등 모델 구조에 관련된 하이퍼파라미터를 여기서 관리합니다.

In [1]:
# we manage parameters for model structure here.
block_str = "32x11,32d2,16x6,16d2,8x6,8d2,4x3,4d4,1x3"    # encoder 구조
custom_width_str = ""                                     # 해상도별 채널 수 조정 (16:64,8:64)
res_list = [(1, None), (4, 1), (8, 4), (16, 8), (32, 16)] # decoder 구조

image_size = 32
image_channels = 3
base_width = 384
bottleneck_multiple = 0.25
zdim = 16
num_mixtures = 10
n_blocks = 3

# hps.py

모델의 학습 관련 및 저장소, 데이터 루트 등의 하이퍼파라미터는 모두 여기 모아서 관리합니다.

In [2]:
HPARAMS_REGISTRY = {}

class Hyperparams(dict):
    def __getattr__(self, attr):
        try:
            return self[attr]
        except KeyError:
            return None

    def __setattr__(self, attr, value):
        self[attr] = value


# We only use CIFAR-10 dataset
cifar10 = Hyperparams()
cifar10.dataset = 'cifar10'
cifar10.lr = 0.0002
cifar10.wd = 0.01
cifar10.n_batch = 32
cifar10.ema_rate =  0.9998
cifar10.warmup_iters = 100
cifar10.skip_threshold = 400.0
cifar10.max_iters = 1563        # training ends up based on which is longer between max_iters & epoch.
cifar10.num_epochs = 10
HPARAMS_REGISTRY['cifar10'] = cifar10


def parse_args_and_update_hparams(H, parser, s=None):
    args = parser.parse_args(s)
    valid_args = set(vars(args).keys())

    hps = HPARAMS_REGISTRY['cifar10']
    for k in hps:
        if k not in valid_args:
            raise ValueError(f"{k} not in default args")
    parser.set_defaults(**hps)
    args = parser.parse_args(s)
    H.update(vars(args))


# we manage all parameters here except for model structure.
def add_vae_arguments(parser):
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--port', type=int, default=29500)
    parser.add_argument('--save_dir', type=str, default='./saved_models')
    parser.add_argument('--data_root', type=str, default='../content')

    parser.add_argument('--desc', type=str, default='test')
    parser.add_argument('--restore_path', type=str, default=None,
                        help="checkpoint prefix를 지정하면 그 지점부터 학습을 복원!")        # default = './saved_models/test/latest'
    parser.add_argument('--restore_ema_path', type=str, default=None)                        # default='./saved_models/test/latest'
    parser.add_argument('--restore_log_path', type=str, default=None)                        # default='./saved_models/test/latest-log.jsonl'
    parser.add_argument('--restore_optimizer_path', type=str, default=None)                  # default='./saved_models/test/latest-opt.th'
    parser.add_argument('--dataset', type=str, default='cifar10')

    parser.add_argument('--ema_rate', type=float, default=0.999)

    parser.add_argument('--test_eval', action="store_true")
    parser.add_argument('--warmup_iters', type=float, default=0)

    parser.add_argument('--grad_clip', type=float, default=200.0)
    parser.add_argument('--skip_threshold', type=float, default=400.0)
    parser.add_argument('--lr', type=float, default=0.00015)
    parser.add_argument('--lr_prior', type=float, default=0.00015)
    parser.add_argument('--wd', type=float, default=0.0)
    parser.add_argument('--wd_prior', type=float, default=0.0)
    parser.add_argument('--num_epochs', type=int, default=10)                 # 10000 (maximum)
    parser.add_argument('--n_batch', type=int, default=32)
    parser.add_argument('--adam_beta1', type=float, default=0.9)
    parser.add_argument('--adam_beta2', type=float, default=0.9)

    parser.add_argument('--temperature', type=float, default=1.0)

    parser.add_argument('--iters_per_ckpt', type=int, default=25000)
    parser.add_argument('--iters_per_print', type=int, default=1000)
    parser.add_argument('--iters_per_save', type=int, default=1500)          # 10000
    parser.add_argument('--iters_per_images', type=int, default=10000)
    parser.add_argument('--epochs_per_eval', type=int, default=10)            # number of epoch
    parser.add_argument('--epochs_per_probe', type=int, default=None)
    parser.add_argument('--epochs_per_eval_save', type=int, default=20)
    parser.add_argument('--num_images_visualize', type=int, default=8)
    parser.add_argument('--num_variables_visualize', type=int, default=6)
    parser.add_argument('--num_temperatures_visualize', type=int, default=3)
    parser.add_argument('--max_iters', type=int, default=3125)                # number of maximum iterations
    return parser

# data.py

In [3]:
import numpy as np
import pickle
import os
import torch
from torch.utils.data import TensorDataset
from sklearn.model_selection import train_test_split


def set_up_data(H):
    shift_loss = -127.5
    scale_loss = 1. / 127.5
    if H.dataset == 'cifar10':
        (trX, _), (vaX, _), (teX, _) = load_cifar10_data(H.data_root, one_hot=False)
        H.image_size = 32
        H.image_channels = 3
        shift = -120.63838
        scale = 1. / 64.16736
    else:
        raise ValueError('unknown dataset: ', H.dataset)

    if H.test_eval:
        print('DOING TEST')
        eval_dataset = teX
    else:
        eval_dataset = vaX

    # Reshape shift, scale, shift_loss, and scale_loss for broadcasting
    shift = torch.tensor([shift]).cuda().view(1, 1, 1, 1)
    scale = torch.tensor([scale]).cuda().view(1, 1, 1, 1)
    shift_loss = torch.tensor([shift_loss]).cuda().view(1, 1, 1, 1)
    scale_loss = torch.tensor([scale_loss]).cuda().view(1, 1, 1, 1)


    train_data = TensorDataset(torch.as_tensor(trX))
    valid_data = TensorDataset(torch.as_tensor(eval_dataset))

    def preprocess_func(x):
        nonlocal shift
        nonlocal scale
        nonlocal shift_loss
        nonlocal scale_loss
        inp = x[0].cuda(non_blocking=True).float()
        out = inp.clone()
        inp.add_(shift).mul_(scale)
        out.add_(shift_loss).mul_(scale_loss)
        return inp, out

    return H, train_data, valid_data, preprocess_func


def unpickle_cifar10(file):
    fo = open(file, 'rb')
    data = pickle.load(fo, encoding='bytes')
    fo.close()
    data = dict(zip([k.decode() for k in data.keys()], data.values()))
    return data


def load_cifar10_data(data_root, one_hot=True):
    root = os.path.join(data_root, 'cifar-10-batches-py')

    # load training batches
    data_list, label_list = [], []
    for i in range(1, 6):
        batch_path = os.path.join(root, f'data_batch_{i}')
        batch = unpickle_cifar10(batch_path)
        data_list.append(batch['data'])
        label_list.append(batch['labels'])
    trX = np.concatenate(data_list, axis=0).astype(np.float32)               # (50000, 3072)
    trY = np.concatenate(label_list, axis=0).astype(np.int64)                # (50000,)

    # load test batches
    test_batch = unpickle_cifar10(os.path.join(root, 'test_batch'))
    teX = np.array(test_batch['data'], dtype=np.uint8).astype(np.float32)    # (10000, 3072)
    teY = np.array(test_batch['labels'], dtype=np.int64)                     # (10000,)

    trX = trX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    teX = teX.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    trX, vaX, trY, vaY = train_test_split(trX, trY, test_size=5000, random_state=11172018)

    if one_hot:
        trY = np.eye(10, dtype=np.float32)[trY]
        vaY = np.eye(10, dtype=np.float32)[vaY]
        teY = np.eye(10, dtype=np.float32)[teY]
    else:
        trY = np.reshape(trY, [-1, 1])
        vaY = np.reshape(vaY, [-1, 1])
        teY = np.reshape(teY, [-1, 1])
    return (trX, trY), (vaX, vaY), (teX, teY)

# utils.py

In [5]:
!pip install mpi4py

Collecting mpi4py
  Downloading mpi4py-4.0.3.tar.gz (466 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/466.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m466.3/466.3 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mpi4py
  Building wheel for mpi4py (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mpi4py: filename=mpi4py-4.0.3-cp311-cp311-linux_x86_64.whl size=4441849 sha256=831424c6b9e310c22847972a4ccb19f31693f21e3c2ed19d9a66bc344675c00f
  Stored in directory: /root/.cache/pip/wheels/5c/56/17/bf6ba37aa971a191a8b9eaa188bf5ec855b8911c1c56fb1f84
Successfully built mpi4py
Installing collected packages: mpi4py
Successfully installed 

In [6]:
from mpi4py import MPI
import os
import json
import socket
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import AdamW
from collections import defaultdict
import argparse
import time
import numpy as np
import subprocess


def parse_layer_string(s):
    layers = []
    for ss in s.split(','):
        if 'x' in ss:
            res, num = ss.split('x')
            count = int(num)
            layers += [(int(res), None) for _ in range(count)]
        elif 'm' in ss:
            res, mixin = [int(a) for a in ss.split('m')]
            layers.append((res, mixin))
        elif 'd' in ss:
            res, down_rate = [int(a) for a in ss.split('d')]
            layers.append((res, down_rate))
        else:
            res = int(ss)
            layers.append((res, None))
    return layers


def pad_channels(t, width):
    d1, d2, d3, d4 = t.shape
    empty = torch.zeros(d1, width, d3, d4, device=t.device)
    empty[:, :d2, :, :] = t
    return empty


def get_width_settings(width, s):
    mapping = defaultdict(lambda: width)
    if s:
        s = s.split(',')
        for ss in s:
            k, v = ss.split(':')
            mapping[int(k)] = int(v)
    return mapping


@torch.jit.script
def gaussian_analytical_kl(mu1, mu2, logsigma1, logsigma2):
    return -0.5 + logsigma2 - logsigma1 + 0.5 * (logsigma1.exp() ** 2 + (mu1 - mu2) ** 2) / (logsigma2.exp() ** 2)


@torch.jit.script
def draw_gaussian_diag_samples(mu, logsigma):
    eps = torch.empty_like(mu).normal_(0., 1.)
    return torch.exp(logsigma) * eps + mu


def discretized_mix_logistic_loss(x, l, low_bit=False):
    """ log-likelihood for mixture of discretized logistics, assumes the data has been rescaled to [-1,1] interval """
    # Adapted from https://github.com/openai/pixel-cnn/blob/master/pixel_cnn_pp/nn.py
    xs = [s for s in x.shape]  # true image (i.e. labels) to regress to, e.g. (B,32,32,3)
    ls = [s for s in l.shape]  # predicted distribution, e.g. (B,32,32,100)
    nr_mix = int(ls[-1] / 10)  # here and below: unpacking the params of the mixture of logistics
    logit_probs = l[:, :, :, :nr_mix]
    l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3])
    means = l[:, :, :, :, :nr_mix]
    log_scales = const_max(l[:, :, :, :, nr_mix:2 * nr_mix], -7.)
    coeffs = torch.tanh(l[:, :, :, :, 2 * nr_mix:3 * nr_mix])
    x = torch.reshape(x, xs + [1]) + torch.zeros(xs + [nr_mix]).to(x.device)  # here and below: getting the means and adjusting them based on preceding sub-pixels
    m2 = torch.reshape(means[:, :, :, 1, :] + coeffs[:, :, :, 0, :] * x[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix])
    m3 = torch.reshape(means[:, :, :, 2, :] + coeffs[:, :, :, 1, :] * x[:, :, :, 0, :] + coeffs[:, :, :, 2, :] * x[:, :, :, 1, :], [xs[0], xs[1], xs[2], 1, nr_mix])
    means = torch.cat([torch.reshape(means[:, :, :, 0, :], [xs[0], xs[1], xs[2], 1, nr_mix]), m2, m3], dim=3)
    centered_x = x - means
    inv_stdv = torch.exp(-log_scales)
    if low_bit:
        plus_in = inv_stdv * (centered_x + 1. / 31.)
        cdf_plus = torch.sigmoid(plus_in)
        min_in = inv_stdv * (centered_x - 1. / 31.)
    else:
        plus_in = inv_stdv * (centered_x + 1. / 255.)
        cdf_plus = torch.sigmoid(plus_in)
        min_in = inv_stdv * (centered_x - 1. / 255.)
    cdf_min = torch.sigmoid(min_in)
    log_cdf_plus = plus_in - F.softplus(plus_in)  # log probability for edge case of 0 (before scaling)
    log_one_minus_cdf_min = -F.softplus(min_in)  # log probability for edge case of 255 (before scaling)
    cdf_delta = cdf_plus - cdf_min  # probability for all other cases
    mid_in = inv_stdv * centered_x
    log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)  # log probability in the center of the bin, to be used in extreme cases (not actually used in our code)

    # now select the right output: left edge case, right edge case, normal case, extremely low prob case (doesn't actually happen for us)

    # this is what we are really doing, but using the robust version below for extreme cases in other applications and to avoid NaN issue with tf.select()
    # log_probs = tf.select(x < -0.999, log_cdf_plus, tf.select(x > 0.999, log_one_minus_cdf_min, tf.log(cdf_delta)))

    # robust version, that still works if probabilities are below 1e-5 (which never happens in our code)
    # tensorflow backpropagates through tf.select() by multiplying with zero instead of selecting: this requires use to use some ugly tricks to avoid potential NaNs
    # the 1e-12 in tf.maximum(cdf_delta, 1e-12) is never actually used as output, it's purely there to get around the tf.select() gradient issue
    if low_bit:
        log_probs = torch.where(x < -0.999,
                                log_cdf_plus,
                                torch.where(x > 0.999,
                                            log_one_minus_cdf_min,
                                            torch.where(cdf_delta > 1e-5,
                                                        torch.log(const_max(cdf_delta, 1e-12)),
                                                        log_pdf_mid - np.log(15.5))))
    else:
        log_probs = torch.where(x < -0.999,
                                log_cdf_plus,
                                torch.where(x > 0.999,
                                            log_one_minus_cdf_min,
                                            torch.where(cdf_delta > 1e-5,
                                                        torch.log(const_max(cdf_delta, 1e-12)),
                                                        log_pdf_mid - np.log(127.5))))
    log_probs = log_probs.sum(dim=3) + log_prob_from_logits(logit_probs)
    mixture_probs = torch.logsumexp(log_probs, -1)
    return -1. * mixture_probs.sum(dim=[1, 2]) / np.prod(xs[1:])


def const_max(t, constant):
    other = torch.ones_like(t) * constant
    return torch.max(t, other)


def const_min(t, constant):
    other = torch.ones_like(t) * constant
    return torch.min(t, other)


def sample_from_discretized_mix_logistic(l, nr_mix):
    ls = [s for s in l.shape]
    xs = ls[:-1] + [3]
    # unpack parameters
    logit_probs = l[:, :, :, :nr_mix]
    l = torch.reshape(l[:, :, :, nr_mix:], xs + [nr_mix * 3])
    # sample mixture indicator from softmax
    eps = torch.empty(logit_probs.shape, device=l.device).uniform_(1e-5, 1. - 1e-5)
    amax = torch.argmax(logit_probs - torch.log(-torch.log(eps)), dim=3)
    sel = F.one_hot(amax, num_classes=nr_mix).float()
    sel = torch.reshape(sel, xs[:-1] + [1, nr_mix])
    # select logistic parameters
    means = (l[:, :, :, :, :nr_mix] * sel).sum(dim=4)
    log_scales = const_max((l[:, :, :, :, nr_mix:nr_mix * 2] * sel).sum(dim=4), -7.)
    coeffs = (torch.tanh(l[:, :, :, :, nr_mix * 2:nr_mix * 3]) * sel).sum(dim=4)
    # sample from logistic & clip to interval
    # we don't actually round to the nearest 8bit value when sampling
    u = torch.empty(means.shape, device=means.device).uniform_(1e-5, 1. - 1e-5)
    x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
    x0 = const_min(const_max(x[:, :, :, 0], -1.), 1.)
    x1 = const_min(const_max(x[:, :, :, 1] + coeffs[:, :, :, 0] * x0, -1.), 1.)
    x2 = const_min(const_max(x[:, :, :, 2] + coeffs[:, :, :, 1] * x0 + coeffs[:, :, :, 2] * x1, -1.), 1.)
    return torch.cat([torch.reshape(x0, xs[:-1] + [1]), torch.reshape(x1, xs[:-1] + [1]), torch.reshape(x2, xs[:-1] + [1])], dim=3)


class DmolNet(nn.Module):
    def __init__(self, width, num_mixtures, low_bit=False):
        super().__init__()
        self.width = width
        self.num_mixtures = num_mixtures
        self.low_bit = low_bit
        self.out_conv = nn.Conv2d(width, num_mixtures * 10, kernel_size=1, stride=1, padding=0)

    def nll(self, px_z, x):
        return discretized_mix_logistic_loss(x=x, l=self.forward(px_z), low_bit=self.low_bit)

    def forward(self, px_z):
        if not isinstance(px_z, torch.Tensor):
            if isinstance(px_z, np.ndarray):
                px_z = torch.from_numpy(px_z).to(device=self.out_conv.weight.device, dtype=self.out_conv.weight.dtype).contiguous()
        xhat = self.out_conv(px_z)
        return xhat.permute(0, 2, 3, 1)

    def sample(self, px_z):
        im = sample_from_discretized_mix_logistic(self.forward(px_z), self.num_mixtures)
        xhat = (im + 1.0) * 127.5
        xhat = xhat.detach().cpu().numpy()
        xhat = np.minimum(np.maximum(0.0, xhat), 255.0).astype(np.uint8)
        return xhat


def log_prob_from_logits(x):
    """ numerically stable log_softmax implementation that prevents overflow """
    axis = len(x.shape) - 1
    m = x.max(dim=axis, keepdim=True)[0]
    return x - m - torch.log(torch.exp(x - m).sum(dim=axis, keepdim=True))


def mpi_size():
    return MPI.COMM_WORLD.Get_size()


def mpi_rank():
    return MPI.COMM_WORLD.Get_rank()


def compute_mpi_topology():
    world_size = mpi_size()
    global_rank = mpi_rank()
    # compute num_nodes
    if world_size % 8 == 0:
        num_nodes = world_size // 8
    else:
        num_nodes = world_size // 8 + 1
    # compute gpus_per_nodes
    if world_size > 1:
        gpus_per_node = max(world_size // num_nodes, 1)
    else:
        gpus_per_node = 1
    # local rank
    local_rank = global_rank % gpus_per_node

    return world_size, local_rank, global_rank


def setup_mpi(H):
    H.mpi_size, H.local_rank, H.rank = compute_mpi_topology()
    os.environ["RANK"] = str(H.rank)
    os.environ["WORLD_SIZE"] = str(H.mpi_size)
    os.environ["MASTER_PORT"] = str(H.port)
    # os.environ["NCCL_LL_THRESHOLD"] = "0"
    os.environ["MASTER_ADDR"] = MPI.COMM_WORLD.bcast(socket.gethostname(), root=0)
    # 처음 한 번만 초기화
    if not dist.is_initialized():
      torch.cuda.set_device(H.local_rank)
      dist.init_process_group(backend='nccl', init_method="env://")   # remove f''


def mkdir_p(path):
    os.makedirs(path, exist_ok=True)


def setup_save_dirs(H):
    H.save_dir = os.path.join(H.save_dir, H.desc)
    mkdir_p(H.save_dir)
    H.logdir = os.path.join(H.save_dir, 'log')


def logger(log_prefix):
    'Prints the arguments out to stdout, .txt, and .jsonl files'
    jsonl_path = f'{log_prefix}.jsonl'
    txt_path = f'{log_prefix}.txt'

    def log(*args, pprint=False, **kwargs):
        if mpi_rank() != 0:
            return
        t = time.ctime()
        argdict = {'time': t}
        if len(args) > 0:
            argdict['message'] = ' '.join([str(x) for x in args])
        argdict.update(kwargs)

        txt_str = []
        args_iter = sorted(argdict) if pprint else argdict
        for k in args_iter:
            val = argdict[k]
            if isinstance(val, np.ndarray):
                val = val.tolist()
            elif isinstance(val, np.integer):
                val = int(val)
            elif isinstance(val, np.floating):
                val = float(val)
            argdict[k] = val
            if isinstance(val, float):
                val = f'{val:.5f}'
            txt_str.append(f'{k}: {val}')
        txt_str = ', '.join(txt_str)

        if pprint:
            json_str = json.dumps(argdict, sort_keys=True)
            txt_str = json.dumps(argdict, sort_keys=True, indent=4)
        else:
            json_str = json.dumps(argdict)
        print(txt_str, flush=True)

        with open(txt_path, "a+") as f:
            print(txt_str, file=f, flush=True)
        with open(jsonl_path, "a+") as f:
            print(json_str, file=f, flush=True)
    return log


def set_up_hyperparams(s=None):
    H = Hyperparams()
    parser = argparse.ArgumentParser()
    parser = add_vae_arguments(parser)
    parse_args_and_update_hparams(H, parser, s=s)
    setup_mpi(H)
    setup_save_dirs(H)
    logprint = logger(H.logdir)
    for i, k in enumerate(sorted(H)):
        logprint(type='hparam', key=k, value=H[k])
    np.random.seed(H.seed)
    torch.manual_seed(H.seed)
    torch.cuda.manual_seed(H.seed)
    logprint('traning model', H.desc, 'on', H.dataset)
    return H, logprint


def linear_warmup(warmup_iters):
    def f(iteration):
        return iteration / warmup_iters if iteration < warmup_iters else 1.0
    return f


def load_vaes(encoder, decoder, image_size, logprint):
    mpi_size, local_rank, rank = compute_mpi_topology()
    torch.cuda.set_device(local_rank)

    vae = VAE(encoder, decoder, image_size).cuda(local_rank)
    ema_vae = VAE(encoder, decoder, image_size).cuda(local_rank)
    ema_vae.load_state_dict(vae.state_dict())
    ema_vae.requires_grad_(True)

    if mpi_size > 1:
        vae = DistributedDataParallel(vae, device_ids=[local_rank], output_device=local_rank)
    # validate parameter names
    named = list(vae.named_parameters())
    all_params = list(vae.parameters())
    if len(named) != len(all_params):
        raise ValueError("Some parameters are unnamed-DDP requires all params to be named")
    total_params = 0
    for name, p in vae.named_parameters():
        total_params += np.prod(p.shape)
    logprint(total_params=total_params, readable=f'{total_params:,}')
    return vae, ema_vae


def load_opt(H, vae, logprint):
    optimizer = AdamW(vae.parameters(), weight_decay=H.wd, lr=H.lr, betas=(H.adam_beta1, H.adam_beta2))
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup(H.warmup_iters))

    starting_epoch = 0
    iterate = 0
    cur_eval_loss = float('inf')
    logprint('optimizer & scheduler initialized', epoch=starting_epoch, iterate=iterate, eval_loss=cur_eval_loss)
    return optimizer, scheduler, starting_epoch, iterate, cur_eval_loss


def allreduce(x, average):
    if mpi_size() > 1:
        dist.all_reduce(x, dist.ReduceOp.SUM)
    return x / mpi_size() if average else x


def get_cpu_stats_over_ranks(stat_dict):
    keys = sorted(stat_dict.keys())
    stats = torch.stack([torch.as_tensor(stat_dict[k]).detach().cuda().float() for k in keys])
    allreduced = allreduce(stats, average=True).cpu()
    return {k: allreduced[i].item() for (i, k) in enumerate(keys)}


def save_model(path, vae, ema_vae, optimizer, H):
    torch.save(vae.state_dict(), f'{path}-model.th')
    torch.save(ema_vae.state_dict(), f'{path}-model-ema.th')
    torch.save(optimizer.state_dict(), f'{path}-opt.th')
    from_log = os.path.join(H.save_dir, 'log.jsonl')
    to_log = f'{os.path.dirname(path)}/{os.path.basename(path)}-log.jsonl'
    subprocess.check_output(['cp', from_log, to_log])


def accumulate_stats(stats, frequency):
    z = {}
    for k in stats[-1]:
        if k in ['distortion_nans', 'rate_nans', 'skipped_updates', 'gcskip']:
            z[k] = np.sum([a[k] for a in stats[-frequency:]])
        elif k == 'grad_norm':
            vals = [a[k] for a in stats[-frequency:]]
            finites = np.array(vals)[np.isfinite(vals)]
            if len(finites) == 0:
                z[k] = 0.0
            else:
                z[k] = np.max(finites)
        elif k == 'elbo':
            vals = [a[k] for a in stats[-frequency:]]
            finites = np.array(vals)[np.isfinite(vals)]
            z['elbo'] = np.mean(vals)
            z['elbo_filtered'] = np.mean(finites)
        elif k == 'iter_time':
            z[k] = stats[-1][k] if len(stats) < frequency else np.mean([a[k] for a in stats[-frequency:]])
        else:
            z[k] = np.mean([a[k] for a in stats[-frequency:]])
    return z


# Block.py

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Block(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch, downsample=False, residual=True, zero_last=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=1)
        self.conv2 = nn.Conv2d(mid_ch, mid_ch, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(mid_ch, mid_ch, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(mid_ch, out_ch, kernel_size=1)

        if zero_last:
            nn.init.zeros_(self.conv4.weight)
            if self.conv4.bias is not None:
                nn.init.zeros_(self.conv4.bias)

        self.use_residual = residual
        self.use_downsample = downsample
        self.down = nn.AvgPool2d(kernel_size=2) if downsample else nn.Identity()

    def forward(self, x):
        out = F.gelu(x)
        out = self.conv1(out)
        out = F.gelu(out)
        out = self.conv2(out)
        out = F.gelu(out)
        out = self.conv3(out)
        out = F.gelu(out)
        out = self.conv4(out)

        if self.use_residual:
            out = out + x

        out = self.down(out)
        return out

# Encoder.py

In [8]:
class Encoder(nn.Module):
    def __init__(self, image_channels, base_width, custom_width_str, block_str, bottleneck_multiple):
        super().__init__()

        self.in_conv = nn.Conv2d(image_channels, base_width, kernel_size=3, padding=1)
        self.widths = get_width_settings(base_width, custom_width_str)
        block_config = parse_layer_string(block_str)

        enc_blocks = []
        for res, down_rate in block_config:
            width = self.widths[res]
            mid_width = int(width * bottleneck_multiple)

            # 원본 방식: 모든 Block은 in_ch == out_ch == width
            block = Block(
                in_ch=width,
                mid_ch=mid_width,
                out_ch=width,
                downsample=(down_rate is not None),
                residual=True
            )
            enc_blocks.append(block)

        self.enc_blocks = nn.ModuleList(enc_blocks)
        self.block_resolutions = [res for res, _ in block_config]

    def forward(self, x):
        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.in_conv(x)

        feats = {}
        feats[x.shape[2]] = x  # 초기 해상도

        for block, res in zip(self.enc_blocks, self.block_resolutions):
            # 🔥 원본 방식: Block 입력 전에 채널 맞춰줌
            if x.shape[1] != self.widths[res]:
                x = pad_channels(x, self.widths[res])

            x = block(x)
            feats[res] = x

        return feats

Crossattention

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, dim_qk, heads=4):
        super().__init__()
        self.dim = dim_qk
        self.heads = heads
        self.scale = (dim_qk // heads) ** -0.5

        self.to_q = nn.Linear(dim_qk, dim_qk)
        self.to_k = nn.Linear(dim_qk, dim_qk)
        self.to_v = nn.Linear(dim_qk, dim_qk)
        self.to_out = nn.Linear(dim_qk, dim_qk)

    def forward(self, x_q, x_kv):
        B, C, H, W = x_q.shape
        N = H * W

        q = self.to_q(x_q.flatten(2).permute(0, 2, 1))  # B, N, C
        k = self.to_k(x_kv.flatten(2).permute(0, 2, 1))
        v = self.to_v(x_kv.flatten(2).permute(0, 2, 1))

        q = q.view(B, N, self.heads, C // self.heads).transpose(1, 2)  # B, heads, N, d
        k = k.view(B, N, self.heads, C // self.heads).transpose(1, 2)
        v = v.view(B, N, self.heads, C // self.heads).transpose(1, 2)

        attn = torch.softmax(torch.matmul(q, k.transpose(-1, -2)) * self.scale, dim=-1)
        out = torch.matmul(attn, v)  # B, heads, N, d
        out = out.transpose(1, 2).contiguous().view(B, N, C)  # B, N, C
        out = self.to_out(out).permute(0, 2, 1).view(B, C, H, W)

        return out

# DecBlock.py

In [13]:
class DecBlock(nn.Module):
    def __init__(self, res, width, zdim, bottleneck_multiple, mixin_res=None, n_blocks=1):
        super().__init__()
        self.res = res
        self.width = width
        self.zdim = zdim
        self.mixin = mixin_res

        cond_width = int(width * bottleneck_multiple)

        self.enc = Block(width * 2, cond_width, zdim * 2, residual=False)
        self.prior = Block(width, cond_width, zdim * 2 + width, residual=False, zero_last=True)

        self.z_proj = nn.Conv2d(zdim, width, kernel_size=1)
        self.z_proj.weight.data *= np.sqrt(1 / n_blocks)

        self.resnet = Block(width, cond_width, width, residual=True)
        self.resnet.conv4.weight.data *= np.sqrt(1 / n_blocks)

        #attention
        self.cross_attn = CrossAttention(dim_qk=width) if mixin_res is not None else None

    def z_fn(self, z):
        return self.z_proj(z)

    def get_inputs(self, xs, activations):
        acts = activations[self.res]
        x = xs.get(self.res, torch.zeros_like(acts))

        # 🔥 interpolate acts if shape mismatch
        if acts.shape[2:] != x.shape[2:]:
            acts = F.interpolate(acts, size=x.shape[2:], mode='nearest')

        if acts.shape[0] != x.shape[0]:
            x = x.repeat(acts.shape[0], 1, 1, 1)

        return x, acts

    def sample(self, x, acts):
        qm, qv = self.enc(torch.cat([x, acts], dim=1)).chunk(2, dim=1)
        feats = self.prior(x)
        pm, pv, xpp = feats[:, :self.zdim], feats[:, self.zdim:self.zdim*2], feats[:, self.zdim*2:]
        x = x + xpp
        z = draw_gaussian_diag_samples(qm, qv)
        kl = gaussian_analytical_kl(qm, pm, qv, pv)
        return z, x, kl

    def sample_uncond(self, x, t=None, lvs=None):
        feats = self.prior(x)
        pm, pv, xpp = feats[:, :self.zdim], feats[:, self.zdim:self.zdim*2], feats[:, self.zdim*2:]
        x = x + xpp
        if lvs is not None:
            z = lvs
        else:
            if t is not None:
                pv = pv + torch.ones_like(pv) * np.log(t)
            z = draw_gaussian_diag_samples(pm, pv)
        return z, x

    def forward(self, xs, activations, get_latents=False):
        x, acts = self.get_inputs(xs, activations)
        if self.mixin is not None:
          mix = xs[self.mixin][:, :x.shape[1], ...]
          mix = F.interpolate(mix, size=x.shape[2:], mode='nearest')  # 해상도 맞추기
          x = x + self.cross_attn(x, mix)

        if torch.rand(1).item() < 0.01:  # 확률적으로 출력
          print(f"[Attn] Cross-attention applied at res {self.res} ← {self.mixin}, shape: {x.shape}")


        z, x, kl = self.sample(x, acts)
        x = x + self.z_fn(z)
        x = self.resnet(x)
        xs[self.res] = x

        if get_latents:
            return xs, dict(z=z.detach(), kl=kl)
        return xs, dict(kl=kl)

    def forward_uncond(self, xs, t=None, lvs=None):
        if self.res in xs:
            x = xs[self.res]
        else:
            ref = xs[list(xs.keys())[0]]
            x = torch.zeros(ref.shape[0], self.width, self.res, self.res, device=ref.device)

        if self.mixin is not None:
            mix = xs[self.mixin][:, :x.shape[1], ...]
            mix = F.interpolate(mix, size=x.shape[2:], mode='nearest')
            x = x + self.cross_attn(x, mix)

        z, x = self.sample_uncond(x, t, lvs=lvs)
        x = x + self.z_fn(z)
        x = self.resnet(x)
        xs[self.res] = x
        return xs

# Decoder.py

In [12]:
import itertools

class Decoder(nn.Module):
    def __init__(self, res_list, width_map, zdim, bottleneck_multiple,
                 output_res, n_blocks, num_mixtures, low_bit=False):
        super().__init__()
        self.output_res = output_res
        self.width_map = width_map

        self.blocks = nn.ModuleList([
            DecBlock(
                res=res,
                width=width_map[res],
                zdim=zdim,
                bottleneck_multiple=bottleneck_multiple,
                mixin_res=mixin,
                n_blocks=n_blocks
            )
            for res, mixin in res_list
        ])

        self.bias_xs = nn.ParameterDict({
            str(res): nn.Parameter(torch.zeros(1, width_map[res], res, res))
            for res, _ in res_list
        })

        out_width = width_map[output_res]
        self.gain = nn.Parameter(torch.ones(1, out_width, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, out_width, 1, 1))

        # 🔥 DmolNet 붙이기
        self.out_net = DmolNet(width=out_width, num_mixtures=num_mixtures, low_bit=low_bit)

    def final_fn(self, x):
        return x * self.gain + self.bias

    def forward(self, activations, get_latents=False):
        B = next(iter(activations.values())).shape[0]
        xs = {
            int(res): bias.repeat(B, 1, 1, 1)
            for res, bias in self.bias_xs.items()
        }

        stats = []
        for block in self.blocks:
            xs, block_stat = block(xs, activations, get_latents=get_latents)
            stats.append(block_stat)

        out = self.final_fn(xs[self.output_res])
        return out, stats  # 🔥 DmolNet 통과시켜 반환

    def forward_uncond(self, n, t=None):
        xs = {
            int(res): bias.repeat(n, 1, 1, 1)
            for res, bias in self.bias_xs.items()
        }

        for idx, block in enumerate(self.blocks):
            temp = t[idx] if isinstance(t, list) else t
            xs = block.forward_uncond(xs, t=temp)

        out = self.final_fn(xs[self.output_res])
        return self.out_net.sample(out)  # 🔥 DmolNet 통해 샘플링

    def forward_manual_latents(self, n, latents, t=None):
        xs = {
            int(res): bias.repeat(n, 1, 1, 1)
            for res, bias in self.bias_xs.items()
        }

        for block, lvs in itertools.zip_longest(self.blocks, latents):
            xs = block.forward_uncond(xs, t=t, lvs=lvs)

        out = self.final_fn(xs[self.output_res])
        return self.out_net.sample(out)


# VAE

In [14]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, image_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.image_size = image_size  # 예: 32 (CIFAR)

    def forward(self, x, x_target):
        activations = self.encoder(x)  # 해상도별 feature map 반환
        px_z, stats = self.decoder(activations, get_latents=True)

        distortion_per_pixel = self.decoder.out_net.nll(px_z, x_target)
        rate_per_pixel = torch.zeros_like(distortion_per_pixel)

        # 각 블록에서 KL divergence를 누적
        for stat in stats:
            rate_per_pixel += stat['kl'].sum(dim=(1, 2, 3))

        ndims = np.prod(x.shape[1:])  # 픽셀 수
        rate_per_pixel /= ndims
        elbo = (distortion_per_pixel + rate_per_pixel).mean()

        return {
            'elbo': elbo,
            'distortion': distortion_per_pixel.mean(),
            'rate': rate_per_pixel.mean()
        }

    def forward_get_latents(self, x):
        activations = self.encoder(x)
        _, stats = self.decoder(activations, get_latents=True)
        return stats

    def forward_uncond_samples(self, n_batch, t=None):
        # Removed redundant call to self.decoder.out_net.sample
        return self.decoder.forward_uncond(n_batch, t=t)

    def forward_samples_set_latents(self, n_batch, latents, t=None):
        # Removed redundant call to self.decoder.out_net.sample
        return self.decoder.forward_manual_latents(n_batch, latents, t=t)

# Train.py

In [15]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import imageio

In [16]:
!wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
!tar -xf cifar-10-python.tar.gz

--2025-06-08 13:42:26--  https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2025-06-08 13:42:42 (11.4 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]



In [17]:
def training_step(H, data_input, target, vae, ema_vae, optimizer, iterate):
    t0 = time.time()
    vae.zero_grad()
    stats = vae.forward(data_input, target)

    stats['elbo'].backward()
    grad_norm = torch.nn.utils.clip_grad_norm_(vae.parameters(), H.grad_clip).item()
    distortion_nans = torch.isnan(stats['distortion']).sum()
    rate_nans = torch.isnan(stats['rate']).sum()
    stats.update(dict(rate_nans=0 if rate_nans == 0 else 1, distortion_nans=0 if distortion_nans == 0 else 1))
    stats = get_cpu_stats_over_ranks(stats)

    skipped_updates = 1
    # only update if no rank has a nan and if the grad norm is below a specific threshold
    if stats['distortion_nans'] == 0 and stats['rate_nans'] == 0 and (H.skip_threshold == -1 or grad_norm < H.skip_threshold):
        optimizer.step()
        skipped_updates = 0
        for p1, p2 in zip(vae.parameters(), ema_vae.parameters()):
            p2.data.mul_(H.ema_rate)
            p2.data.add_(p1.data * (1 - H.ema_rate))
    t1 = time.time()
    stats.update(skipped_updates=skipped_updates, iter_time=t1 - t0, grad_norm=grad_norm)
    return stats


def eval_step(data_input, target, ema_vae):
    with torch.no_grad():
        stats = ema_vae.forward(data_input, target)
    stats = get_cpu_stats_over_ranks(stats)
    return stats


def get_sample_for_visualization(data, preprocess_func, batch_size):
    for x in DataLoader(data, batch_size=batch_size):
        break
    orig_image = x[0]
    preprocessed = preprocess_func(x)[0]
    return orig_image, preprocessed


def train_loop(H, data_train, data_valid, preprocess_func, vae, ema_vae,
               optimizer, scheduler, starting_epoch, iterate, cur_eval_loss, logprint):
    train_sampler = DistributedSampler(data_train, num_replicas=H.mpi_size, rank=H.rank)
    viz_batch_original, viz_batch_processed = get_sample_for_visualization(data_valid, preprocess_func, H.num_images_visualize) # Removed H.dataset as it's not used
    early_evals = set([1] + [2 ** exp for exp in range(3, 14)])
    stats = []
    iters_since_starting = 0
    H.ema_rate = torch.as_tensor(H.ema_rate).cuda()
    for epoch in range(starting_epoch, H.num_epochs):
        train_sampler.set_epoch(epoch)
        for x in DataLoader(data_train, batch_size=H.n_batch, drop_last=True, pin_memory=True, sampler=train_sampler):
            if H.max_iters > 0 and iterate >= H.max_iters:
                logprint(f"Reached max_iters={H.max_iters}, stopping training.")
                return
            data_input, target = preprocess_func(x)
            training_stats = training_step(H, data_input, target, vae, ema_vae, optimizer, iterate)
            stats.append(training_stats)
            scheduler.step()
            if iterate % H.iters_per_print == 0 or iters_since_starting in early_evals:
                logprint(model=H.desc, type='train_loss', lr=scheduler.get_last_lr()[0], epoch=epoch, step=iterate, **accumulate_stats(stats, H.iters_per_print))

            if iterate % H.iters_per_images == 0 or (iters_since_starting in early_evals and H.dataset != 'ffhq_1024') and H.rank == 0:
                write_images(H, ema_vae, viz_batch_original, viz_batch_processed, f'{H.save_dir}/samples-{iterate}.png', logprint)

            iterate += 1
            iters_since_starting += 1

            # 일정 시간마다 "latest" 체크포인트를 저장!
            if iterate % H.iters_per_save == 0 and H.rank == 0:
                if np.isfinite(stats[-1]['elbo']):
                    logprint(model=H.desc, type='train_loss', epoch=epoch, step=iterate, **accumulate_stats(stats, H.iters_per_print))
                    fp = os.path.join(H.save_dir, 'latest')
                    logprint(f'Saving model@ {iterate} to {fp}')
                    save_model(fp, vae, ema_vae, optimizer, H)

            if iterate % H.iters_per_ckpt == 0 and H.rank == 0:
                save_model(os.path.join(H.save_dir, f'iter-{iterate}'), vae, ema_vae, optimizer, H)

        if epoch % H.epochs_per_eval == 0:
            valid_stats = evaluate(H, ema_vae, data_valid, preprocess_func)
            logprint(model=H.desc, type='eval_loss', epoch=epoch, step=iterate, **valid_stats)


def evaluate(H, ema_vae, data_valid, preprocess_func):
    stats_valid = []
    valid_sampler = DistributedSampler(data_valid, num_replicas=H.mpi_size, rank=H.rank)
    for x in DataLoader(data_valid, batch_size=H.n_batch, drop_last=True, pin_memory=True, sampler=valid_sampler):
        data_input, target = preprocess_func(x)
        stats_valid.append(eval_step(data_input, target, ema_vae))
    vals = [a['elbo'] for a in stats_valid]
    finites = np.array(vals)[np.isfinite(vals)]
    stats = dict(n_batches=len(vals), filtered_elbo=np.mean(finites), **{k: np.mean([a[k] for a in stats_valid]) for k in stats_valid[-1]})
    return stats


def write_images(H, ema_vae, viz_batch_original, viz_batch_processed, fname, logprint):
    zs = [s['z'].cuda() for s in ema_vae.forward_get_latents(viz_batch_processed)]
    batches = [viz_batch_original.numpy()]
    mb = viz_batch_processed.shape[0]
    lv_points = np.floor(np.linspace(0, 1, H.num_variables_visualize + 2) * len(zs)).astype(int)[1:-1]
    for i in lv_points:
        reconstruction = ema_vae.decoder.forward_manual_latents(mb, zs[:i], t=0.1)
        batches.append(reconstruction)
    for t in [1.0, 0.9, 0.8, 0.7][:H.num_temperatures_visualize]:
        sample = ema_vae.decoder.forward_uncond(mb, t=t)
        batches.append(sample)
    n_rows = len(batches)
    im = np.concatenate(batches, axis=0).reshape((n_rows, mb, *viz_batch_processed.shape[1:])).transpose([0, 2, 1, 3, 4]).reshape([n_rows * viz_batch_processed.shape[1], mb * viz_batch_processed.shape[2], 3]).astype(np.uint8) # Explicitly cast to uint8
    logprint(f'printing samples to {fname}')
    imageio.imwrite(fname, im)


def run_test_eval(H, ema_vae, data_test, preprocess_func, logprint):
    print('evaluating')
    stats = evaluate(H, ema_vae, data_test, preprocess_func)
    print('test results')
    for k in stats:
        print(k, stats[k])
    logprint(type='test_loss', **stats)


def main():
    # Encoder/Decoder/VAE 생성
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = Encoder(
        image_channels=image_channels,
        base_width=base_width,
        custom_width_str=custom_width_str,
        block_str=block_str,
        bottleneck_multiple=bottleneck_multiple
    )
    decoder = Decoder(
        res_list=res_list,
        width_map=encoder.widths,
        zdim=zdim,
        bottleneck_multiple=bottleneck_multiple,
        output_res=image_size,
        n_blocks=n_blocks,
        num_mixtures=num_mixtures,
        low_bit=False
    )

    H, logprint = set_up_hyperparams(s=[])
    H.device = device
    H, data_train, data_valid_or_test, preprocess_func = set_up_data(H)
    vae, ema_vae = load_vaes(encoder, decoder, image_size, logprint)

    # 저장된 checkpoint에서 학습을 다시 시작하는 경우
    # load vae parameter
    if H.restore_path is not None:
        model_ckpt = f"{H.restore_path}-model.th"
        ckpt_vae = torch.load(model_ckpt, map_location=H.device)
        vae.load_state_dict(ckpt_vae)
        vae = vae.to(device)
        # load ema_vae parameter
        if H.restore_ema_path is not None:
            ema_ckpt = f"{H.restore_ema_path}-model-ema.th"
            ckpt_ema = torch.load(ema_ckpt, map_location=H.device)
            ema_vae.load_state_dict(ckpt_ema)
            ema_vae = ema_vae.to(H.device)
        print(f">> Loaded model & EMA from {H.restore_path}.")
    else:
        vae = vae.to(device)
        ema_vae = ema_vae.to(device)

    # generate optimizer & load optimizer
    optimizer, scheduler, starting_epoch, iterate, cur_eval_loss = load_opt(H, vae, logprint)
    if H.restore_optimizer_path is not None:
        optimizer_ckpt = torch.load(H.restore_optimizer_path, map_location=H.device)
        optimizer.load_state_dict(optimizer_ckpt)
        print(f">> Loaded optimizer state from {H.restore_optimizer_path}")

    # 실제 evaluation & test loop
    if H.test_eval:
        run_test_eval(H, ema_vae, data_valid_or_test, preprocess_func, logprint)
    else:
        train_loop(H, data_train, data_valid_or_test, preprocess_func, vae, ema_vae,
                   optimizer, scheduler, starting_epoch, iterate, cur_eval_loss, logprint)


if __name__ == "__main__":
    main()

time: Sun Jun  8 13:43:18 2025, type: hparam, key: adam_beta1, value: 0.90000
time: Sun Jun  8 13:43:18 2025, type: hparam, key: adam_beta2, value: 0.90000
time: Sun Jun  8 13:43:18 2025, type: hparam, key: data_root, value: ../content
time: Sun Jun  8 13:43:18 2025, type: hparam, key: dataset, value: cifar10
time: Sun Jun  8 13:43:18 2025, type: hparam, key: desc, value: test
time: Sun Jun  8 13:43:18 2025, type: hparam, key: ema_rate, value: 0.99980
time: Sun Jun  8 13:43:18 2025, type: hparam, key: epochs_per_eval, value: 10
time: Sun Jun  8 13:43:18 2025, type: hparam, key: epochs_per_eval_save, value: 20
time: Sun Jun  8 13:43:18 2025, type: hparam, key: epochs_per_probe, value: None
time: Sun Jun  8 13:43:18 2025, type: hparam, key: grad_clip, value: 200.00000
time: Sun Jun  8 13:43:18 2025, type: hparam, key: iters_per_ckpt, value: 25000
time: Sun Jun  8 13:43:18 2025, type: hparam, key: iters_per_images, value: 10000
time: Sun Jun  8 13:43:18 2025, type: hparam, key: iters_per_