In [1]:
import sys
import os

# Get the current working directory
cwd = os.getcwd()

# Go one level up
parent_dir = os.path.abspath(os.path.join(cwd, ".."))

# Add parent directory to Python path
if parent_dir not in sys.path:
    sys.path.append(parent_dir)


In [2]:
import os
import gc
import lpips
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision
import transformers
from torchvision.transforms.functional import crop
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from glob import glob
from einops import rearrange

import diffusers
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler

import wandb

from src.model import Difix, load_ckpt_from_state_dict, save_ckpt
from src.dataset import PairedDataset
from src.loss import gram_loss

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


## Arguments

In [3]:
from types import SimpleNamespace

In [4]:
args = {
    "lambda_lpips": 1.0,
    "lambda_l2": 1.0,
    "lambda_gram": 1.0,
    "gram_loss_warmup_steps": 2000,
    "dataset_path": "data/converted_dataset_fixed.json",
    "train_image_prep": "resized_crop_512",
    "test_image_prep": "resized_crop_512",
    "prompt": None,
    "eval_freq": 100,
    "num_samples_eval": 100,
    "viz_freq": 100,
    "tracker_project_name": "difix",
    "tracker_run_name": "train",
    "pretrained_model_name_or_path": None,
    "revision": None,
    "variant": None,
    "tokenizer_name": None,
    "lora_rank_vae": 4,
    "timestep": 199,
    "mv_unet": False,
    "output_dir": "./outputs/difix/train",
    "cache_dir": None,
    "seed": None,
    "resolution": 512,
    "train_batch_size": 1,
    "num_training_epochs": 10,
    "max_train_steps": 10000,
    "checkpointing_steps": 500,
    "gradient_accumulation_steps": 1,
    "gradient_checkpointing": False,
    "learning_rate": 5e-6,
    "lr_scheduler": "constant",
    "lr_warmup_steps": 500,
    "lr_num_cycles": 1,
    "lr_power": 1.0,
    "dataloader_num_workers": 4,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "adam_weight_decay": 1e-2,
    "adam_epsilon": 1e-8,
    "max_grad_norm": 1.0,
    "allow_tf32": False,
    "report_to": "wandb",
    "mixed_precision": "bf16",
    "enable_xformers_memory_efficient_attention": True,
    "set_grads_to_none": False,
    "resume": None,
}



args['output_dir'] = './outputs/difix/train'
args['dataset_path'] = "./data/converted_dataset_fixed.json"
args['max_train_steps'] = 10000 
args['resolution'] = 512 
args['learning_rate'] = 2e-5 
args['train_batch_size'] = 1 
args['dataloader_num_workers'] = 4 
args['enable_xformers_memory_efficient_attention'] = True
args['checkpointing_steps'] = 1000 
args['eval_freq'] = 1000 
args['viz_freq'] = 100 
args['lambda_lpips'] = 1.0 
args['lambda_l2'] = 1.0 
args['lambda_gram'] =  1.0 
args['gram_loss_warmup_steps'] = 2000
args['report_to']= "wandb" 
args['tracker_project_name'] = "difix" 
args['tracker_run_name'] = "train" 
args['timestep'] = 199

args = SimpleNamespace(**args)

In [5]:
accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
)

if accelerator.is_local_main_process:
    transformers.utils.logging.set_verbosity_warning()
    diffusers.utils.logging.set_verbosity_info()
else:
    transformers.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()

if args.seed is not None:
    set_seed(args.seed)

if accelerator.is_main_process:
    os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
    os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)

### Set Up Model ###
net_difix = Difix(
    lora_rank_vae=args.lora_rank_vae, 
    timestep=args.timestep,
    mv_unet=args.mv_unet,
)
net_difix.set_train()

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        net_difix.unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available, please install it by running `pip install xformers`")

if args.gradient_checkpointing:
    net_difix.unet.enable_gradient_checkpointing()

if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True

  torch.utils._pytree._register_pytree_node(
{'variance_type', 'clip_sample_range', 'thresholding', 'rescale_betas_zero_snr', 'dynamic_thresholding_ratio'} was not found in config. Values will be initialized to default values.


Initializing model with random weights
Number of trainable parameters in UNet: 865.91M
Number of trainable parameters in VAE: 0.52M


In [6]:
### Set up metrics
net_lpips = lpips.LPIPS(net='vgg').cuda()

net_lpips.requires_grad_(False)

net_vgg = torchvision.models.vgg16(pretrained=True).features
for param in net_vgg.parameters():
    param.requires_grad_(False)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/cx24957/miniconda3/envs/difix_env_fixed/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


In [7]:
#### make the optimizer
layers_to_opt = []
layers_to_opt += list(net_difix.unet.parameters())

for n, _p in net_difix.vae.named_parameters():
    if "lora" in n and "vae_skip" in n:
        assert _p.requires_grad
        layers_to_opt.append(_p)
layers_to_opt = layers_to_opt + list(net_difix.vae.decoder.skip_conv_1.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_2.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_3.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_4.parameters())

optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
    num_training_steps=args.max_train_steps * accelerator.num_processes,
    num_cycles=args.lr_num_cycles, power=args.lr_power,)

  return disable_fn(*args, **kwargs)


In [8]:
### Set up dataset
dataset_train = PairedDataset(dataset_path='/mnt/e/Difix3d/data/converted_dataset_fixed.json',
                              height=512,
                              width=512,
                              split="train",
                              tokenizer=net_difix.tokenizer)
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
dataset_val = PairedDataset(dataset_path='/mnt/e/Difix3d/data/converted_dataset_fixed.json',
                            height=512,
                            width=512,
                            split="test",
                            tokenizer=net_difix.tokenizer)
random.Random(42).shuffle(dataset_val.img_ids)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)

In [9]:
### Set up components on accelerator
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16

# Move al networksr to device and cast to weight_dtype
net_difix.to(accelerator.device, dtype=weight_dtype)
net_lpips.to(accelerator.device, dtype=weight_dtype)
net_vgg.to(accelerator.device, dtype=weight_dtype)

# Prepare everything with our `accelerator`.
net_difix, optimizer, dl_train, lr_scheduler = accelerator.prepare(
    net_difix, optimizer, dl_train, lr_scheduler
)
net_lpips, net_vgg = accelerator.prepare(net_lpips, net_vgg)
# renorm with image net statistics
t_vgg_renorm =  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))



In [10]:
# tracker and progress bar
if accelerator.is_main_process:
    init_kwargs = {
        "wandb": {
            "name": args.tracker_run_name,
            "dir": args.output_dir,
        },
    }        
    tracker_config = dict(vars(args))
    accelerator.init_trackers(args.tracker_project_name, config=tracker_config, init_kwargs=init_kwargs)

progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
    disable=not accelerator.is_local_main_process,)

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Steps:   0%|          | 0/10000 [00:00<?, ?it/s]

In [11]:
global_step = 0

In [None]:
# start the training loop
for epoch in tqdm(range(0, args.num_training_epochs)):
    for step, batch in tqdm(enumerate(dl_train)):
        l_acc = [net_difix]
        with accelerator.accumulate(*l_acc):
            x_src = batch["conditioning_pixel_values"]
            x_tgt = batch["output_pixel_values"]
            B, V, C, H, W = x_src.shape

            # forward pass
            x_tgt_pred = net_difix(x_src, prompt_tokens=batch["input_ids"])       
            
            x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w')
            x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w')
                        
            # Reconstruction loss
            loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
            loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
            loss = loss_l2 + loss_lpips
            
            # Gram matrix loss
            if args.lambda_gram > 0:
                if global_step > args.gram_loss_warmup_steps:
                    x_tgt_pred_renorm = t_vgg_renorm(x_tgt_pred * 0.5 + 0.5)
                    crop_h, crop_w = 400, 400
                    top, left = random.randint(0, H - crop_h), random.randint(0, W - crop_w)
                    x_tgt_pred_renorm = crop(x_tgt_pred_renorm, top, left, crop_h, crop_w)
                    
                    x_tgt_renorm = t_vgg_renorm(x_tgt * 0.5 + 0.5)
                    x_tgt_renorm = crop(x_tgt_renorm, top, left, crop_h, crop_w)
                    
                    loss_gram = gram_loss(x_tgt_pred_renorm.to(weight_dtype), x_tgt_renorm.to(weight_dtype), net_vgg) * args.lambda_gram
                    loss += loss_gram
                else:
                    loss_gram = torch.tensor(0.0).to(weight_dtype)                    

            accelerator.backward(loss, retain_graph=False)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad(set_to_none=args.set_grads_to_none)
            
            x_tgt = rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V)
            x_tgt_pred = rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V)

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1

            if accelerator.is_main_process:
                logs = {}
                # log all the losses
                logs["loss_l2"] = loss_l2.detach().item()
                logs["loss_lpips"] = loss_lpips.detach().item()
                if args.lambda_gram > 0:
                    logs["loss_gram"] = loss_gram.detach().item()
                progress_bar.set_postfix(**logs)

                # viz some images
                if global_step % args.viz_freq == 1:
                    print('uploading viz images')
                    log_dict = {
                        "train/source": [wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
                        "train/target": [wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
                        "train/model_output": [wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
                    }
                    for k in log_dict:
                        logs[k] = log_dict[k]

                # checkpoint the model
                if global_step % args.checkpointing_steps == 1:
                    outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
                    # accelerator.unwrap_model(net_difix).save_model(outf)
                    save_ckpt(accelerator.unwrap_model(net_difix), optimizer, outf)

                # compute validation set L2, LPIPS
                if args.eval_freq > 0 and global_step % args.eval_freq == 1:
                    l_l2, l_lpips = [], []
                    log_dict = {"sample/source": [], "sample/target": [], "sample/model_output": []}
                    for step, batch_val in enumerate(dl_val):
                        if step >= args.num_samples_eval:
                            break
                        x_src = batch_val["conditioning_pixel_values"].to(accelerator.device, dtype=weight_dtype)
                        x_tgt = batch_val["output_pixel_values"].to(accelerator.device, dtype=weight_dtype)
                        B, V, C, H, W = x_src.shape
                        assert B == 1, "Use batch size 1 for eval."
                        with torch.no_grad():
                            # forward pass
                            x_tgt_pred = accelerator.unwrap_model(net_difix)(x_src, prompt_tokens=batch_val["input_ids"].cuda())
                            
                            if step % 10 == 0:

                                log_dict["sample/source"].append(wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
                                log_dict["sample/target"].append(wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
                                log_dict["sample/model_output"].append(wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[0].float().detach().cpu(), caption=f"idx={len(log_dict['sample/source'])}"))
                            
                            x_tgt = x_tgt[:, 0] # take the input view
                            x_tgt_pred = x_tgt_pred[:, 0] # take the input view
                            # compute the reconstruction losses
                            loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean")
                            loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean()

                            l_l2.append(loss_l2.item())
                            l_lpips.append(loss_lpips.item())

                    logs["val/l2"] = np.mean(l_l2)
                    logs["val/lpips"] = np.mean(l_lpips)
                    for k in log_dict:
                        logs[k] = log_dict[k]
                    gc.collect()
                    torch.cuda.empty_cache()
                accelerator.log(logs, step=global_step)



[A



uploading viz images



Steps:   0%|          | 2/10000 [01:54<183:19:50, 66.01s/it, loss_gram=0, loss_l2=0.0309, loss_lpips=0.318]
Steps:   0%|          | 3/10000 [02:02<109:33:28, 39.45s/it, loss_gram=0, loss_l2=0.00927, loss_lpips=0.363]
Steps:   0%|          | 4/10000 [02:11<75:58:16, 27.36s/it, loss_gram=0, loss_l2=0.0105, loss_lpips=0.314]  
Steps:   0%|          | 5/10000 [02:20<57:46:20, 20.81s/it, loss_gram=0, loss_l2=0.016, loss_lpips=0.304] 
Steps:   0%|          | 6/10000 [02:29<46:10:44, 16.63s/it, loss_gram=0, loss_l2=0.0233, loss_lpips=0.29]
Steps:   0%|          | 7/10000 [02:38<39:17:21, 14.15s/it, loss_gram=0, loss_l2=0.015, loss_lpips=0.484]
Steps:   0%|          | 8/10000 [02:46<34:28:40, 12.42s/it, loss_gram=0, loss_l2=0.00506, loss_lpips=0.314]
Steps:   0%|          | 9/10000 [02:55<31:16:56, 11.27s/it, loss_gram=0, loss_l2=0.0241, loss_lpips=0.428] 
Steps:   0%|          | 10/10000 [03:04<28:49:29, 10.39s/it, loss_gram=0, loss_l2=0.00846, loss_lpips=0.239]
Steps:   0%|          | 11/10

uploading viz images



Steps:   1%|          | 102/10000 [26:05<57:02:15, 20.75s/it, loss_gram=0, loss_l2=0.00172, loss_lpips=0.134]
Steps:   1%|          | 103/10000 [26:25<57:10:25, 20.80s/it, loss_gram=0, loss_l2=0.0134, loss_lpips=0.125] 
Steps:   1%|          | 104/10000 [26:46<57:09:27, 20.79s/it, loss_gram=0, loss_l2=0.0194, loss_lpips=0.149]
Steps:   1%|          | 105/10000 [27:07<57:38:42, 20.97s/it, loss_gram=0, loss_l2=0.00184, loss_lpips=0.157]
Steps:   1%|          | 106/10000 [27:27<56:17:20, 20.48s/it, loss_gram=0, loss_l2=0.0142, loss_lpips=0.142] 
Steps:   1%|          | 107/10000 [27:48<57:06:40, 20.78s/it, loss_gram=0, loss_l2=0.00433, loss_lpips=0.147]
Steps:   1%|          | 108/10000 [28:09<56:54:11, 20.71s/it, loss_gram=0, loss_l2=0.00706, loss_lpips=0.101]
Steps:   1%|          | 109/10000 [28:28<55:41:17, 20.27s/it, loss_gram=0, loss_l2=0.0111, loss_lpips=0.148] 
Steps:   1%|          | 110/10000 [28:47<54:40:51, 19.90s/it, loss_gram=0, loss_l2=0.00485, loss_lpips=0.196]
Steps:   1

uploading viz images



Steps:   2%|▏         | 202/10000 [1:00:41<52:37:03, 19.33s/it, loss_gram=0, loss_l2=0.00697, loss_lpips=0.151]
Steps:   2%|▏         | 203/10000 [1:00:50<44:01:52, 16.18s/it, loss_gram=0, loss_l2=0.00725, loss_lpips=0.188]
Steps:   2%|▏         | 204/10000 [1:00:58<37:44:18, 13.87s/it, loss_gram=0, loss_l2=0.00578, loss_lpips=0.143]
Steps:   2%|▏         | 205/10000 [1:01:07<33:36:13, 12.35s/it, loss_gram=0, loss_l2=0.0208, loss_lpips=0.147] 
Steps:   2%|▏         | 206/10000 [1:01:16<30:20:24, 11.15s/it, loss_gram=0, loss_l2=0.00121, loss_lpips=0.0914]
Steps:   2%|▏         | 207/10000 [1:01:24<28:14:39, 10.38s/it, loss_gram=0, loss_l2=0.00926, loss_lpips=0.0989]
Steps:   2%|▏         | 208/10000 [1:01:33<26:40:19,  9.81s/it, loss_gram=0, loss_l2=0.00221, loss_lpips=0.0842]
Steps:   2%|▏         | 209/10000 [1:01:41<25:46:22,  9.48s/it, loss_gram=0, loss_l2=0.0203, loss_lpips=0.221]  
Steps:   2%|▏         | 210/10000 [1:01:50<24:56:28,  9.17s/it, loss_gram=0, loss_l2=0.00979, loss_

uploading viz images



Steps:   3%|▎         | 302/10000 [1:21:38<65:36:30, 24.35s/it, loss_gram=0, loss_l2=0.0113, loss_lpips=0.106]
Steps:   3%|▎         | 303/10000 [1:22:03<66:18:47, 24.62s/it, loss_gram=0, loss_l2=0.0122, loss_lpips=0.124]
Steps:   3%|▎         | 304/10000 [1:22:27<65:46:38, 24.42s/it, loss_gram=0, loss_l2=0.0216, loss_lpips=0.197]
Steps:   3%|▎         | 305/10000 [1:22:51<65:39:25, 24.38s/it, loss_gram=0, loss_l2=0.0114, loss_lpips=0.0959]
Steps:   3%|▎         | 306/10000 [1:23:16<66:14:54, 24.60s/it, loss_gram=0, loss_l2=0.00889, loss_lpips=0.0916]
Steps:   3%|▎         | 307/10000 [1:23:41<66:32:55, 24.72s/it, loss_gram=0, loss_l2=0.00231, loss_lpips=0.116] 
Steps:   3%|▎         | 308/10000 [1:24:06<65:59:21, 24.51s/it, loss_gram=0, loss_l2=0.0111, loss_lpips=0.14]  
Steps:   3%|▎         | 309/10000 [1:24:29<65:32:32, 24.35s/it, loss_gram=0, loss_l2=0.0254, loss_lpips=0.196]
Steps:   3%|▎         | 310/10000 [1:24:53<64:44:16, 24.05s/it, loss_gram=0, loss_l2=0.00878, loss_lpips=

uploading viz images



Steps:   4%|▍         | 402/10000 [1:56:29<54:06:25, 20.29s/it, loss_gram=0, loss_l2=0.00849, loss_lpips=0.137]
Steps:   4%|▍         | 403/10000 [1:56:49<54:07:30, 20.30s/it, loss_gram=0, loss_l2=0.0114, loss_lpips=0.139] 
Steps:   4%|▍         | 404/10000 [1:57:09<53:42:46, 20.15s/it, loss_gram=0, loss_l2=0.00394, loss_lpips=0.132]
Steps:   4%|▍         | 405/10000 [1:57:28<53:29:05, 20.07s/it, loss_gram=0, loss_l2=0.00924, loss_lpips=0.102]
Steps:   4%|▍         | 406/10000 [1:57:49<53:53:38, 20.22s/it, loss_gram=0, loss_l2=0.00873, loss_lpips=0.161]
Steps:   4%|▍         | 407/10000 [1:58:11<55:18:39, 20.76s/it, loss_gram=0, loss_l2=0.0123, loss_lpips=0.139] 
Steps:   4%|▍         | 408/10000 [1:58:31<54:39:53, 20.52s/it, loss_gram=0, loss_l2=0.013, loss_lpips=0.169] 
Steps:   4%|▍         | 409/10000 [1:58:50<53:40:19, 20.15s/it, loss_gram=0, loss_l2=0.00148, loss_lpips=0.114]
Steps:   4%|▍         | 410/10000 [1:59:10<53:26:33, 20.06s/it, loss_gram=0, loss_l2=0.00911, loss_lpips

uploading viz images



Steps:   5%|▌         | 502/10000 [2:30:18<55:19:19, 20.97s/it, loss_gram=0, loss_l2=0.0101, loss_lpips=0.0968]
Steps:   5%|▌         | 503/10000 [2:30:39<55:36:26, 21.08s/it, loss_gram=0, loss_l2=0.00536, loss_lpips=0.102]
Steps:   5%|▌         | 504/10000 [2:31:01<55:55:49, 21.20s/it, loss_gram=0, loss_l2=0.0105, loss_lpips=0.108] 
Steps:   5%|▌         | 505/10000 [2:31:21<55:18:42, 20.97s/it, loss_gram=0, loss_l2=0.0135, loss_lpips=0.134]
Steps:   5%|▌         | 506/10000 [2:31:42<55:03:01, 20.87s/it, loss_gram=0, loss_l2=0.00497, loss_lpips=0.116]
Steps:   5%|▌         | 507/10000 [2:32:03<55:29:21, 21.04s/it, loss_gram=0, loss_l2=0.029, loss_lpips=0.192]  
Steps:   5%|▌         | 508/10000 [2:32:24<55:01:38, 20.87s/it, loss_gram=0, loss_l2=0.0169, loss_lpips=0.164]
Steps:   5%|▌         | 509/10000 [2:32:45<55:10:39, 20.93s/it, loss_gram=0, loss_l2=0.0024, loss_lpips=0.0906]
Steps:   5%|▌         | 510/10000 [2:33:04<53:31:58, 20.31s/it, loss_gram=0, loss_l2=0.00974, loss_lpips=

uploading viz images



Steps:   6%|▌         | 602/10000 [3:04:11<54:32:58, 20.90s/it, loss_gram=0, loss_l2=0.000744, loss_lpips=0.133]
Steps:   6%|▌         | 603/10000 [3:04:32<54:34:26, 20.91s/it, loss_gram=0, loss_l2=0.0051, loss_lpips=0.122]  
Steps:   6%|▌         | 604/10000 [3:04:53<54:33:44, 20.91s/it, loss_gram=0, loss_l2=0.00702, loss_lpips=0.0881]
Steps:   6%|▌         | 605/10000 [3:05:13<54:07:30, 20.74s/it, loss_gram=0, loss_l2=0.0158, loss_lpips=0.219]  
Steps:   6%|▌         | 606/10000 [3:05:33<53:08:36, 20.37s/it, loss_gram=0, loss_l2=0.00644, loss_lpips=0.0994]
Steps:   6%|▌         | 607/10000 [3:05:52<52:10:47, 20.00s/it, loss_gram=0, loss_l2=0.0126, loss_lpips=0.147]  
Steps:   6%|▌         | 608/10000 [3:06:13<52:43:38, 20.21s/it, loss_gram=0, loss_l2=0.00821, loss_lpips=0.176]
Steps:   6%|▌         | 609/10000 [3:06:32<52:21:01, 20.07s/it, loss_gram=0, loss_l2=0.00315, loss_lpips=0.137]
Steps:   6%|▌         | 610/10000 [3:06:53<52:22:46, 20.08s/it, loss_gram=0, loss_l2=0.0584, loss

uploading viz images



Steps:   7%|▋         | 702/10000 [3:37:58<54:39:03, 21.16s/it, loss_gram=0, loss_l2=0.00621, loss_lpips=0.12]
Steps:   7%|▋         | 703/10000 [3:38:18<53:45:51, 20.82s/it, loss_gram=0, loss_l2=0.0508, loss_lpips=0.139]
Steps:   7%|▋         | 704/10000 [3:38:39<54:13:56, 21.00s/it, loss_gram=0, loss_l2=0.00759, loss_lpips=0.0854]
Steps:   7%|▋         | 705/10000 [3:38:59<53:08:28, 20.58s/it, loss_gram=0, loss_l2=0.00674, loss_lpips=0.0718]
Steps:   7%|▋         | 706/10000 [3:39:18<51:56:00, 20.12s/it, loss_gram=0, loss_l2=0.0116, loss_lpips=0.115]  
Steps:   7%|▋         | 707/10000 [3:39:39<52:40:38, 20.41s/it, loss_gram=0, loss_l2=0.00586, loss_lpips=0.0691]
Steps:   7%|▋         | 708/10000 [3:40:00<52:43:59, 20.43s/it, loss_gram=0, loss_l2=0.0088, loss_lpips=0.104]  
Steps:   7%|▋         | 709/10000 [3:40:20<52:27:23, 20.33s/it, loss_gram=0, loss_l2=0.00809, loss_lpips=0.111]
Steps:   7%|▋         | 710/10000 [3:40:41<53:45:27, 20.83s/it, loss_gram=0, loss_l2=0.00248, loss_l

uploading viz images



Steps:   8%|▊         | 802/10000 [4:11:47<52:45:50, 20.65s/it, loss_gram=0, loss_l2=0.00252, loss_lpips=0.0762]
Steps:   8%|▊         | 803/10000 [4:12:07<52:22:17, 20.50s/it, loss_gram=0, loss_l2=0.0451, loss_lpips=0.139]  
Steps:   8%|▊         | 804/10000 [4:12:26<51:23:17, 20.12s/it, loss_gram=0, loss_l2=0.000784, loss_lpips=0.0836]
Steps:   8%|▊         | 805/10000 [4:12:45<50:31:04, 19.78s/it, loss_gram=0, loss_l2=0.00126, loss_lpips=0.0699] 
Steps:   8%|▊         | 806/10000 [4:13:04<49:49:57, 19.51s/it, loss_gram=0, loss_l2=0.00695, loss_lpips=0.104] 
Steps:   8%|▊         | 807/10000 [4:13:24<49:54:56, 19.55s/it, loss_gram=0, loss_l2=0.000927, loss_lpips=0.182]
Steps:   8%|▊         | 808/10000 [4:13:45<51:02:31, 19.99s/it, loss_gram=0, loss_l2=0.00367, loss_lpips=0.146] 
Steps:   8%|▊         | 809/10000 [4:14:06<51:54:33, 20.33s/it, loss_gram=0, loss_l2=0.00711, loss_lpips=0.128]
Steps:   8%|▊         | 810/10000 [4:14:25<51:13:46, 20.07s/it, loss_gram=0, loss_l2=0.0133, l

uploading viz images



Steps:   9%|▉         | 902/10000 [4:45:23<50:42:46, 20.07s/it, loss_gram=0, loss_l2=0.00677, loss_lpips=0.118]  
Steps:   9%|▉         | 903/10000 [4:45:43<50:42:11, 20.07s/it, loss_gram=0, loss_l2=0.00637, loss_lpips=0.161]
Steps:   9%|▉         | 904/10000 [4:46:04<51:08:51, 20.24s/it, loss_gram=0, loss_l2=0.000977, loss_lpips=0.19]
Steps:   9%|▉         | 905/10000 [4:46:23<50:31:52, 20.00s/it, loss_gram=0, loss_l2=0.00149, loss_lpips=0.109]
Steps:   9%|▉         | 906/10000 [4:46:42<49:52:19, 19.74s/it, loss_gram=0, loss_l2=0.0165, loss_lpips=0.11]  
Steps:   9%|▉         | 907/10000 [4:47:06<52:30:44, 20.79s/it, loss_gram=0, loss_l2=0.0123, loss_lpips=0.2] 
Steps:   9%|▉         | 908/10000 [4:47:26<51:52:02, 20.54s/it, loss_gram=0, loss_l2=0.00809, loss_lpips=0.114]
Steps:   9%|▉         | 909/10000 [4:47:46<51:46:10, 20.50s/it, loss_gram=0, loss_l2=0.00203, loss_lpips=0.0841]
Steps:   9%|▉         | 910/10000 [4:48:07<51:51:19, 20.54s/it, loss_gram=0, loss_l2=0.0116, loss_lpip

uploading viz images



Steps:  10%|█         | 1002/10000 [5:23:11<232:34:48, 93.05s/it, loss_gram=0, loss_l2=0.00581, loss_lpips=0.113] 
Steps:  10%|█         | 1003/10000 [5:23:34<179:49:40, 71.96s/it, loss_gram=0, loss_l2=0.00333, loss_lpips=0.126]
Steps:  10%|█         | 1004/10000 [5:23:55<141:43:06, 56.71s/it, loss_gram=0, loss_l2=0.0114, loss_lpips=0.149] 
Steps:  10%|█         | 1005/10000 [5:24:15<114:08:12, 45.68s/it, loss_gram=0, loss_l2=0.0202, loss_lpips=0.125]
Steps:  10%|█         | 1006/10000 [5:24:39<97:38:21, 39.08s/it, loss_gram=0, loss_l2=0.00219, loss_lpips=0.072]
Steps:  10%|█         | 1007/10000 [5:24:58<83:01:50, 33.24s/it, loss_gram=0, loss_l2=0.0045, loss_lpips=0.0736]
Steps:  10%|█         | 1008/10000 [5:25:21<75:16:03, 30.13s/it, loss_gram=0, loss_l2=0.0139, loss_lpips=0.109] 
Steps:  10%|█         | 1009/10000 [5:25:40<67:07:59, 26.88s/it, loss_gram=0, loss_l2=0.00344, loss_lpips=0.134]
Steps:  10%|█         | 1010/10000 [5:26:02<63:18:25, 25.35s/it, loss_gram=0, loss_l2=0.004

uploading viz images



Steps:  11%|█         | 1102/10000 [5:57:02<50:38:11, 20.49s/it, loss_gram=0, loss_l2=0.0101, loss_lpips=0.215] 
Steps:  11%|█         | 1103/10000 [5:57:20<49:00:40, 19.83s/it, loss_gram=0, loss_l2=0.00368, loss_lpips=0.0731]
Steps:  11%|█         | 1104/10000 [5:57:41<49:56:55, 20.21s/it, loss_gram=0, loss_l2=0.00714, loss_lpips=0.0823]
Steps:  11%|█         | 1105/10000 [5:58:00<48:51:57, 19.78s/it, loss_gram=0, loss_l2=0.00874, loss_lpips=0.104] 
Steps:  11%|█         | 1106/10000 [5:58:20<49:05:32, 19.87s/it, loss_gram=0, loss_l2=0.00442, loss_lpips=0.128]
Steps:  11%|█         | 1107/10000 [5:58:37<46:29:52, 18.82s/it, loss_gram=0, loss_l2=0.0105, loss_lpips=0.123] 
Steps:  11%|█         | 1108/10000 [5:58:58<48:19:51, 19.57s/it, loss_gram=0, loss_l2=0.00335, loss_lpips=0.111]
Steps:  11%|█         | 1109/10000 [5:59:14<46:14:47, 18.73s/it, loss_gram=0, loss_l2=0.0038, loss_lpips=0.059] 
Steps:  11%|█         | 1110/10000 [5:59:36<47:47:27, 19.35s/it, loss_gram=0, loss_l2=0.0071

uploading viz images



Steps:  12%|█▏        | 1202/10000 [6:28:47<47:05:54, 19.27s/it, loss_gram=0, loss_l2=0.00856, loss_lpips=0.127]
Steps:  12%|█▏        | 1203/10000 [6:29:05<46:13:40, 18.92s/it, loss_gram=0, loss_l2=0.00131, loss_lpips=0.0953]
Steps:  12%|█▏        | 1204/10000 [6:29:26<47:57:30, 19.63s/it, loss_gram=0, loss_l2=0.0087, loss_lpips=0.098]  
Steps:  12%|█▏        | 1205/10000 [6:29:44<46:17:44, 18.95s/it, loss_gram=0, loss_l2=0.0107, loss_lpips=0.0836]
Steps:  12%|█▏        | 1206/10000 [6:30:04<46:59:56, 19.24s/it, loss_gram=0, loss_l2=0.0255, loss_lpips=0.226] 
Steps:  12%|█▏        | 1207/10000 [6:30:21<45:47:05, 18.75s/it, loss_gram=0, loss_l2=0.00478, loss_lpips=0.151]
Steps:  12%|█▏        | 1208/10000 [6:30:40<46:01:11, 18.84s/it, loss_gram=0, loss_l2=0.00343, loss_lpips=0.108]
Steps:  12%|█▏        | 1209/10000 [6:30:58<45:34:58, 18.67s/it, loss_gram=0, loss_l2=0.00424, loss_lpips=0.14] 
Steps:  12%|█▏        | 1210/10000 [6:31:19<46:39:41, 19.11s/it, loss_gram=0, loss_l2=0.00448

uploading viz images



Steps:  13%|█▎        | 1302/10000 [7:00:27<47:13:34, 19.55s/it, loss_gram=0, loss_l2=0.00632, loss_lpips=0.15] 
Steps:  13%|█▎        | 1303/10000 [7:00:44<46:03:26, 19.06s/it, loss_gram=0, loss_l2=0.0181, loss_lpips=0.13] 
Steps:  13%|█▎        | 1304/10000 [7:01:05<46:59:42, 19.46s/it, loss_gram=0, loss_l2=0.0345, loss_lpips=0.255]
Steps:  13%|█▎        | 1305/10000 [7:01:23<45:55:23, 19.01s/it, loss_gram=0, loss_l2=0.014, loss_lpips=0.146] 
Steps:  13%|█▎        | 1306/10000 [7:01:43<46:49:32, 19.39s/it, loss_gram=0, loss_l2=0.00258, loss_lpips=0.0883]
Steps:  13%|█▎        | 1307/10000 [7:02:00<45:09:03, 18.70s/it, loss_gram=0, loss_l2=0.00914, loss_lpips=0.132] 
Steps:  13%|█▎        | 1308/10000 [7:02:20<45:48:45, 18.97s/it, loss_gram=0, loss_l2=0.00322, loss_lpips=0.0793]
Steps:  13%|█▎        | 1309/10000 [7:02:38<45:09:34, 18.71s/it, loss_gram=0, loss_l2=0.0159, loss_lpips=0.129]  
Steps:  13%|█▎        | 1310/10000 [7:02:59<46:53:12, 19.42s/it, loss_gram=0, loss_l2=0.0044, 

uploading viz images



Steps:  14%|█▍        | 1402/10000 [7:34:33<49:57:20, 20.92s/it, loss_gram=0, loss_l2=0.00342, loss_lpips=0.0685]
Steps:  14%|█▍        | 1403/10000 [7:34:54<50:04:41, 20.97s/it, loss_gram=0, loss_l2=0.00739, loss_lpips=0.0974]
Steps:  14%|█▍        | 1404/10000 [7:35:15<50:07:07, 20.99s/it, loss_gram=0, loss_l2=0.0852, loss_lpips=0.313]  
Steps:  14%|█▍        | 1405/10000 [7:35:34<48:51:09, 20.46s/it, loss_gram=0, loss_l2=0.00828, loss_lpips=0.0757]
Steps:  14%|█▍        | 1406/10000 [7:35:57<50:33:48, 21.18s/it, loss_gram=0, loss_l2=0.00438, loss_lpips=0.128] 
Steps:  14%|█▍        | 1407/10000 [7:36:15<48:28:05, 20.31s/it, loss_gram=0, loss_l2=0.00717, loss_lpips=0.08] 
Steps:  14%|█▍        | 1408/10000 [7:36:38<50:00:23, 20.95s/it, loss_gram=0, loss_l2=0.00552, loss_lpips=0.113]
Steps:  14%|█▍        | 1409/10000 [7:36:56<48:07:39, 20.17s/it, loss_gram=0, loss_l2=0.00375, loss_lpips=0.0929]
Steps:  14%|█▍        | 1410/10000 [7:37:19<49:57:36, 20.94s/it, loss_gram=0, loss_l2=0.0

uploading viz images



Steps:  15%|█▌        | 1502/10000 [8:07:04<45:02:46, 19.08s/it, loss_gram=0, loss_l2=0.00108, loss_lpips=0.135] 
Steps:  15%|█▌        | 1503/10000 [8:07:22<44:01:33, 18.65s/it, loss_gram=0, loss_l2=0.00357, loss_lpips=0.0981]
Steps:  15%|█▌        | 1504/10000 [8:07:43<45:42:26, 19.37s/it, loss_gram=0, loss_l2=0.00954, loss_lpips=0.0823]
Steps:  15%|█▌        | 1505/10000 [8:08:01<44:44:11, 18.96s/it, loss_gram=0, loss_l2=0.00442, loss_lpips=0.0828]
Steps:  15%|█▌        | 1506/10000 [8:08:22<46:05:58, 19.54s/it, loss_gram=0, loss_l2=0.0108, loss_lpips=0.105]  
Steps:  15%|█▌        | 1507/10000 [8:08:40<45:09:04, 19.14s/it, loss_gram=0, loss_l2=0.00376, loss_lpips=0.0852]
Steps:  15%|█▌        | 1508/10000 [8:08:58<44:26:44, 18.84s/it, loss_gram=0, loss_l2=0.0114, loss_lpips=0.0814] 
Steps:  15%|█▌        | 1509/10000 [8:09:16<44:02:31, 18.67s/it, loss_gram=0, loss_l2=0.00578, loss_lpips=0.11] 
Steps:  15%|█▌        | 1510/10000 [8:09:36<44:53:59, 19.04s/it, loss_gram=0, loss_l2=0.

uploading viz images



Steps:  16%|█▌        | 1602/10000 [8:38:42<44:07:15, 18.91s/it, loss_gram=0, loss_l2=0.0153, loss_lpips=0.128]
Steps:  16%|█▌        | 1603/10000 [8:39:00<43:21:03, 18.59s/it, loss_gram=0, loss_l2=0.00172, loss_lpips=0.153]
Steps:  16%|█▌        | 1604/10000 [8:39:19<43:33:25, 18.68s/it, loss_gram=0, loss_l2=0.00107, loss_lpips=0.169]
Steps:  16%|█▌        | 1605/10000 [8:39:36<42:25:55, 18.20s/it, loss_gram=0, loss_l2=0.0056, loss_lpips=0.0565]
Steps:  16%|█▌        | 1606/10000 [8:39:56<43:46:03, 18.77s/it, loss_gram=0, loss_l2=0.00278, loss_lpips=0.0632]
Steps:  16%|█▌        | 1607/10000 [8:40:13<42:41:48, 18.31s/it, loss_gram=0, loss_l2=0.00914, loss_lpips=0.113] 
Steps:  16%|█▌        | 1608/10000 [8:40:32<42:55:38, 18.41s/it, loss_gram=0, loss_l2=0.016, loss_lpips=0.203]  
Steps:  16%|█▌        | 1609/10000 [8:40:49<42:17:48, 18.15s/it, loss_gram=0, loss_l2=0.0021, loss_lpips=0.106]
Steps:  16%|█▌        | 1610/10000 [8:41:10<43:22:40, 18.61s/it, loss_gram=0, loss_l2=0.0247, l

uploading viz images



Steps:  17%|█▋        | 1702/10000 [9:10:50<50:32:00, 21.92s/it, loss_gram=0, loss_l2=0.00307, loss_lpips=0.0676]
Steps:  17%|█▋        | 1703/10000 [9:11:09<48:26:45, 21.02s/it, loss_gram=0, loss_l2=0.0326, loss_lpips=0.118]  
Steps:  17%|█▋        | 1704/10000 [9:11:33<50:21:40, 21.85s/it, loss_gram=0, loss_l2=0.00132, loss_lpips=0.148]
Steps:  17%|█▋        | 1705/10000 [9:11:52<48:33:10, 21.07s/it, loss_gram=0, loss_l2=0.00346, loss_lpips=0.0977]
Steps:  17%|█▋        | 1706/10000 [9:12:12<47:36:02, 20.66s/it, loss_gram=0, loss_l2=0.000924, loss_lpips=0.0972]
Steps:  17%|█▋        | 1707/10000 [9:12:28<44:51:42, 19.47s/it, loss_gram=0, loss_l2=0.00329, loss_lpips=0.152]  
Steps:  17%|█▋        | 1708/10000 [9:12:51<47:02:08, 20.42s/it, loss_gram=0, loss_l2=0.0064, loss_lpips=0.109] 
Steps:  17%|█▋        | 1709/10000 [9:13:10<45:51:15, 19.91s/it, loss_gram=0, loss_l2=0.0136, loss_lpips=0.0836]
Steps:  17%|█▋        | 1710/10000 [9:13:30<46:03:02, 20.00s/it, loss_gram=0, loss_l2=0.

uploading viz images



Steps:  18%|█▊        | 1802/10000 [9:44:50<45:46:24, 20.10s/it, loss_gram=0, loss_l2=0.0126, loss_lpips=0.0818]
Steps:  18%|█▊        | 1803/10000 [9:45:08<44:20:07, 19.47s/it, loss_gram=0, loss_l2=0.00188, loss_lpips=0.127]
Steps:  18%|█▊        | 1804/10000 [9:45:31<46:05:19, 20.24s/it, loss_gram=0, loss_l2=0.00832, loss_lpips=0.119]
Steps:  18%|█▊        | 1805/10000 [9:45:51<46:22:12, 20.37s/it, loss_gram=0, loss_l2=0.0044, loss_lpips=0.0984]
Steps:  18%|█▊        | 1806/10000 [9:46:14<47:39:58, 20.94s/it, loss_gram=0, loss_l2=0.00435, loss_lpips=0.0951]
Steps:  18%|█▊        | 1807/10000 [9:46:31<45:31:15, 20.00s/it, loss_gram=0, loss_l2=0.0109, loss_lpips=0.13]   
Steps:  18%|█▊        | 1808/10000 [9:46:53<46:46:23, 20.55s/it, loss_gram=0, loss_l2=0.0166, loss_lpips=0.114]
Steps:  18%|█▊        | 1809/10000 [9:47:12<45:39:03, 20.06s/it, loss_gram=0, loss_l2=0.00587, loss_lpips=0.136]
Steps:  18%|█▊        | 1810/10000 [9:47:32<45:39:39, 20.07s/it, loss_gram=0, loss_l2=0.0179, 

uploading viz images



Steps:  19%|█▉        | 1902/10000 [10:17:25<44:05:08, 19.60s/it, loss_gram=0, loss_l2=0.00388, loss_lpips=0.12]  
Steps:  19%|█▉        | 1903/10000 [10:17:43<43:41:15, 19.42s/it, loss_gram=0, loss_l2=0.00643, loss_lpips=0.0698]
Steps:  19%|█▉        | 1904/10000 [10:18:05<45:18:24, 20.15s/it, loss_gram=0, loss_l2=0.0099, loss_lpips=0.0708] 
Steps:  19%|█▉        | 1905/10000 [10:18:24<44:20:16, 19.72s/it, loss_gram=0, loss_l2=0.00439, loss_lpips=0.0732]
Steps:  19%|█▉        | 1906/10000 [10:18:45<45:22:21, 20.18s/it, loss_gram=0, loss_l2=0.00573, loss_lpips=0.0648]
Steps:  19%|█▉        | 1907/10000 [10:19:03<43:54:46, 19.53s/it, loss_gram=0, loss_l2=0.00337, loss_lpips=0.0647]
Steps:  19%|█▉        | 1908/10000 [10:19:24<44:41:15, 19.88s/it, loss_gram=0, loss_l2=0.0149, loss_lpips=0.0977] 
Steps:  19%|█▉        | 1909/10000 [10:19:42<43:31:22, 19.37s/it, loss_gram=0, loss_l2=0.00228, loss_lpips=0.0646]
Steps:  19%|█▉        | 1910/10000 [10:20:01<43:21:05, 19.29s/it, loss_gram=0, 

uploading viz images



Steps:  20%|██        | 2002/10000 [10:53:49<190:08:11, 85.58s/it, loss_gram=1.04, loss_l2=0.0013, loss_lpips=0.0667]
Steps:  20%|██        | 2003/10000 [10:54:11<147:30:10, 66.40s/it, loss_gram=2.28, loss_l2=0.00569, loss_lpips=0.0696]
Steps:  20%|██        | 2004/10000 [10:54:34<119:11:59, 53.67s/it, loss_gram=61.5, loss_l2=0.0204, loss_lpips=0.134]  
Steps:  20%|██        | 2005/10000 [10:54:51<94:37:20, 42.61s/it, loss_gram=1.62, loss_l2=0.00826, loss_lpips=0.0868]
Steps:  20%|██        | 2006/10000 [10:55:17<83:14:54, 37.49s/it, loss_gram=11.8, loss_l2=0.0101, loss_lpips=0.0878] 
Steps:  20%|██        | 2007/10000 [10:55:34<70:01:28, 31.54s/it, loss_gram=0.773, loss_l2=0.00532, loss_lpips=0.101]
Steps:  20%|██        | 2008/10000 [10:55:59<65:44:18, 29.61s/it, loss_gram=0.641, loss_l2=0.00126, loss_lpips=0.158]
Steps:  20%|██        | 2009/10000 [10:56:16<57:08:50, 25.75s/it, loss_gram=1, loss_l2=0.00132, loss_lpips=0.0586]   
Steps:  20%|██        | 2010/10000 [10:56:39<55:09:00

uploading viz images



Steps:  21%|██        | 2102/10000 [11:29:10<48:53:10, 22.28s/it, loss_gram=0.977, loss_l2=0.00277, loss_lpips=0.0842]
Steps:  21%|██        | 2103/10000 [11:29:28<46:05:03, 21.01s/it, loss_gram=1.4, loss_l2=0.0103, loss_lpips=0.0869]   
Steps:  21%|██        | 2104/10000 [11:29:52<48:19:28, 22.03s/it, loss_gram=3.16, loss_l2=0.00646, loss_lpips=0.107]
Steps:  21%|██        | 2105/10000 [11:30:09<44:53:17, 20.47s/it, loss_gram=0.441, loss_l2=0.00182, loss_lpips=0.0961]
Steps:  21%|██        | 2106/10000 [11:30:34<47:44:00, 21.77s/it, loss_gram=5.75, loss_l2=0.00673, loss_lpips=0.187]  
Steps:  21%|██        | 2107/10000 [11:30:52<45:07:41, 20.58s/it, loss_gram=2.66, loss_l2=0.0114, loss_lpips=0.0984]
Steps:  21%|██        | 2108/10000 [11:31:16<47:15:41, 21.56s/it, loss_gram=10.9, loss_l2=0.0157, loss_lpips=0.126] 
Steps:  21%|██        | 2109/10000 [11:31:32<44:08:19, 20.14s/it, loss_gram=0.185, loss_l2=0.00403, loss_lpips=0.101]
Steps:  21%|██        | 2110/10000 [11:31:57<47:08:54,

uploading viz images



Steps:  22%|██▏       | 2202/10000 [12:38:47<546:39:20, 252.37s/it, loss_gram=6.16, loss_l2=0.037, loss_lpips=0.123]  
Steps:  22%|██▏       | 2203/10000 [12:38:55<387:31:12, 178.92s/it, loss_gram=20.5, loss_l2=0.0435, loss_lpips=0.236]
Steps:  22%|██▏       | 2204/10000 [12:39:13<282:56:47, 130.66s/it, loss_gram=2.73, loss_l2=0.0117, loss_lpips=0.0969]
Steps:  22%|██▏       | 2205/10000 [12:39:20<202:32:00, 93.54s/it, loss_gram=31.6, loss_l2=0.0265, loss_lpips=0.119]  
Steps:  22%|██▏       | 2206/10000 [12:39:32<149:03:43, 68.85s/it, loss_gram=3.88, loss_l2=0.0189, loss_lpips=0.122]
Steps:  22%|██▏       | 2207/10000 [12:39:45<113:41:03, 52.52s/it, loss_gram=3.31, loss_l2=0.0254, loss_lpips=0.134]
Steps:  22%|██▏       | 2208/10000 [12:40:32<88:23:34, 40.84s/it, loss_gram=3.92, loss_l2=0.0149, loss_lpips=0.113] 
Steps:  22%|██▏       | 2209/10000 [12:44:28<236:50:52, 109.44s/it, loss_gram=0.151, loss_l2=0.0022, loss_lpips=0.14]
[A

In [12]:
print(next(net_lpips.parameters()).device)


cuda:0


In [None]:
import torch
import torch.nn.functional as F
import os, random, time, gc
import numpy as np
import wandb
from einops import rearrange
from torchvision.transforms.functional import crop

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_losses(x_tgt_pred, x_tgt, args, net_lpips, weight_dtype):
    loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
    loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
    total_loss = loss_l2 + loss_lpips
    return total_loss, loss_l2, loss_lpips

def compute_gram_loss(x_tgt_pred, x_tgt, args, global_step, t_vgg_renorm, net_vgg, weight_dtype):
    if args.lambda_gram == 0 or global_step <= args.gram_loss_warmup_steps:
        return torch.tensor(0.0, device=device, dtype=weight_dtype)

    x_tgt_pred_renorm = t_vgg_renorm(x_tgt_pred * 0.5 + 0.5)
    x_tgt_renorm = t_vgg_renorm(x_tgt * 0.5 + 0.5)

    crop_h, crop_w = 400, 400
    H, W = x_tgt_pred.shape[-2:]
    top, left = random.randint(0, H - crop_h), random.randint(0, W - crop_w)
    x_tgt_pred_renorm = crop(x_tgt_pred_renorm, top, left, crop_h, crop_w)
    x_tgt_renorm = crop(x_tgt_renorm, top, left, crop_h, crop_w)

    return gram_loss(x_tgt_pred_renorm.to(weight_dtype), x_tgt_renorm.to(weight_dtype), net_vgg) * args.lambda_gram

def normalize_img_for_logging(img_tensor):
    """
    Normalize an image tensor to [0, 255] and convert to uint8 for wandb.Image.
    Handles input ranges of [-1, 1] or [0, 1]. Warns if out-of-range.
    Assumes shape is (C, H, W).
    """
    img = img_tensor.clone().detach().cpu().float()

    min_val, max_val = img.min().item(), img.max().item()

    if min_val >= -1.0 and max_val <= 1.0:
        # Handle [-1, 1] range
        img = ((img + 1) * 0.5).clamp(0, 1)
    elif min_val >= 0.0 and max_val <= 1.0:
        # Already in [0, 1]
        img = img.clamp(0, 1)
    else:
        # Unexpected range
        print(f"[normalize_img_for_logging] ⚠️ Image values outside expected ranges: min={min_val}, max={max_val}. Clamping to [0, 1].")
        img = img.clamp(0, 1)

    img = (img * 255).byte()
    return img


def log_images(x_src, x_tgt, x_tgt_pred, step, prefix="train"):
    B = x_src.size(0)

    log_dict = {
        f"{prefix}/source": [wandb.Image(normalize_img_for_logging(rearrange(x_src, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
        f"{prefix}/target": [wandb.Image(normalize_img_for_logging(rearrange(x_tgt, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
        f"{prefix}/model_output": [wandb.Image(normalize_img_for_logging(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
    }

    wandb.log(log_dict, step=step)

    wandb.log(log_dict, step=step)

def save_checkpoint(model, optimizer, path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

@torch.no_grad()
def evaluate(dl_val, model, args, net_lpips, weight_dtype, step):
    model.eval()
    l_l2, l_lpips = [], []
    log_dict = {"sample/source": [], "sample/target": [], "sample/model_output": []}

    for i, batch in enumerate(dl_val):
        if i >= args.num_samples_eval:
            break
        x_src = batch["conditioning_pixel_values"].to(device, dtype=weight_dtype)
        x_tgt = batch["output_pixel_values"].to(device, dtype=weight_dtype)
        B, V, C, H, W = x_src.shape
        assert B == 1, "Use batch size 1 for eval."

        x_tgt_pred = model(x_src, prompt_tokens=batch["input_ids"].to(device))

        if i % 10 == 0:
            log_dict["sample/source"].append(wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[0].float().cpu()))
            log_dict["sample/target"].append(wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[0].float().cpu()))
            log_dict["sample/model_output"].append(wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[0].float().cpu()))

        x_tgt = x_tgt[:, 0]
        x_tgt_pred = x_tgt_pred[:, 0]
        l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean").item()
        lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean().item()

        l_l2.append(l2)
        l_lpips.append(lpips)

    logs = {
        "val/l2": np.mean(l_l2),
        "val/lpips": np.mean(l_lpips),
        **log_dict
    }
    wandb.log(logs, step=step)
    model.train()

def train_one_epoch(dl_train, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype, global_step):
    model.train()

    for step, batch in enumerate(dl_train):
        x_src = batch["conditioning_pixel_values"].to(device)
        x_tgt = batch["output_pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)

        B, V, C, H, W = x_src.shape

        # Forward pass
        x_tgt_pred = model(x_src, prompt_tokens=input_ids)

        x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w')
        x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w')

        # Compute losses
        loss, loss_l2, loss_lpips = compute_losses(x_tgt_pred, x_tgt, args, net_lpips, weight_dtype)
        loss_gram = compute_gram_loss(x_tgt_pred, x_tgt, args, global_step, t_vgg_renorm, net_vgg, weight_dtype)
        loss += loss_gram

        # Backward
        optimizer.zero_grad(set_to_none=args.set_grads_to_none)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        # Logging
        if global_step % args.viz_freq == 1:
            log_images(x_src, rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V), rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V), global_step)

        if global_step % args.checkpointing_steps == 1:
            ckpt_path = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
            save_checkpoint(model, optimizer, ckpt_path)

        if args.eval_freq > 0 and global_step % args.eval_freq == 1:
            evaluate(dl_val, model, args, net_lpips, weight_dtype, global_step)

        wandb.log({
            "loss_l2": loss_l2.item(),
            "loss_lpips": loss_lpips.item(),
            "loss_gram": loss_gram.item() if args.lambda_gram > 0 else 0.0
        }, step=global_step)

        global_step += 1
    return global_step

def train_loop(dl_train, dl_val, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype):
    global_step = 0
    for epoch in range(args.num_training_epochs):
        global_step = train_one_epoch(dl_train, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype, global_step)
