In [None]:
# https://github.com/phizaz/diffae/blob/master/model/nn.py#L123

# https://github.com/phizaz/diffae/blob/master/model/unet_autoenc.py#L155
# _t_emb = timestep_embedding(t, self.conf.model_channels).to(dtype=cond.dtype)
# _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels).to(dtype=cond.dtype)

# https://github.com/phizaz/diffae/blob/master/experiment.py#L860
# data = target_dict[key].data
# target_dict[key].data.copy_(data * decay + source_dict[key].data.to(dtype=data.dtype, device=data.device) * (1 - decay))

In [None]:
import os, sys
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

sys.path.append(os.path.abspath('./diffae/'))

from templates import ffhq256_autoenc, LitModel, WarmupLR, ema
from choices import ModelMeanType, LossType, OptimizerType, TrainMode
from model.nn import mean_flat

In [None]:
import numpy as np
import torch
from torchvision.transforms import functional as VF
from torchvision.transforms import RandomResizedCrop
from PIL import Image
from matplotlib import pyplot
import json
from tqdm.auto import tqdm

In [None]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, file_paths, target_size, crops):
        
        self.file_paths = file_paths
        self.target_size = target_size
        self.crops = crops
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, index):
        
        file_path = self.file_paths[index]
        crop = self.crops[index]

        image = Image.open(file_path).convert('RGB')
        if len(crop) == 4:
            image = image.crop(crop)

        if np.random.random() < 0.5:
            image = VF.hflip(image)
            
        scale = min((self.target_size / min(image.size)) ** 2, .5)
        
        image = VF.to_tensor(image)

        params = RandomResizedCrop.get_params(image, scale=(scale, 1.), ratio=(1.,1.))
        image = VF.resized_crop(image, *params, size=self.target_size, interpolation=VF.InterpolationMode.BICUBIC, antialias=True)
        image = torch.clamp(image, 0., 1.)
        
        pixel_values = image * 2. - 1.

        return pixel_values

In [None]:
num_epochs = 10
gradient_accumulation_steps = 1

# https://github.com/phizaz/diffae/blob/master/experiment.py#L877

dtype = torch.float32
# dtype = torch.bfloat16

device = 'cuda'
conf = ffhq256_autoenc()

conf.batch_size = 16
# conf.model_conf.use_checkpoint = True

model = LitModel(conf)

In [None]:
model.model.to(device=device, dtype=dtype);
model.ema_model.to(device=device, dtype=dtype);

# model.model.dtype = torch.bfloat16
# model.ema_model.dtype = torch.bfloat16
# model.model.encoder.dtype = torch.bfloat16
# model.ema_model.encoder.dtype = torch.bfloat16

# for i in model.model.modules():
#     if 'GroupNorm32' in str(type(i)):
#         i.to(dtype=torch.float32)

In [None]:
file_paths = list()
crops = list()

root = '/opt/dataset/Super_Metroid_Redesign/'
for dname in os.listdir(root):
    if os.path.isfile(os.path.join(root, dname)):
        continue
    file_paths.extend([os.path.join(root, dname, i) for i in os.listdir(os.path.join(root, dname))])

root = '/opt/dataset/Metroid_Fusion/'
for dname in os.listdir(root):
    if os.path.isfile(os.path.join(root, dname)):
        continue
    file_paths.extend([os.path.join(root, dname, i) for i in os.listdir(os.path.join(root, dname))])

crops.extend([[]] * len(file_paths))


# roots = [
#     '/opt/dataset/cg/new/',
# ]
# for root in roots:
    
#     info = json.load(open(os.path.join(root, 'info.json')))
    
#     for item in tqdm(info, leave=False):
#         dname = item['name']
#         crop = item['crop']
#         paths = [os.path.join(root, 'keyframes', dname, i) for i in os.listdir(os.path.join(root, 'keyframes', dname))]
#         file_paths.extend(paths)
#         crops.extend([crop] * len(paths))


dataset = Dataset(file_paths=file_paths, target_size=conf.model_conf.image_size, crops=crops)

In [None]:
train_dataloader = torch.utils.data.DataLoader(
    dataset,
    shuffle=True,
    batch_size=conf.batch_size,
    num_workers=conf.num_workers,
    drop_last=True
)

In [None]:
# https://github.com/phizaz/diffae/blob/master/experiment.py#L633

if conf.optimizer == OptimizerType.adam:
    optim = torch.optim.Adam(model.model.parameters(),
                             lr=conf.lr,
                             weight_decay=conf.weight_decay,
                             fused=True,
                             # foreach=False
    )
elif conf.optimizer == OptimizerType.adamw:
    optim = torch.optim.AdamW(model.model.parameters(),
                              lr=conf.lr,
                              weight_decay=conf.weight_decay,
                              fused=True,
                              # foreach=False
    )
else:
    raise NotImplementedError()
    
if conf.warmup > 0:
    sched = torch.optim.lr_scheduler.LambdaLR(optim,
                                              lr_lambda=WarmupLR(conf.warmup))

In [None]:
model.model = torch.compile(model.model)

In [None]:
def train_step(model, x_start):

    # https://github.com/phizaz/diffae/blob/master/experiment.py#L374
    conf.train_mode.require_dataset_infer() # False 
    conf.train_mode # TrainMode.diffusion
    
    t, weight = model.T_sampler.sample(x_start.shape[0], device=device)
    
    # https://github.com/phizaz/diffae/blob/master/diffusion/base.py#L100
    
    noise = torch.randn_like(x_start)

    x_t = model.sampler.q_sample(x_start, t, noise=noise).to(dtype=x_start.dtype)
    
    terms = {'x_t': x_t}
    
    model.sampler.loss_type # LossType.mse

    model_forward = model.model.forward(x=x_t.detach(),
                                  t=model.sampler._scale_timesteps(t),
                                  x_start=x_start.detach())
    
    model_output = model_forward.pred
    
    # _model_output = model_output
    # if conf.train_pred_xstart_detach:
    #     _model_output = _model_output.detach()
    
    # # get the pred xstart
    # p_mean_var = model.sampler.p_mean_variance(
    #     model=DummyModel(pred=_model_output),
    #     # gradient goes through x_t
    #     x=x_t,
    #     t=t,
    #     clip_denoised=False)
    # terms['pred_xstart'] = p_mean_var['pred_xstart']
    
    target_types = {
        ModelMeanType.eps: noise,
    }
    target = target_types[model.sampler.model_mean_type]
    assert model_output.shape == target.shape == x_start.shape
    
    if model.sampler.loss_type == LossType.mse:
        if model.sampler.model_mean_type == ModelMeanType.eps:
            # (n, c, h, w) => (n, )
            terms["mse"] = mean_flat((target - model_output)**2)
        else:
            raise NotImplementedError()
    elif model.sampler.loss_type == LossType.l1:
        # (n, c, h, w) => (n, )
        terms["mse"] = mean_flat((target - model_output).abs())
    else:
        raise NotImplementedError()
    
    if "vb" in terms:
        # if learning the variance also use the vlb loss
        terms["loss"] = terms["mse"] + terms["vb"]
    else:
        terms["loss"] = terms["mse"]

    return terms

In [None]:
name = f'diffae_metroid'

In [None]:
max_train_steps = len(train_dataloader) * num_epochs * gradient_accumulation_steps

progress_bar = tqdm(range(0, max_train_steps), desc="Steps")

model.train()

epoch = 1
step = 0
data_iter = iter(train_dataloader)

losses = list()

with open(f'./{name}.log', 'w') as f:
    pass

while step < max_train_steps:

    try:
        batch = next(data_iter)
    except:
        torch.save(model.state_dict(), f"{name}_{epoch}.pth")
        epoch += 1
        data_iter = iter(train_dataloader)
        batch = next(data_iter)
    
    optim.zero_grad()

    x_start = batch.to(device=device, dtype=dtype)
    
    terms = train_step(model, x_start)
    loss = terms['loss'].mean()
    
    loss.backward()
    
    losses.append(loss.detach().float().cpu().numpy())

    step += 1
    
    if step % gradient_accumulation_steps == 0:
        
        # https://github.com/phizaz/diffae/blob/master/experiment.py#L433
        if hasattr(model, 'on_before_optimizer_step'):
            model.on_before_optimizer_step(optim, 0)
        
        optim.step()
    
        if conf.warmup > 0:
            sched.step()
    
        # https://github.com/phizaz/diffae/blob/master/experiment.py#L415
        # only apply ema on the last gradient accumulation step,
        # if it is the iteration that has optimizer.step()
        if conf.train_mode == TrainMode.latent_diffusion:
            # it trains only the latent hence change only the latent
            # ema(model.model.latent_net, model.ema_model.latent_net,
            #     conf.ema_decay)
            ema(model.model.latent_net._orig_mod, model.ema_model.latent_net,
                conf.ema_decay)
        else:
            # ema(model.model, model.ema_model, conf.ema_decay)
            ema(model.model._orig_mod, model.ema_model, conf.ema_decay)

        with open(f'./{name}.log', 'a') as f:
            f.write(f"Epoch: {epoch}, Step: {step + 1}, Loss: {np.mean(losses)}\n")
        losses = list()
        
        progress_bar.update(1)

In [None]:
torch.save(model.state_dict(), f"{name}_{epoch}.pth")