In [1]:
import os
import sys

In [2]:
paths = ['../src/', '../guided-diffusion/']
for path in paths:
    if os.path.abspath(path) not in sys.path:
        sys.path.insert(0, os.path.abspath(path))
        print(os.path.abspath(path))

/nfs/data/andrewbai/Natural-Disaster-Image-Generation-to-raise-Environmental-Awareness/src
/nfs/data/andrewbai/Natural-Disaster-Image-Generation-to-raise-Environmental-Awareness/guided-diffusion


In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [4]:
import argparse
from importlib import reload
import numpy as np
import functools

import torch
import torchvision
import torchvision.transforms as transforms

In [6]:
from glide_text2im.gaussian_diffusion import GaussianDiffusion
from glide_text2im.model_creation import model_and_diffusion_defaults, create_model
from glide_text2im.gaussian_diffusion import get_named_beta_schedule
from glide_text2im.respace import space_timesteps
from glide_text2im.nn import timestep_embedding, zero_module
from glide_text2im.text2im_model import (
    Text2ImUNet, 
    SuperResInpaintText2ImUnet,
    InpaintText2ImUNet,
    SuperResText2ImUNet
)
from glide_text2im.tokenizer.bpe import get_encoder
from glide_text2im.xf import convert_module_to_f16

from guided_diffusion.train_util import TrainLoop
from guided_diffusion.gaussian_diffusion import ModelMeanType, ModelVarType, LossType
from guided_diffusion.nn import mean_flat
from guided_diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion import dist_util, logger

from guided_diffusion.respace import space_timesteps, SpacedDiffusion

In [7]:
from data import get_data

In [8]:
options = model_and_diffusion_defaults()

In [9]:
options['image_size'] = 256

In [10]:
options

{'image_size': 256,
 'num_channels': 192,
 'num_res_blocks': 3,
 'channel_mult': '',
 'num_heads': 1,
 'num_head_channels': 64,
 'num_heads_upsample': -1,
 'attention_resolutions': '32,16,8',
 'dropout': 0.1,
 'text_ctx': 128,
 'xf_width': 512,
 'xf_layers': 16,
 'xf_heads': 8,
 'xf_final_ln': True,
 'xf_padding': True,
 'diffusion_steps': 1000,
 'noise_schedule': 'squaredcos_cap_v2',
 'timestep_respacing': '',
 'use_scale_shift_norm': True,
 'resblock_updown': True,
 'use_fp16': True,
 'cache_text_emb': False,
 'inpaint': False,
 'super_res': False}

In [11]:
defaults = dict(
    data_dir="",
    schedule_sampler="uniform",
    lr=1e-4,
    weight_decay=0.0,
    lr_anneal_steps=0,
    batch_size=1,
    microbatch=-1,  # -1 disables microbatches
    ema_rate="0.9999",  # comma-separated list of EMA values
    log_interval=10,
    save_interval=10000,
    resume_checkpoint="",
    use_fp16=options['use_fp16'],
    fp16_scale_growth=1e-3,
)

In [12]:
def create_gaussian_diffusion(
    *,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="",
):
    '''
    Copied from `https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py`.
    Change `SpacedDiffusion` to our custom `TrainableSpacedDiffusion`.
    '''
    betas = get_named_beta_schedule(noise_schedule, steps)
    if use_kl:
        loss_type = LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = LossType.RESCALED_MSE
    else:
        loss_type = LossType.MSE
    if not timestep_respacing:
        timestep_respacing = [steps]
    return SpacedDiffusion( # TrainableSpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            ModelMeanType.EPSILON if not predict_xstart else ModelMeanType.START_X
        ),
        model_var_type=(
            (
                ModelVarType.FIXED_LARGE
                if not sigma_small
                else ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )

In [21]:
def create_model_and_diffusion(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult,
    num_heads,
    num_head_channels,
    num_heads_upsample,
    attention_resolutions,
    dropout,
    text_ctx,
    xf_width,
    xf_layers,
    xf_heads,
    xf_final_ln,
    xf_padding,
    diffusion_steps,
    noise_schedule,
    
    learn_sigma,
    sigma_small,
    use_kl,
    predict_xstart,
    rescale_timesteps,
    rescale_learned_sigmas,
    
    timestep_respacing,
    use_scale_shift_norm,
    resblock_updown,
    use_fp16,
    cache_text_emb,
    inpaint,
    super_res,
):
    '''
    https://github.com/openai/glide-text2im/blob/9cc8e563851bd38f5ddb3e305127192cb0f02f5c/glide_text2im/model_creation.py#L54
    '''
    model = create_model(
        image_size,
        num_channels,
        num_res_blocks,
        channel_mult=channel_mult,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,
        text_ctx=text_ctx,
        xf_width=xf_width,
        xf_layers=xf_layers,
        xf_heads=xf_heads,
        xf_final_ln=xf_final_ln,
        xf_padding=xf_padding,
        resblock_updown=resblock_updown,
        use_fp16=use_fp16,
        cache_text_emb=cache_text_emb,
        inpaint=inpaint,
        super_res=super_res,
    )
    diffusion = create_gaussian_diffusion(
        steps=diffusion_steps,
        noise_schedule=noise_schedule,
        learn_sigma=learn_sigma,
        sigma_small=sigma_small,
        use_kl=use_kl,
        predict_xstart=predict_xstart,
        rescale_timesteps=rescale_timesteps,
        rescale_learned_sigmas=rescale_learned_sigmas,
        timestep_respacing=timestep_respacing,
    )
    return model, diffusion

In [22]:
model, diffusion = create_model_and_diffusion(
    **options, learn_sigma=True, sigma_small=False, use_kl=False, 
    predict_xstart=False, rescale_timesteps=False, rescale_learned_sigmas=False
)

In [18]:
# if options['use_fp16']:
#     model.convert_to_fp16()

In [23]:
img_size = 256
train_transform = transforms.Compose(
    [transforms.Resize(img_size),
     transforms.RandomCrop(img_size),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)
valid_transform = transforms.Compose(
    [transforms.Resize(img_size),
     transforms.CenterCrop(img_size),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

parser = argparse.ArgumentParser()

# dataloader args
parser.add_argument('--data-dir', type=str, default="../data")
parser.add_argument('--train-filename', type=str, default="../data/train.txt")
parser.add_argument('--valid-filename', type=str, default="../data/validation.txt")
parser.add_argument('--img-key', type=str, default="images")
parser.add_argument('--caption-key', type=str, default="captions")
parser.add_argument("--csv-separator", type=str, default=" ")
parser.add_argument('--train-batch_size', type=int, default=defaults['batch_size'])
parser.add_argument('--valid-batch-size', type=int, default=defaults['batch_size'])
# transforms.ToTensor() changes input format from H x W x C to C x H x W
# don't use below argument if input format required is C x H x W
# use it to convert back from C x H x W
# vqgan_jax requires H x W x C format, so set it to run this script
parser.add_argument('--permute', action='store_true')

args = parser.parse_args("")

def tokenize(text):
    tokens = model.tokenizer.encode(text)
    tokens, mask = model.tokenizer.padded_tokens_and_mask(
        tokens, options['text_ctx']
    )
    cond = {'tokens': tokens, 'mask': mask}
    return cond

data = get_data(args, (train_transform, valid_transform), 'glide', tokenize=tokenize)
data.setup()
data = iter(data.train_dataloader())

In [24]:
schedule_sampler = create_named_schedule_sampler("uniform", diffusion)

In [25]:
dist_util.setup_dist()
logger.configure()

model = model.to(dist_util.dev())

Logging to /tmp/openai-2022-03-02-16-32-32-734142


In [26]:
train_loop = TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=defaults['batch_size'],
    microbatch=defaults['microbatch'],
    lr=defaults['lr'],
    ema_rate=defaults['ema_rate'],
    log_interval=defaults['log_interval'],
    save_interval=defaults['save_interval'],
    resume_checkpoint=defaults['resume_checkpoint'],
    use_fp16=defaults['use_fp16'],
    fp16_scale_growth=defaults['fp16_scale_growth'],
    schedule_sampler=schedule_sampler,
    weight_decay=defaults['weight_decay'],
    lr_anneal_steps=defaults['lr_anneal_steps'],
)

In [27]:
train_loop.run_loop()

----------------------------
| grad_norm     | 4.32     |
| lg_loss_scale | 20       |
| loss          | 1.01     |
| loss_q2       | 1.01     |
| mse           | 1        |
| mse_q2        | 1        |
| param_norm    | 5.09e+03 |
| samples       | 1        |
| step          | 0        |
| vb            | 0.00454  |
| vb_q2         | 0.00454  |
----------------------------
saving model 0...
saving model 0.9999...
----------------------------
| grad_norm     | 4        |
| lg_loss_scale | 20       |
| loss          | 0.867    |
| loss_q0       | 0.893    |
| loss_q1       | 0.867    |
| loss_q2       | 0.869    |
| loss_q3       | 0.848    |
| mse           | 0.859    |
| mse_q0        | 0.878    |
| mse_q1        | 0.862    |
| mse_q2        | 0.865    |
| mse_q3        | 0.835    |
| param_norm    | 5.09e+03 |
| samples       | 11       |
| step          | 10       |
| vb            | 0.00717  |
| vb_q0         | 0.0155   |
| vb_q1         | 0.00439  |
| vb_q2         | 0.00439  |
| 

KeyboardInterrupt: 

In [None]:
'''
def run_loop(self):
    while (
        not self.lr_anneal_steps
        or self.step + self.resume_step < self.lr_anneal_steps
    ):
        batch, cond = next(self.data)
        self.run_step(batch, cond)
        if self.step % self.log_interval == 0:
            logger.dumpkvs()
        if self.step % self.save_interval == 0:
            self.save()
            # Run for a finite amount of time in integration tests.
            if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
                return
        self.step += 1
    # Save the last checkpoint if it wasn't already saved.
    if (self.step - 1) % self.save_interval != 0:
        self.save()
        
def run_step(self, batch, cond):
    self.forward_backward(batch, cond)
    took_step = self.mp_trainer.optimize(self.opt)
    if took_step:
        self._update_ema()
    self._anneal_lr()
    self.log_step()

def forward_backward(self, batch, cond):
    self.mp_trainer.zero_grad()
    for i in range(0, batch.shape[0], self.microbatch):
        micro = batch[i : i + self.microbatch].to(dist_util.dev())
        micro_cond = {
            k: v[i : i + self.microbatch].to(dist_util.dev())
            for k, v in cond.items()
        }
        last_batch = (i + self.microbatch) >= batch.shape[0]
        t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())

        compute_losses = functools.partial(
            self.diffusion.training_losses,
            self.ddp_model,
            micro,
            t,
            model_kwargs=micro_cond,
        )

        if last_batch or not self.use_ddp:
            losses = compute_losses()
        else:
            with self.ddp_model.no_sync():
                losses = compute_losses()

        if isinstance(self.schedule_sampler, LossAwareSampler):
            self.schedule_sampler.update_with_local_losses(
                t, losses["loss"].detach()
            )

        loss = (losses["loss"] * weights).mean()
        log_loss_dict(
            self.diffusion, t, {k: v * weights for k, v in losses.items()}
        )
        self.mp_trainer.backward(loss)
'''

In [47]:
batch, cond = next(train_loop.data)

In [48]:
batch = batch.to(dist_util.dev())
cond = {
    k: v.to(dist_util.dev())
    for k, v in cond.items()
}

In [49]:
train_loop.run_step(batch, cond)

In [44]:
train_loop.forward_backward(batch, cond)

In [45]:
took_step = train_loop.mp_trainer.optimize(train_loop.opt)

In [31]:
t, weights = train_loop.schedule_sampler.sample(batch.shape[0], dist_util.dev())

In [32]:
compute_losses = functools.partial(
    train_loop.diffusion.training_losses,
    train_loop.ddp_model,
    batch,
    t,
    model_kwargs=cond,
)

In [33]:
losses = compute_losses()

In [34]:
losses

{'vb': tensor([0.0667], device='cuda:0', grad_fn=<SWhereBackward>),
 'mse': tensor([0.9988], device='cuda:0', grad_fn=<MeanBackward1>),
 'loss': tensor([1.0654], device='cuda:0', grad_fn=<AddBackward0>)}

In [35]:
loss = (losses["loss"] * weights).mean()

In [37]:
train_loop.mp_trainer.backward(loss)

In [39]:
took_step = train_loop.mp_trainer.optimize(train_loop.opt)

In [40]:
took_step

True

In [276]:
noise = torch.randn_like(batch)
x_t = train_loop.diffusion.q_sample(batch, t, noise=noise)

In [277]:
x_t.shape

torch.Size([1, 3, 256, 256])

In [279]:
scaled_t = train_loop.diffusion._scale_timesteps(t)

In [280]:
print(scaled_t, t)

tensor([164], device='cuda:0') tensor([164], device='cuda:0')


In [None]:
def forward(self, x, timesteps, tokens=None, mask=None):
        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
        if self.xf_width:
            text_outputs = self.get_text_emb(tokens, mask)
            xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
            emb = emb + xf_proj.to(emb)
        else:
            xf_out = None
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, xf_out)
            hs.append(h)
        h = self.middle_block(h, emb, xf_out)
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, xf_out)
        h = h.type(x.dtype)
        h = self.out(h)
        return h

In [281]:
emb = train_loop.model.time_embed(timestep_embedding(scaled_t, train_loop.model.model_channels))
emb.shape

torch.Size([1, 768])

In [283]:
xf_in = train_loop.model.token_embedding(cond['tokens'].long())

In [284]:
train_loop.model.padding_embedding[None].dtype

torch.float16

In [135]:
xf_in.dtype

torch.float16

In [286]:
xf_in = torch.where(cond['mask'][..., None], xf_in, train_loop.model.padding_embedding[None].to(torch.float16))

In [288]:
xf_out = train_loop.model.transformer(xf_in.to(train_loop.model.dtype))

In [290]:
xf_out = train_loop.model.final_ln(xf_out)

In [292]:
xf_proj = train_loop.model.transformer_proj(xf_out[:, -1])

In [295]:
xf_out = xf_out.permute(0, 2, 1)

In [296]:
text_outputs = dict(xf_proj=xf_proj, xf_out=xf_out)

In [297]:
xf_proj, xf_out = text_outputs["xf_proj"], text_outputs["xf_out"]
emb = emb + xf_proj.to(emb)

In [304]:
emb

tensor([[-6.8044e-01,  1.0896e+00, -2.1173e-01, -6.7466e-01,  1.2109e-01,
         -1.0182e+00,  3.3557e-01,  1.2053e-01,  4.7135e-01,  4.6891e-01,
         -2.3416e-01, -2.4081e-01,  2.1960e-01,  7.8206e-01,  5.8388e-01,
         -5.7195e-01, -1.7002e-01,  1.6693e-01, -1.9188e-01,  2.0491e-01,
         -2.0205e-01, -1.2121e-01, -4.7672e-01,  7.2354e-01, -1.1115e-01,
         -1.8986e-01, -3.9922e-01, -1.2227e+00, -3.5668e-01, -5.0049e-01,
         -4.0665e-01,  3.5411e-01,  5.6386e-01, -7.2206e-01, -3.0089e-01,
          1.1960e-01,  8.4087e-01, -9.3811e-01,  2.3384e-01,  8.7181e-01,
          2.2939e-01,  4.2666e-01,  1.4032e-01,  1.0607e+00, -5.5759e-01,
         -6.1131e-01, -1.3434e-01, -8.0868e-01,  8.8383e-02, -9.3230e-02,
          4.0284e-01,  3.6635e-01,  4.4648e-01,  5.7132e-02,  6.7303e-01,
          3.2172e-01, -3.0984e-01, -8.0393e-02, -1.8968e-01, -8.5095e-03,
         -7.9837e-01, -7.4838e-02, -1.1586e+00,  1.7473e-01,  2.9347e-02,
          2.0359e-01,  3.8169e-01, -5.

In [338]:
h = batch.type(train_loop.model.dtype)

In [339]:
hs = []

In [340]:
for module in train_loop.model.input_blocks:
    h = module(h, emb, xf_out)
    hs.append(h)

In [341]:
h = train_loop.model.middle_block(h, emb, xf_out)

In [342]:
for module in train_loop.model.output_blocks:
    h = torch.cat([h, hs.pop()], dim=1)
    h = module(h, emb, xf_out)

In [343]:
h.dtype

torch.float16

In [344]:
h = h.type(batch.dtype)

In [345]:
h.dtype

torch.float32

In [346]:
h

tensor([[[[ 0.0950,  0.1895,  0.1876,  ..., -0.2178, -0.1205, -0.0816],
          [ 0.1599,  0.2959,  0.2949,  ..., -0.1941, -0.0786, -0.0690],
          [ 0.1593,  0.3059,  0.2979,  ..., -0.0363,  0.0039,  0.0089],
          ...,
          [ 0.1442,  0.2322,  0.2067,  ...,  0.0978,  0.0929,  0.0916],
          [ 0.1343,  0.2094,  0.2346,  ...,  0.0778,  0.1064,  0.0715],
          [ 0.0995,  0.1720,  0.1879,  ...,  0.0553,  0.0612,  0.0157]],

         [[ 0.1814,  0.3132,  0.3059,  ..., -0.2178, -0.1222, -0.0175],
          [ 0.2476,  0.4126,  0.4036,  ..., -0.2080, -0.1019, -0.0074],
          [ 0.2529,  0.4202,  0.4089,  ..., -0.0117,  0.0224,  0.0509],
          ...,
          [ 0.2236,  0.3542,  0.3059,  ...,  0.1700,  0.1682,  0.0909],
          [ 0.2048,  0.3564,  0.3418,  ...,  0.1592,  0.1703,  0.0845],
          [ 0.1147,  0.2013,  0.2026,  ...,  0.0837,  0.0842,  0.0389]],

         [[ 0.0740,  0.1656,  0.1617,  ..., -0.2262, -0.1884, -0.1265],
          [ 0.0842,  0.2595,  

In [347]:
h * torch.sigmoid(h)

tensor([[[[ 0.0497,  0.1037,  0.1026,  ..., -0.0971, -0.0566, -0.0391],
          [ 0.0863,  0.1697,  0.1690,  ..., -0.0877, -0.0377, -0.0333],
          [ 0.0860,  0.1762,  0.1709,  ..., -0.0178,  0.0019,  0.0044],
          ...,
          [ 0.0773,  0.1295,  0.1140,  ...,  0.0513,  0.0486,  0.0479],
          [ 0.0716,  0.1156,  0.1310,  ...,  0.0404,  0.0561,  0.0370],
          [ 0.0522,  0.0934,  0.1027,  ...,  0.0284,  0.0316,  0.0079]],

         [[ 0.0989,  0.1809,  0.1762,  ..., -0.0971, -0.0574, -0.0087],
          [ 0.1390,  0.2483,  0.2420,  ..., -0.0932, -0.0483, -0.0037],
          [ 0.1424,  0.2536,  0.2457,  ..., -0.0058,  0.0113,  0.0261],
          ...,
          [ 0.1243,  0.2082,  0.1762,  ...,  0.0922,  0.0912,  0.0475],
          [ 0.1129,  0.2097,  0.1998,  ...,  0.0859,  0.0924,  0.0440],
          [ 0.0607,  0.1107,  0.1115,  ...,  0.0436,  0.0439,  0.0198]],

         [[ 0.0384,  0.0897,  0.0874,  ..., -0.1004, -0.0853, -0.0592],
          [ 0.0439,  0.1465,  

In [353]:
zero_module(torch.nn.Conv2d(3, 4, 3)).weight

Parameter containing:
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],


        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], requires_grad=True)

In [336]:
h = train_loop.model.out(h)

In [337]:
h

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          

In [302]:
model_output = train_loop.model(x_t, scaled_t, **cond)

In [303]:
model_output

tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, 

In [299]:
model_output = train_loop.ddp_model(x_t, scaled_t, **cond)

In [300]:
model_output

tensor([[[[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          ...,
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, nan, nan]],

         [[nan, nan, nan,  ..., nan, nan, nan],
          [nan, nan, nan,  ..., nan, 

In [90]:
B, C = x_t.shape[:2]
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)

In [93]:
terms = {}

In [94]:
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
terms["vb"] = train_loop.diffusion._vb_terms_bpd(
    model=lambda *args, r=frozen_out: r,
    x_start=batch,
    x_t=x_t,
    t=t,
    clip_denoised=False,
)["output"]

In [96]:
terms['vb']

tensor([nan], device='cuda:0', grad_fn=<SWhereBackward>)

In [86]:
batch.shape

torch.Size([1, 3, 256, 256])

In [91]:
model_output.shape

torch.Size([1, 3, 256, 256])

In [88]:
noise.shape

torch.Size([1, 3, 256, 256])

In [97]:
train_loop.diffusion.loss_type

<LossType.MSE: 1>

In [98]:
target = noise

In [99]:
terms["mse"] = mean_flat((target - model_output) ** 2)

In [100]:
terms["mse"]

tensor([nan], device='cuda:0', grad_fn=<MeanBackward1>)

In [102]:
terms["loss"] = terms["mse"] + terms["vb"]

In [49]:
losses = compute_losses()

AssertionError: 

In [None]:
train_loop.run_loop()