Skip to content

Commit

Permalink
Add non-leaky augmentations (Config F), mapping conditioning, and U-N…
Browse files Browse the repository at this point in the history
…et input conditioning
  • Loading branch information
crowsonkb committed Jun 24, 2022
1 parent fbf8067 commit 0e7f4f3
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch.

**This repo is a work in progress** (models may break on later versions, script options may change). Also Config F is not currently implemented, this repo currently implements Config E.
**This repo is a work in progress** (models may break on later versions, script options may change).

Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running:

Expand Down
2 changes: 1 addition & 1 deletion k_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import evaluation, gns, layers, models, sampling, utils
from . import augmentation, evaluation, gns, layers, models, sampling, utils
from .layers import Denoiser
95 changes: 95 additions & 0 deletions k_diffusion/augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from functools import reduce
import math
import operator

import numpy as np
from skimage import transform
import torch
from torch import nn


def translate2d(tx, ty):
mat = [[1, 0, tx],
[0, 1, ty],
[0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)


def scale2d(sx, sy):
mat = [[sx, 0, 0],
[ 0, sy, 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)


def rotate2d(theta):
mat = [[torch.cos(theta), torch.sin(-theta), 0],
[torch.sin(theta), torch.cos(theta), 0],
[ 0, 0, 1]]
return torch.tensor(mat, dtype=torch.float32)


class KarrasAugmentationPipeline:
def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8):
self.a_prob = a_prob
self.a_scale = a_scale
self.a_aniso = a_aniso
self.a_trans = a_trans

def __call__(self, image):
h, w = image.size
mats = []

# x-flip
a0 = torch.randint(2, []).float()
mats.append(scale2d(1 - 2 * a0, 1))
# y-flip
do = (torch.rand([]) < self.a_prob).float()
a1 = torch.randint(2, []).float() * do
mats.append(scale2d(1, 1 - 2 * a1))
# scaling
do = (torch.rand([]) < self.a_prob).float()
a2 = torch.randn([]) * do
mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2))
# rotation
do = (torch.rand([]) < self.a_prob).float()
a3 = torch.rand([]) * (math.pi * 2 - math.pi) * do
mats.append(rotate2d(-a3))
# anisotropy
do = (torch.rand([]) < self.a_prob).float()
a4 = torch.rand([]) * (math.pi * 2 - math.pi) * do
a5 = torch.randn([]) * do
mats.append(rotate2d(a4))
mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5))
mats.append(rotate2d(-a4))
# translation
do = (torch.rand([]) < self.a_prob).float()
a6 = torch.randn([]) * do
a7 = torch.randn([]) * do
mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7))

# form the transformation matrix and conditioning vector
mat = reduce(operator.matmul, mats)
cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7])

# apply the transformation
image = np.array(image, dtype=np.float32) / 255
tf = transform.AffineTransform(mat.numpy())
image = transform.warp(image, tf, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True)
image = torch.as_tensor(image).movedim(2, 0) * 2 - 1
return image, cond


class KarrasAugmentWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model

def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs):
if aug_cond is None:
aug_cond = input.new_zeros([input.shape[0], 9])
if mapping_cond is None:
mapping_cond = aug_cond
else:
mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1)
return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs)
17 changes: 11 additions & 6 deletions k_diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ def __init__(self, feats_in, feats_out, n_layers=2):


class ImageDenoiserModel(nn.Module):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, dropout_rate=0.):
def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, mapping_cond_dim=0, unet_cond_dim=0, dropout_rate=0.):
super().__init__()
self.timestep_embed = layers.FourierFeatures(1, 256)
self.mapping = MappingNet(256, feats_in)
self.proj_in = nn.Conv2d(c_in, channels[0], 1)
self.timestep_embed = layers.FourierFeatures(1, feats_in)
if mapping_cond_dim > 0:
self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False)
self.mapping = MappingNet(feats_in, feats_in)
self.proj_in = nn.Conv2d(c_in + unet_cond_dim, channels[0], 1)
self.proj_out = nn.Conv2d(channels[0], c_in, 1)
nn.init.zeros_(self.proj_out.weight)
nn.init.zeros_(self.proj_out.bias)
Expand All @@ -86,11 +88,14 @@ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, dropout_r
u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > 0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate))
self.u_net = layers.UNet(d_blocks, reversed(u_blocks))

def forward(self, input, sigma):
def forward(self, input, sigma, mapping_cond=None, unet_cond=None):
c_noise = sigma.log() / 4
timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2))
mapping_out = self.mapping(timestep_embed)
mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond)
mapping_out = self.mapping(timestep_embed + mapping_cond_embed)
cond = {'cond': mapping_out}
if unet_cond is not None:
input = torch.cat([input, unet_cond], dim=1)
input = self.proj_in(input)
input = self.u_net(input, cond)
input = self.proj_out(input)
Expand Down
2 changes: 2 additions & 0 deletions model_configs/model_config_32x32_small.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"depths": [2, 4, 4],
"channels": [128, 256, 512],
"self_attn_depths": [false, true, true],
"dropout_rate": 0.0,
"augment_prob": 0.12,
"sigma_data": 0.5,
"sigma_min": 1e-2,
"sigma_max": 80,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ accelerate
einops
Pillow
resize-right
scikit-image
scipy
torch
torchvision
Expand Down
33 changes: 24 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def main():
help='the checkpoint to resume from')
p.add_argument('--save-every', type=int, default=10000,
help='save every this many steps')
p.add_argument('--start-method', type=str, default='spawn',
choices=['fork', 'forkserver', 'spawn'],
help='the multiprocessing start method')
p.add_argument('--train-set', type=str, required=True,
help='the training set location')
p.add_argument('--wandb-entity', type=str,
Expand All @@ -53,9 +56,10 @@ def main():
help='the wandb project name (specify this to enable wandb)')
p.add_argument('--wandb-save-model', action='store_true',
help='save model to wandb')

args = p.parse_args()

mp.set_start_method(args.start_method)

model_config = json.load(open(args.model_config))
# TODO: allow non-square input sizes
assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1]
Expand All @@ -71,8 +75,11 @@ def main():
model_config['mapping_out'],
model_config['depths'],
model_config['channels'],
model_config['self_attn_depths']
model_config['self_attn_depths'],
dropout_rate=model_config['dropout_rate'],
mapping_cond_dim=9,
)
inner_model = K.augmentation.KarrasAugmentWrapper(inner_model)
accelerator.print('Parameters:', K.utils.n_params(inner_model))

# If logging to wandb, initialize the run
Expand All @@ -88,17 +95,26 @@ def main():
sched = K.utils.InverseLR(opt, inv_gamma=50000, power=1/2, warmup=0.99)
ema_sched = K.utils.EMAWarmup(power=2/3, max_value=0.9999)

tf = transforms.Compose([
tf_no_aug = transforms.Compose([
transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(size[0]),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
train_set_no_aug = datasets.ImageFolder(args.train_set, transform=tf_no_aug)
train_dl_no_aug = data.DataLoader(train_set_no_aug, args.batch_size, shuffle=True,
num_workers=args.num_workers)

tf = transforms.Compose([
transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop(size[0]),
K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']),
])
train_set = datasets.ImageFolder(args.train_set, transform=tf)
train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True,
num_workers=args.num_workers, persistent_workers=True)

inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl)
inner_model, opt, train_dl, train_dl_no_aug = accelerator.prepare(inner_model, opt, train_dl, train_dl_no_aug)
if use_wandb:
wandb.watch(inner_model)
if args.gns:
Expand Down Expand Up @@ -126,9 +142,9 @@ def main():
step = 0

extractor = K.evaluation.InceptionV3FeatureExtractor(device=device)
train_iter = iter(train_dl)
train_iter_no_aug = iter(train_dl_no_aug)
accelerator.print('Computing features for reals...')
reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[0], extractor, args.evaluate_n, args.batch_size)
reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter_no_aug)[0], extractor, args.evaluate_n, args.batch_size)
if accelerator.is_main_process:
metrics_log_filepath = Path(f'{args.name}_metrics.csv')
if metrics_log_filepath.exists():
Expand Down Expand Up @@ -202,10 +218,10 @@ def save():
while True:
for batch in tqdm(train_dl, disable=not accelerator.is_local_main_process):
opt.zero_grad()
reals = batch[0].to(device)
reals, aug_cond = batch[0]
noise = torch.randn_like(reals)
sigma = torch.distributions.LogNormal(sigma_mean, sigma_std).sample([reals.shape[0]]).to(device)
loss = model.loss(reals, noise, sigma).mean()
loss = model.loss(reals, noise, sigma, aug_cond=aug_cond).mean()
accelerator.backward(loss)
if args.gns:
sq_norm_small_batch, sq_norm_large_batch = accelerator.reduce(gns_stats_hook.get_stats(), 'mean').tolist()
Expand Down Expand Up @@ -250,5 +266,4 @@ def save():


if __name__ == '__main__':
mp.set_start_method('spawn')
main()

0 comments on commit 0e7f4f3

Please sign in to comment.