In [1]:
import os
import sys

In [2]:
paths = ['../src/', '../guided-diffusion/']
for path in paths:
    if os.path.abspath(path) not in sys.path:
        sys.path.insert(0, os.path.abspath(path))
        print(os.path.abspath(path))

/nfs/data/andrewbai/Natural-Disaster-Image-Generation-to-raise-Environmental-Awareness/src
/nfs/data/andrewbai/Natural-Disaster-Image-Generation-to-raise-Environmental-Awareness/guided-diffusion


In [3]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [4]:
import argparse
from importlib import reload
import numpy as np
import functools
import datetime

import torch
import torchvision
import torchvision.transforms as transforms

In [5]:
from glide_text2im.model_creation import model_and_diffusion_defaults, create_model
from glide_text2im.gaussian_diffusion import get_named_beta_schedule

from guided_diffusion import gaussian_diffusion as gd
from guided_diffusion.train_util import TrainLoop
from guided_diffusion.respace import space_timesteps, SpacedDiffusion
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion import dist_util, logger

In [6]:
from data import get_data

In [7]:
options = model_and_diffusion_defaults()

In [8]:
options['image_size'] = 256

In [9]:
train_options = dict(
    schedule_sampler="uniform",
    lr=1e-4,
    weight_decay=0.0,
    lr_anneal_steps=0,
    batch_size=4,
    microbatch=-1,  # -1 disables microbatches
    ema_rate="0.9999",  # comma-separated list of EMA values
    log_interval=10,
    save_interval=10000,
    resume_checkpoint="",
    learn_sigma=True, 
    sigma_small=False, 
    use_kl=False, 
    predict_xstart=False, 
    rescale_timesteps=False, 
    rescale_learned_sigmas=False,
    fp16_scale_growth=1e-3,
)

In [10]:
options.update(train_options)

In [11]:
def create_gaussian_diffusion(
    *,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="",
):
    '''
    Copied from `https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/script_util.py`.
    Some beta_scheduler in glide-text2im is not supported in guided-diffusion.
    '''
    betas = get_named_beta_schedule(noise_schedule, steps)
    if use_kl:
        loss_type = gd.LossType.RESCALED_KL
    elif rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE
    else:
        loss_type = gd.LossType.MSE
    if not timestep_respacing:
        timestep_respacing = [steps]
    return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
        ),
        model_var_type=(
            (
                gd.ModelVarType.FIXED_LARGE
                if not sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )
            if not learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),
        loss_type=loss_type,
        rescale_timesteps=rescale_timesteps,
    )

In [12]:
def create_model_and_diffusion(
    image_size,
    num_channels,
    num_res_blocks,
    channel_mult,
    num_heads,
    num_head_channels,
    num_heads_upsample,
    attention_resolutions,
    dropout,
    text_ctx,
    xf_width,
    xf_layers,
    xf_heads,
    xf_final_ln,
    xf_padding,
    diffusion_steps,
    noise_schedule,
    
    learn_sigma,
    sigma_small,
    use_kl,
    predict_xstart,
    rescale_timesteps,
    rescale_learned_sigmas,
    
    timestep_respacing,
    use_scale_shift_norm,
    resblock_updown,
    use_fp16,
    cache_text_emb,
    inpaint,
    super_res,
    **kwargs
):
    '''
    https://github.com/openai/glide-text2im/blob/9cc8e563851bd38f5ddb3e305127192cb0f02f5c/glide_text2im/model_creation.py#L54
    '''
    model = create_model(
        image_size,
        num_channels,
        num_res_blocks,
        channel_mult=channel_mult,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_head_channels=num_head_channels,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,
        text_ctx=text_ctx,
        xf_width=xf_width,
        xf_layers=xf_layers,
        xf_heads=xf_heads,
        xf_final_ln=xf_final_ln,
        xf_padding=xf_padding,
        resblock_updown=resblock_updown,
        use_fp16=use_fp16,
        cache_text_emb=cache_text_emb,
        inpaint=inpaint,
        super_res=super_res,
    )
    diffusion = create_gaussian_diffusion(
        steps=diffusion_steps,
        noise_schedule=noise_schedule,
        learn_sigma=learn_sigma,
        sigma_small=sigma_small,
        use_kl=use_kl,
        predict_xstart=predict_xstart,
        rescale_timesteps=rescale_timesteps,
        rescale_learned_sigmas=rescale_learned_sigmas,
        timestep_respacing=timestep_respacing,
    )
    return model, diffusion

In [13]:
model, diffusion = create_model_and_diffusion(**options)

In [14]:
# TODO: check preprocessing correct or not
train_transform = transforms.Compose(
    [transforms.Resize(options['image_size']),
     transforms.RandomCrop(options['image_size']),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)
valid_transform = transforms.Compose(
    [transforms.Resize(options['image_size']),
     transforms.CenterCrop(options['image_size']),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
)

parser = argparse.ArgumentParser()

# dataloader args
parser.add_argument('--data-dir', type=str, default="../data")
parser.add_argument('--train-filename', type=str, default="../data/train.txt")
parser.add_argument('--valid-filename', type=str, default="../data/validation.txt")
parser.add_argument('--img-key', type=str, default="images")
parser.add_argument('--caption-key', type=str, default="captions")
parser.add_argument("--csv-separator", type=str, default=" ")
parser.add_argument('--train-batch_size', type=int, default=options['batch_size'])
parser.add_argument('--valid-batch-size', type=int, default=options['batch_size'])
parser.add_argument('--permute', action='store_true')

args = parser.parse_args("")

def tokenize(text):
    tokens = model.tokenizer.encode(text)
    tokens, mask = model.tokenizer.padded_tokens_and_mask(
        tokens, options['text_ctx']
    )
    cond = {'tokens': tokens, 'mask': mask}
    return cond

data = get_data(args, (train_transform, valid_transform), 'glide', tokenize=tokenize)
data.setup()
data = iter(data.train_dataloader())

In [15]:
schedule_sampler = create_named_schedule_sampler("uniform", diffusion)

In [16]:
dist_util.setup_dist()
timestamp = datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")
logger.configure(dir=f'/home/andrewbai/glide_logs/{timestamp}')

model = model.to(dist_util.dev())

Logging to /home/andrewbai/glide_logs/openai-2022-03-02-16-53-45-029863


In [17]:
train_loop = TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=options['batch_size'],
    microbatch=options['microbatch'],
    lr=options['lr'],
    ema_rate=options['ema_rate'],
    log_interval=options['log_interval'],
    save_interval=options['save_interval'],
    resume_checkpoint=options['resume_checkpoint'],
    use_fp16=options['use_fp16'],
    fp16_scale_growth=options['fp16_scale_growth'],
    schedule_sampler=schedule_sampler,
    weight_decay=options['weight_decay'],
    lr_anneal_steps=options['lr_anneal_steps'],
)

In [18]:
train_loop.run_loop()

----------------------------
| grad_norm     | 3.55     |
| lg_loss_scale | 20       |
| loss          | 1.01     |
| loss_q0       | 1.01     |
| loss_q1       | 1        |
| loss_q3       | 1.01     |
| mse           | 0.998    |
| mse_q0        | 0.999    |
| mse_q1        | 0.997    |
| mse_q3        | 0.998    |
| param_norm    | 5.09e+03 |
| samples       | 4        |
| step          | 0        |
| vb            | 0.00942  |
| vb_q0         | 0.0138   |
| vb_q1         | 0.00452  |
| vb_q3         | 0.0097   |
----------------------------
saving model 0...
saving model 0.9999...
----------------------------
| grad_norm     | 3.54     |
| lg_loss_scale | 20       |
| loss          | 0.886    |
| loss_q0       | 0.952    |
| loss_q1       | 0.833    |
| loss_q2       | 0.894    |
| loss_q3       | 0.873    |
| mse           | 0.871    |
| mse_q0        | 0.923    |
| mse_q1        | 0.828    |
| mse_q2        | 0.889    |
| mse_q3        | 0.858    |
| param_norm    | 5.09e+03 |
| 

----------------------------
| grad_norm     | 0.494    |
| lg_loss_scale | 19.1     |
| loss          | 0.0752   |
| loss_q0       | 0.193    |
| loss_q1       | 0.0416   |
| loss_q2       | 0.0299   |
| loss_q3       | 0.0219   |
| mse           | 0.074    |
| mse_q0        | 0.19     |
| mse_q1        | 0.0414   |
| mse_q2        | 0.0298   |
| mse_q3        | 0.0208   |
| param_norm    | 5.09e+03 |
| samples       | 524      |
| step          | 130      |
| vb            | 0.00118  |
| vb_q0         | 0.00298  |
| vb_q1         | 0.000198 |
| vb_q2         | 0.000155 |
| vb_q3         | 0.00111  |
----------------------------


KeyboardInterrupt: 