In [1]:
import sys
import os

# Get the current working directory
cwd = os.getcwd()

# Go one level up
parent_dir = os.path.abspath(os.path.join(cwd, ".."))

# Add parent directory to Python path
if parent_dir not in sys.path:
    sys.path.append(parent_dir)


In [2]:
import os
import gc
import lpips
import random
import argparse
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision
import transformers
from torchvision.transforms.functional import crop
from accelerate import Accelerator
from accelerate.utils import set_seed
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from glob import glob
from einops import rearrange

import diffusers
from diffusers.utils.import_utils import is_xformers_available
from diffusers.optimization import get_scheduler

import wandb

from src.model import Difix, load_ckpt_from_state_dict, save_ckpt
from src.dataset import PairedDataset
from src.loss import gram_loss

from types import SimpleNamespace

  from .autonotebook import tqdm as notebook_tqdm
  torch.utils._pytree._register_pytree_node(


## Arguments

In [3]:
args = {
    "lambda_lpips": 1.0,
    "lambda_l2": 1.0,
    "lambda_gram": 1.0,
    "gram_loss_warmup_steps": 2000,
    "dataset_path": "data/converted_dataset_fixed.json",
    "train_image_prep": "resized_crop_512",
    "test_image_prep": "resized_crop_512",
    "prompt": None,
    "eval_freq": 100,
    "num_samples_eval": 100,
    "viz_freq": 100,
    "tracker_project_name": "difix",
    "tracker_run_name": "train",
    "pretrained_model_name_or_path": None,
    "revision": None,
    "variant": None,
    "tokenizer_name": None,
    "lora_rank_vae": 4,
    "timestep": 199,
    "mv_unet": False,
    "output_dir": "./outputs/difix/train",
    "cache_dir": None,
    "seed": None,
    "resolution": 512,
    "train_batch_size": 1,
    "num_training_epochs": 10,
    "max_train_steps": 10000,
    "checkpointing_steps": 500,
    "gradient_accumulation_steps": 1,
    "gradient_checkpointing": False,
    "learning_rate": 5e-6,
    "lr_scheduler": "constant",
    "lr_warmup_steps": 500,
    "lr_num_cycles": 1,
    "lr_power": 1.0,
    "dataloader_num_workers": 4,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "adam_weight_decay": 1e-2,
    "adam_epsilon": 1e-8,
    "max_grad_norm": 1.0,
    "allow_tf32": False,
    "report_to": "wandb",
    "mixed_precision": "bf16",
    "enable_xformers_memory_efficient_attention": True,
    "set_grads_to_none": False,
    "resume": None,
}



args['output_dir'] = './outputs/difix/train'
args['dataset_path'] = "./data/converted_dataset_fixed.json"
args['max_train_steps'] = 10000 
args['resolution'] = 512 
args['learning_rate'] = 2e-5 
args['train_batch_size'] = 1 
args['dataloader_num_workers'] = 4 
args['enable_xformers_memory_efficient_attention'] = True
args['checkpointing_steps'] = 1000 
args['eval_freq'] = 1000 
args['viz_freq'] = 100 
args['lambda_lpips'] = 1.0 
args['lambda_l2'] = 1.0 
args['lambda_gram'] =  1.0 
args['gram_loss_warmup_steps'] = 2000
args['report_to']= "wandb" 
args['tracker_project_name'] = "difix" 
args['tracker_run_name'] = "train" 
args['timestep'] = 199

args = SimpleNamespace(**args)

In [4]:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()

if args.seed is not None:
    set_seed(args.seed)


os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)

### Set Up Model ###
net_difix = Difix(
    lora_rank_vae=args.lora_rank_vae, 
    timestep=args.timestep,
    mv_unet=args.mv_unet,
)
net_difix.set_train()

if args.enable_xformers_memory_efficient_attention:
    if is_xformers_available():
        net_difix.unet.enable_xformers_memory_efficient_attention()
    else:
        raise ValueError("xformers is not available, please install it by running `pip install xformers`")

if args.gradient_checkpointing:
    net_difix.unet.enable_gradient_checkpointing()

if args.allow_tf32:
    torch.backends.cuda.matmul.allow_tf32 = True



Initializing model with random weights
Number of trainable parameters in UNet: 865.91M
Number of trainable parameters in VAE: 0.52M


In [5]:
### Set up metrics
net_lpips = lpips.LPIPS(net='vgg').cuda()

net_lpips.requires_grad_(False)

net_vgg = torchvision.models.vgg16(pretrained=True).features
for param in net_vgg.parameters():
    param.requires_grad_(False)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /home/cx24957/miniconda3/envs/difix_env_fixed/lib/python3.10/site-packages/lpips/weights/v0.1/vgg.pth


In [6]:
#### make the optimizer
layers_to_opt = []
layers_to_opt += list(net_difix.unet.parameters())

for n, _p in net_difix.vae.named_parameters():
    if "lora" in n and "vae_skip" in n:
        assert _p.requires_grad
        layers_to_opt.append(_p)
layers_to_opt = layers_to_opt + list(net_difix.vae.decoder.skip_conv_1.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_2.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_3.parameters()) + \
    list(net_difix.vae.decoder.skip_conv_4.parameters())

optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,)
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps,
    num_training_steps=args.max_train_steps,
    num_cycles=args.lr_num_cycles, power=args.lr_power,)

  return disable_fn(*args, **kwargs)


In [13]:
### Set up dataset
dataset_train = PairedDataset(dataset_path='/mnt/e/Difix3d/data/converted_dataset_fixed.json',
                              height=512,
                              width=512,
                              split="train", tokenizer=net_difix.tokenizer)
dl_train = torch.utils.data.DataLoader(dataset_train,
                                       batch_size=args.train_batch_size,
                                       shuffle=True, num_workers=args.dataloader_num_workers)
dataset_val = PairedDataset(dataset_path='/mnt/e/Difix3d/data/converted_dataset_fixed.json',
                            height=512,
                            width=512,
                            split="test", tokenizer=net_difix.tokenizer)
random.Random(42).shuffle(dataset_val.img_ids)
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)

In [14]:
dataset_train[0]['output_pixel_values'].shape

torch.Size([2, 3, 512, 512])

In [15]:
device = 'cuda'

### Set up components on accelerator
weight_dtype = torch.float32
# if accelerator.mixed_precision == "fp16":
#     weight_dtype = torch.float16
# elif accelerator.mixed_precision == "bf16":
#     weight_dtype = torch.bfloat16

# Move al networksr to device and cast to weight_dtype
net_difix.to(device, dtype=weight_dtype)
net_lpips.to(device, dtype=weight_dtype)
net_vgg.to(device, dtype=weight_dtype)

# renorm with image net statistics
t_vgg_renorm =  transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))



In [16]:
import torch
import torch.nn.functional as F
import os, random, time, gc
import numpy as np
import wandb
from einops import rearrange
from torchvision.transforms.functional import crop

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def compute_losses(x_tgt_pred, x_tgt, args, net_lpips, weight_dtype):
    print('l2')
    print(x_tgt_pred.float().shape)
    loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
    print('lpips')
    loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
    total_loss = loss_l2 + loss_lpips
    return total_loss, loss_l2, loss_lpips

def compute_gram_loss(x_tgt_pred, x_tgt, args, global_step, t_vgg_renorm, net_vgg, weight_dtype):
    if args.lambda_gram == 0 or global_step <= args.gram_loss_warmup_steps:
        return torch.tensor(0.0, device=device, dtype=weight_dtype)

    x_tgt_pred_renorm = t_vgg_renorm(x_tgt_pred * 0.5 + 0.5)
    x_tgt_renorm = t_vgg_renorm(x_tgt * 0.5 + 0.5)

    crop_h, crop_w = 400, 400
    H, W = x_tgt_pred.shape[-2:]
    top, left = random.randint(0, H - crop_h), random.randint(0, W - crop_w)
    x_tgt_pred_renorm = crop(x_tgt_pred_renorm, top, left, crop_h, crop_w)
    x_tgt_renorm = crop(x_tgt_renorm, top, left, crop_h, crop_w)

    return gram_loss(x_tgt_pred_renorm.to(weight_dtype), x_tgt_renorm.to(weight_dtype), net_vgg) * args.lambda_gram

def normalize_img_for_logging(img_tensor):
    """
    Normalize an image tensor to [0, 255] and convert to uint8 for wandb.Image.
    Handles input ranges of [-1, 1] or [0, 1]. Warns if out-of-range.
    Assumes shape is (C, H, W).
    """
    img = img_tensor.clone().detach().cpu().float()

    min_val, max_val = img.min().item(), img.max().item()

    if min_val >= -1.0 and max_val <= 1.0:
        # Handle [-1, 1] range
        img = ((img + 1) * 0.5).clamp(0, 1)
    elif min_val >= 0.0 and max_val <= 1.0:
        # Already in [0, 1]
        img = img.clamp(0, 1)
    else:
        # Unexpected range
        print(f"[normalize_img_for_logging] ⚠️ Image values outside expected ranges: min={min_val}, max={max_val}. Clamping to [0, 1].")
        img = img.clamp(0, 1)

    img = (img * 255).byte()
    return img


def log_images(x_src, x_tgt, x_tgt_pred, step, prefix="train"):
    B = x_src.size(0)

    log_dict = {
        f"{prefix}/source": [wandb.Image(normalize_img_for_logging(rearrange(x_src, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
        f"{prefix}/target": [wandb.Image(normalize_img_for_logging(rearrange(x_tgt, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
        f"{prefix}/model_output": [wandb.Image(normalize_img_for_logging(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[i]), caption=f"idx={i}") for i in range(B)],
    }

    wandb.log(log_dict, step=step)

    wandb.log(log_dict, step=step)

def save_checkpoint(model, optimizer, path):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

@torch.no_grad()
def evaluate(dl_val, model, args, net_lpips, weight_dtype, step):
    model.eval()
    l_l2, l_lpips = [], []
    log_dict = {"sample/source": [], "sample/target": [], "sample/model_output": []}

    for i, batch in enumerate(dl_val):
        if i >= args.num_samples_eval:
            break
        x_src = batch["conditioning_pixel_values"].to(device, dtype=weight_dtype)
        x_tgt = batch["output_pixel_values"].to(device, dtype=weight_dtype)
        B, V, C, H, W = x_src.shape
        assert B == 1, "Use batch size 1 for eval."

        x_tgt_pred = model(x_src, prompt_tokens=batch["input_ids"].to(device))

        if i % 10 == 0:
            log_dict["sample/source"].append(wandb.Image(rearrange(x_src, "b v c h w -> b c (v h) w")[0].float().cpu()))
            log_dict["sample/target"].append(wandb.Image(rearrange(x_tgt, "b v c h w -> b c (v h) w")[0].float().cpu()))
            log_dict["sample/model_output"].append(wandb.Image(rearrange(x_tgt_pred, "b v c h w -> b c (v h) w")[0].float().cpu()))

        x_tgt = x_tgt[:, 0]
        x_tgt_pred = x_tgt_pred[:, 0]
        l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean").item()
        lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean().item()

        l_l2.append(l2)
        l_lpips.append(lpips)

    logs = {
        "val/l2": np.mean(l_l2),
        "val/lpips": np.mean(l_lpips),
        **log_dict
    }
    wandb.log(logs, step=step)
    model.train()

def train_one_epoch(dl_train, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype, global_step):
    model.train()

    for step, batch in tqdm(enumerate(dl_train)):
        x_src = batch["conditioning_pixel_values"].to(device)
        x_tgt = batch["output_pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)

        B, V, C, H, W = x_src.shape

        # Forward pass
        x_tgt_pred = model(x_src, prompt_tokens=input_ids)

        x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w')
        x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w')

        # Compute losses
        loss, loss_l2, loss_lpips = compute_losses(x_tgt_pred, x_tgt, args, net_lpips, weight_dtype)
        loss_gram = compute_gram_loss(x_tgt_pred, x_tgt, args, global_step, t_vgg_renorm, net_vgg, weight_dtype)
        loss += loss_gram

        # Backward
        optimizer.zero_grad(set_to_none=args.set_grads_to_none)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        # Logging
        if global_step % args.viz_freq == 1:
            log_images(x_src, rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V), rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V), global_step)

        if global_step % args.checkpointing_steps == 1:
            ckpt_path = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
            save_checkpoint(model, optimizer, ckpt_path)

        if args.eval_freq > 0 and global_step % args.eval_freq == 1:
            evaluate(dl_val, model, args, net_lpips, weight_dtype, global_step)

        wandb.log({
            "loss_l2": loss_l2.item(),
            "loss_lpips": loss_lpips.item(),
            "loss_gram": loss_gram.item() if args.lambda_gram > 0 else 0.0
        }, step=global_step)

        global_step += 1
    return global_step

def train_loop(dl_train, dl_val, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype):
    global_step = 0
    for epoch in tqdm(range(args.num_training_epochs)):
        global_step = train_one_epoch(dl_train, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype, global_step)


In [None]:
train_loop(dl_train,
           dl_val,
           net_difix,
           optimizer,
           lr_scheduler,
           args, 
           net_lpips, 
           t_vgg_renorm, 
           net_vgg, 
           weight_dtype)

  0%|          | 0/10 [00:00<?, ?it/s]

In [17]:
def train_one_epoch(dl_train, model, optimizer, scheduler, args, net_lpips, t_vgg_renorm, net_vgg, weight_dtype, global_step):
    model.train()

    for step, batch in tqdm(enumerate(dl_train)):
        print('loading data to device')
        x_src = batch["conditioning_pixel_values"].to(device)
        x_tgt = batch["output_pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)

        B, V, C, H, W = x_src.shape

        # Forward pass
        print('forward pass')
        x_tgt_pred = model(x_src, prompt_tokens=input_ids)

        print('rearranging')
        x_tgt = rearrange(x_tgt, 'b v c h w -> (b v) c h w')
        x_tgt_pred = rearrange(x_tgt_pred, 'b v c h w -> (b v) c h w')

        # Compute losses
        print('computing first losses')
        loss, loss_l2, loss_lpips = compute_losses(x_tgt_pred, x_tgt, args, net_lpips, weight_dtype)
        print('computing gram loss')
        loss_gram = compute_gram_loss(x_tgt_pred, x_tgt, args, global_step, t_vgg_renorm, net_vgg, weight_dtype)
        loss += loss_gram

        # Backward
        print('gradient comp')
        optimizer.zero_grad(set_to_none=args.set_grads_to_none)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        optimizer.step()
        scheduler.step()

        # # Logging
        # if global_step % args.viz_freq == 1:
        #     log_images(x_src, rearrange(x_tgt, '(b v) c h w -> b v c h w', v=V), rearrange(x_tgt_pred, '(b v) c h w -> b v c h w', v=V), global_step)

        # if global_step % args.checkpointing_steps == 1:
        #     ckpt_path = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
        #     save_checkpoint(model, optimizer, ckpt_path)

        # if args.eval_freq > 0 and global_step % args.eval_freq == 1:
        #     evaluate(dl_val, model, args, net_lpips, weight_dtype, global_step)

        # wandb.log({
        #     "loss_l2": loss_l2.item(),
        #     "loss_lpips": loss_lpips.item(),
        #     "loss_gram": loss_gram.item() if args.lambda_gram > 0 else 0.0
        # }, step=global_step)

        global_step += 1
    return global_step

In [None]:
train_one_epoch(dl_train, 
                net_difix, 
                optimizer, 
                lr_scheduler, 
                args, 
                net_lpips, t_vgg_renorm, net_vgg, weight_dtype, 0)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
0it [00:00, ?it/s]

loading data to device
forward pass
rearranging
computing first losses
l2
torch.Size([2, 3, 512, 512])
lpips
computing gram loss
gradient comp


1it [00:13, 13.08s/it]

loading data to device
forward pass
rearranging
computing first losses
l2
torch.Size([2, 3, 512, 512])
lpips
computing gram loss
