In [1]:
import argparse
import datetime
import logging
import inspect
import math
import os
import random
import gc
import copy

from typing import Dict, Optional, Tuple
from omegaconf import OmegaConf

import cv2
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision.transforms as T
import diffusers
import transformers

from torchvision import transforms
from tqdm.auto import tqdm

from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed

from models.unet_3d_condition import UNet3DConditionModel
from diffusers.models import AutoencoderKL
from diffusers import DPMSolverMultistepScheduler, DDPMScheduler, TextToVideoSDPipeline
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, export_to_video
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention_processor import AttnProcessor2_0, Attention
from diffusers.models.attention import BasicTransformerBlock

from transformers import CLIPTextModel, CLIPTokenizer
from transformers.models.clip.modeling_clip import CLIPEncoder
from utils.dataset import VideoJsonDataset, SingleVideoDataset, \
    ImageDataset, VideoFolderDataset, CachedDataset
from einops import rearrange, repeat
from utils.lora_handler import LoraHandler, LORA_VERSIONS

  from .autonotebook import tqdm as notebook_tqdm


Initializing the conversion map


In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
already_printed_trainables = False

logger = get_logger(__name__, log_level="INFO")

def create_logging(logging, logger, accelerator):
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

In [4]:
accelerator = Accelerator(
    gradient_accumulation_steps=1,
    mixed_precision='fp16',
    log_with='tensorboard',
    project_dir='./outputs'
)

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


In [5]:
create_logging(logging, logger, accelerator)

11/18/2024 03:35:34 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: fp16



In [6]:
def accelerate_set_verbose(accelerator):
    if accelerator.is_local_main_process:
        transformers.utils.logging.set_verbosity_warning()
        diffusers.utils.logging.set_verbosity_info()
    else:
        transformers.utils.logging.set_verbosity_error()
        diffusers.utils.logging.set_verbosity_error()
# Initialize accelerate, transformers, and diffusers warnings
accelerate_set_verbose(accelerator)

In [7]:
set_seed(64)

In [8]:
pretrained_model_path  = "./models/model_scope_diffusers/"
output_dir = "./outputs"

In [9]:
def create_output_folders(output_dir):
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    out_dir = os.path.join(output_dir, f"train_{now}")
    
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(f"{out_dir}/samples", exist_ok=True)
    # OmegaConf.save(config, os.path.join(out_dir, 'config.yaml'))

    return out_dir


if accelerator.is_main_process:
    output_dir = create_output_folders(output_dir)

In [10]:
import os
def load_primary_models(pretrained_model_path):
    noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
    tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, local_files_only = True, use_safetensors=True, subfolder = 'text_encoder')
    vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
    unet = UNet3DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet", local_files_only = True, use_safetensors= True)

    return noise_scheduler, tokenizer, text_encoder, vae, unet
# Load scheduler, tokenizer and models.
noise_scheduler, tokenizer, text_encoder, vae, unet = load_primary_models(pretrained_model_path)


{'rescale_betas_zero_snr', 'variance_type', 'timestep_spacing'} was not found in config. Values will be initialized to default values.
{'shift_factor', 'latents_std', 'force_upcast', 'latents_mean', 'use_post_quant_conv', 'use_quant_conv', 'mid_block_add_attention'} was not found in config. Values will be initialized to default values.
{'mid_block_scale_factor', 'downsample_padding'} was not found in config. Values will be initialized to default values.


In [11]:

# Function to calculate the model's memory size in MB
def get_model_size_in_mb(model):
    if hasattr(model, 'parameters'):
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)  # Total number of parameters
        total_size_bytes = total_params * 2  # float32 is 4 bytes
        total_size_mb = total_size_bytes / (1024 * 1024)  # Convert to MB
        return total_size_mb
    return 0  # Return 0 if the model doesn't have parameters

# Assuming these models are already loaded
def print_model_memory_usage(pretrained_model_path):
    # Load model
    
    print(f"Memory usage for each model:")
    
    # Print memory usage for models that have parameters
    print(f"Noise Scheduler (no parameters, only config): 0 MB")  # Noise scheduler has no parameters
    print(f"Tokenizer (no parameters): 0 MB")  # Tokenizer has no parameters
    
    print(f"Text Encoder: {get_model_size_in_mb(text_encoder):.2f} MB")
    print(f"VAE: {get_model_size_in_mb(vae):.2f} MB")
    print(f"UNet: {get_model_size_in_mb(unet):.2f} MB")

# Example usage
pretrained_model_path = "/home/bheeshm/nisarg/models/model_scope_diffusers"
print_model_memory_usage(pretrained_model_path)


Memory usage for each model:
Noise Scheduler (no parameters, only config): 0 MB
Tokenizer (no parameters): 0 MB
Text Encoder: 649.24 MB
VAE: 159.56 MB
UNet: 2691.71 MB


In [12]:
def freeze_models(models_to_freeze):
    for model in models_to_freeze:
        if model is not None: model.requires_grad_(False) 
freeze_models([vae, text_encoder, unet])

In [13]:
def is_attn(name):
   return ('attn1' or 'attn2' == name.split('.')[-1])

def set_processors(attentions):
    for attn in attentions: attn.set_processor(AttnProcessor2_0()) 
    
def set_torch_2_attn(unet):
    optim_count = 0
    
    for name, module in unet.named_modules():
        if is_attn(name):
            if isinstance(module, torch.nn.ModuleList):
                for m in module:
                    if isinstance(m, BasicTransformerBlock):
                        set_processors([m.attn1, m.attn2])
                        optim_count += 1
    if optim_count > 0: 
        print(f"{optim_count} Attention layers using Scaled Dot Product Attention.")

In [14]:
def handle_memory_attention(enable_xformers_memory_efficient_attention, enable_torch_2_attn, unet): 
    try:
        is_torch_2 = hasattr(F, 'scaled_dot_product_attention')
        enable_torch_2 = is_torch_2 and enable_torch_2_attn
        
        if enable_xformers_memory_efficient_attention and not enable_torch_2:
            if is_xformers_available():
                from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
                unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
            else:
                raise ValueError("xformers is not available. Make sure it is installed correctly")
        
        if enable_torch_2:
            set_torch_2_attn(unet)
            
    except:
        print("Could not enable memory efficient attention for xformers or Torch 2.0.")


handle_memory_attention(True, True, unet)

33 Attention layers using Scaled Dot Product Attention.


In [15]:
# Initialize the optimizer
optimizer_cls = torch.optim.AdamW

In [16]:
extra_unet_params = None
extra_text_encoder_params= None
train_batch_size= 1
max_train_steps= 10000
learning_rate= 5e-6
scale_lr =False
lr_scheduler = "constant"
lr_warmup_steps=0
adam_beta1=0.9
adam_beta2= 0.999
adam_weight_decay= 0
adam_epsilon = 1e-08
max_grad_norm= 1.0
gradient_accumulation_steps=1
gradient_checkpointing= True
text_encoder_gradient_checkpointing= False
checkpointing_steps= 2500
resume_from_checkpoint= None
resume_step= None
train_text_encoder: False
use_offset_noise: False
rescale_schedule: False
offset_noise_strength= 0.1
extend_dataset= False
cache_latents= True
cached_latent_dir= "./outputs/cached_latents"
lora_version= "cloneofsimo"
save_lora_for_webui= True
only_lora_for_webui= False
lora_bias= 'none'
use_unet_lora= False
use_text_lora= False
unet_lora_modules= ("UNet3DConditionModel",)
text_encoder_lora_modules= ("CLIPEncoderLayer",)
save_pretrained_model= True
lora_rank= 16
lora_path= ""
lora_unet_dropout= 0.1
lora_text_dropout= 0.1
dataset_types= ("folder",)
shuffle=True
validation_steps= 100
trainable_modules= ("all",)
trainable_text_modules= ("all",)
train_text_encoder = False

In [17]:
# Use LoRA if enabled.  
lora_manager = LoraHandler(
    version=lora_version, 
    use_unet_lora=use_unet_lora,
    use_text_lora=use_text_lora,
    save_for_webui=save_lora_for_webui,
    only_for_webui=only_lora_for_webui,
    unet_replace_modules=unet_lora_modules,
    text_encoder_replace_modules=text_encoder_lora_modules,
    lora_bias=lora_bias
)


In [18]:
unet_lora_params, unet_negation = lora_manager.add_lora_to_model(
    use_unet_lora, unet, lora_manager.unet_replace_modules, lora_unet_dropout, lora_path, r=lora_rank) 

text_encoder_lora_params, text_encoder_negation = lora_manager.add_lora_to_model(
    use_text_lora, text_encoder, lora_manager.text_encoder_replace_modules, lora_text_dropout, lora_path, r=lora_rank) 


In [19]:

extra_unet_params = extra_unet_params if extra_unet_params is not None else {}
extra_text_encoder_params = extra_unet_params if extra_unet_params is not None else {}

trainable_modules_available = trainable_modules is not None
trainable_text_modules_available = (train_text_encoder and trainable_text_modules is not None)

In [20]:
def param_optim(model, condition, extra_params=None, is_lora=False, negation=None):
    extra_params = extra_params if len(extra_params.keys()) > 0 else None
    return {
        "model": model, 
        "condition": condition, 
        'extra_params': extra_params,
        'is_lora': is_lora,
        "negation": negation
    }

In [21]:
optim_params = [
        param_optim(unet, trainable_modules_available, extra_params=extra_unet_params, negation=unet_negation),
        param_optim(text_encoder, trainable_text_modules_available, 
                        extra_params=extra_text_encoder_params, 
                        negation=text_encoder_negation
                   ),
        param_optim(text_encoder_lora_params, use_text_lora, is_lora=True, 
                        extra_params={**{"lr": learning_rate}, **extra_text_encoder_params}
                    ),
        param_optim(unet_lora_params, use_unet_lora, is_lora=True, 
                        extra_params={**{"lr": learning_rate}, **extra_unet_params}
                    )
    ]

In [22]:
def create_optim_params(name='param', params=None, lr=5e-6, extra_params=None):
    params = {
        "name": name, 
        "params": params, 
        "lr": lr
    }
    if extra_params is not None:
        for k, v in extra_params.items():
            params[k] = v
    
    return params

def create_optimizer_params(model_list, lr):
    import itertools
    optimizer_params = []

    for optim in model_list:
        model, condition, extra_params, is_lora, negation = optim.values()
        # Check if we are doing LoRA training.
        if is_lora and condition and isinstance(model, list): 
            params = create_optim_params(
                params=itertools.chain(*model), 
                extra_params=extra_params
            )
            optimizer_params.append(params)
            continue
            
        if is_lora and  condition and not isinstance(model, list):
            for n, p in model.named_parameters():
                if 'lora' in n:
                    params = create_optim_params(n, p, lr, extra_params)
                    optimizer_params.append(params)
            continue

        # If this is true, we can train it.
        if condition:
            for n, p in model.named_parameters():
                should_negate = 'lora' in n and not is_lora
                if should_negate: continue

                params = create_optim_params(n, p, lr, extra_params)
                optimizer_params.append(params)
    
    return optimizer_params

In [23]:
params = create_optimizer_params(optim_params, learning_rate)

In [24]:
optimizer = optimizer_cls(
    params,
    lr=learning_rate,
    betas=(adam_beta1, adam_beta2),
    weight_decay=adam_weight_decay,
    eps=adam_epsilon,
)

# Scheduler
lr_scheduler = get_scheduler(
    lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
    num_training_steps=max_train_steps * gradient_accumulation_steps,
)

In [25]:
extra_train_data = None
# Training data parameters
train_data= {
  # The width and height in which you want your training data to be resized to.
  'width': 320,  'height': 240,
  # This will find the closest aspect ratio to your input width and height. 
  # For example, 512x512 width and height with a video of resolution 1280x720 will be resized to 512x256
  'use_bucketing': True,
  # The start frame index where your videos should start (Leave this at one for json and folder based training).
  'sample_start_idx': 1,
  # Used for 'folder'. The rate at which your frames are sampled. Does nothing for 'json' and 'single_video' dataset.
  'fps': 24,
  # For 'single_video' and 'json'. The number of frames to "step" (1,2,3,4) (frame_step=2) -> (1,3,5,7, ...).  
  'frame_step': 2,
  # The number of frames to sample. The higher this number, the higher the VRAM (acts similar to batch size).
  'n_sample_frames': 1,
  # # 'single_video'
  # single_video_path: "path/to/single/video.mp4",
  # # The prompt when using a a single video file
  # single_video_prompt: "",
  # Fallback prompt if caption cannot be read. Enabled for 'image' and 'folder'.
  'fallback_prompt': 'human is walking',
  # 'folder'
  'path': "./test",
  # # 'json'
  # json_path: 'path/to/train/json/'
  # # 'image'
  # image_dir: 'path/to/image/directory'
  # # The prompt for all image files. Leave blank to use caption files (.txt) 
  # single_img_prompt: ""}
}

def get_train_dataset(dataset_types, train_data, tokenizer):
    train_datasets = []

    # Loop through all available datasets, get the name, then add to list of data to process.
    for DataSet in [VideoJsonDataset, SingleVideoDataset, ImageDataset, VideoFolderDataset]:
        for dataset in dataset_types:
            if dataset == DataSet.__getname__():
                train_datasets.append(DataSet(**train_data, tokenizer=tokenizer))

    if len(train_datasets) > 0:
        return train_datasets
    else:
        raise ValueError("Dataset type not found: 'json', 'single_video', 'folder', 'image'")

def extend_datasets(datasets, dataset_items, extend=False):
    biggest_data_len = max(x.__len__() for x in datasets)
    extended = []
    for dataset in datasets:
        if dataset.__len__() == 0:
            del dataset
            continue
        if dataset.__len__() < biggest_data_len:
            for item in dataset_items:
                if extend and item not in extended and hasattr(dataset, item):
                    print(f"Extending {item}")

                    value = getattr(dataset, item)
                    value *= biggest_data_len
                    value = value[:biggest_data_len]

                    setattr(dataset, item, value)

                    print(f"New {item} dataset length: {dataset.__len__()}")
                    extended.append(item)
# Get the training dataset based on types (json, single_video, image)
train_datasets = get_train_dataset(dataset_types, train_data, tokenizer)

# If you have extra train data, you can add a list of however many you would like.
# Eg: extra_train_data: [{: {dataset_types, train_data: {etc...}}}] 
try:
    if extra_train_data is not None and len(extra_train_data) > 0:
        for dataset in extra_train_data:
            d_t, t_d = dataset['dataset_types'], dataset['train_data']
            train_datasets += get_train_dataset(d_t, t_d, tokenizer)

except Exception as e:
    print(f"Could not process extra train datasets due to an error : {e}")

# Extend datasets that are less than the greatest one. This allows for more balanced training.
attrs = ['train_data', 'frames', 'image_dir', 'video_files']
extend_datasets(train_datasets, attrs, extend=extend_dataset)

In [26]:
# Process one dataset
if len(train_datasets) == 1:
    train_dataset = train_datasets[0]
    
# Process many datasets
else:
    train_dataset = torch.utils.data.ConcatDataset(train_datasets)

In [27]:
train_dataset

<utils.dataset.VideoFolderDataset at 0x7416e6b98b80>

In [28]:
len(train_dataset)

11

In [29]:
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=train_batch_size,
    shuffle=shuffle

)

In [30]:
len(train_dataloader)

11

In [31]:
def tensor_to_vae_latent(t, vae):
    video_length = t.shape[1]

    t = rearrange(t, "b f c h w -> (b f) c h w")
    latents = vae.encode(t).latent_dist.sample()
    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
    latents = latents * 0.18215

    return latents

In [32]:
def handle_cache_latents(

        should_cache, 
        output_dir, 
        train_dataloader, 
        train_batch_size, 
        vae, 
        cached_latent_dir=None,
        shuffle=False
    ):

    # Cache latents by storing them in VRAM. 
    # Speeds up training and saves memory by not encoding during the train loop.
    if not should_cache: return None
    vae.to('cuda', dtype=torch.float16)
    vae.enable_slicing()
    
    cached_latent_dir = (
        os.path.abspath(cached_latent_dir) if cached_latent_dir is not None else None 
        )
    print('Cached Latent Directory: ', cached_latent_dir if not None else "")
    if cached_latent_dir is None:
        cache_save_dir = f"{output_dir}/cached_latents"
        os.makedirs(cache_save_dir, exist_ok=True)

        for i, batch in enumerate(tqdm(train_dataloader, desc="Caching Latents.")):

            save_name = f"cached_{i}"
            full_out_path =  f"{cache_save_dir}/{save_name}.pt"

            pixel_values = batch['pixel_values'].to('cuda', dtype=torch.float16)
            batch['pixel_values'] = tensor_to_vae_latent(pixel_values, vae)
            for k, v in batch.items(): batch[k] = v[0]
        
            torch.save(batch, full_out_path)
            del pixel_values
            del batch

            # We do this to avoid fragmentation from casting latents between devices.
            torch.cuda.empty_cache()
    else:
        cache_save_dir = cached_latent_dir
        

    return torch.utils.data.DataLoader(
        CachedDataset(cache_dir=cache_save_dir), 
        batch_size=train_batch_size, 
        shuffle=shuffle,
        num_workers=0
    ) 

    # Latents caching
cached_data_loader = handle_cache_latents(
    cache_latents, 
    output_dir,
    train_dataloader, 
    train_batch_size, 
    vae,
    cached_latent_dir
) 

Cached Latent Directory:  /home/bheeshm/nisarg/outputs/cached_latents


In [33]:
cached_data_loader

<torch.utils.data.dataloader.DataLoader at 0x7416dd3f3520>

In [34]:
len(cached_data_loader)

6420

In [35]:
if cached_data_loader is not None: 
    train_dataloader = cached_data_loader

In [36]:
# Prepare everything with our `accelerator`.
unet, optimizer,train_dataloader, lr_scheduler, text_encoder = accelerator.prepare(
    unet, 
    optimizer, 
    train_dataloader, 
    lr_scheduler, 
    text_encoder
)


  model.forward = torch.cuda.amp.autocast(dtype=torch.float16)(model.forward)


In [37]:
def unet_and_text_g_c(unet, text_encoder, unet_enable, text_enable):
    unet._set_gradient_checkpointing(value=unet_enable)
    text_encoder._set_gradient_checkpointing(CLIPEncoder)
# Use Gradient Checkpointing if enabled.
unet_and_text_g_c(
    unet, 
    text_encoder, 
    gradient_checkpointing, 
    text_encoder_gradient_checkpointing
)

In [38]:
vae.enable_slicing()

In [39]:
def is_mixed_precision(accelerator):
    weight_dtype = torch.float32

    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16

    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    return weight_dtype

# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = is_mixed_precision(accelerator)

# Move text encoders, and VAE to GPU
models_to_cast = [text_encoder, vae]

def cast_to_gpu_and_type(model_list, accelerator, weight_dtype):
    for model in model_list:
        if model is not None: model.to(accelerator.device, dtype=weight_dtype)

cast_to_gpu_and_type(models_to_cast, accelerator, weight_dtype)

In [40]:
def enforce_zero_terminal_snr(betas):
    """
    Corrects noise in diffusion schedulers.
    From: Common Diffusion Noise Schedules and Sample Steps are Flawed
    https://arxiv.org/pdf/2305.08891.pdf
    """
    # Convert betas to alphas_bar_sqrt
    alphas = 1 - betas
    alphas_bar = alphas.cumprod(0)
    alphas_bar_sqrt = alphas_bar.sqrt()

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
        alphas_bar_sqrt_0 - alphas_bar_sqrt_T
    )

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt ** 2
    alphas = alphas_bar[1:] / alphas_bar[:-1]
    alphas = torch.cat([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas

In [41]:
use_offset_noise = False
rescale_schedule = False
# Fix noise schedules to predcit light and dark areas if available.
if not use_offset_noise and rescale_schedule:
    noise_scheduler.betas = enforce_zero_terminal_snr(noise_scheduler.betas)
    
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)

# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
    accelerator.init_trackers("text2video-fine-tune")

In [42]:
# Train!
total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f"  Num examples = {len(train_dataset)}")
logger.info(f"  Num Epochs = {num_train_epochs}")
logger.info(f"  Instantaneous batch size per device = {train_batch_size}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
logger.info(f"  Total optimization steps = {max_train_steps}")
global_step = 0
first_epoch = 0

11/18/2024 03:36:47 - INFO - __main__ - ***** Running training *****
11/18/2024 03:36:47 - INFO - __main__ -   Num examples = 11
11/18/2024 03:36:47 - INFO - __main__ -   Num Epochs = 2
11/18/2024 03:36:47 - INFO - __main__ -   Instantaneous batch size per device = 1
11/18/2024 03:36:47 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 1
11/18/2024 03:36:47 - INFO - __main__ -   Gradient Accumulation steps = 1
11/18/2024 03:36:47 - INFO - __main__ -   Total optimization steps = 10000


In [43]:

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")

Steps:   0%|          | 0/10000 [00:00<?, ?it/s]

In [44]:
def handle_trainable_modules(model, trainable_modules=None, is_enabled=True, negation=None):
    global already_printed_trainables
    acc = []
    unfrozen_params = 0
    
    if trainable_modules is not None:
        unlock_all = any([name == 'all' for name in trainable_modules])
        if unlock_all:
            model.requires_grad_(True)
            unfrozen_params = len(list(model.parameters()))
        else:
            model.requires_grad_(False)
            for name, param in model.named_parameters():
                for tm in trainable_modules:
                    if all([tm in name, name not in acc, 'lora' not in name]):
                        param.requires_grad_(is_enabled)
                        acc.append(name)
                        unfrozen_params += 1
                        
    if unfrozen_params > 0 and not already_printed_trainables:
        already_printed_trainables = True 
        print(f"{unfrozen_params} params have been processed.")


In [45]:
def sample_noise(latents, noise_strength, use_offset_noise=False):
    b ,c, f, *_ = latents.shape
    noise_latents = torch.randn_like(latents, device=latents.device)
    offset_noise = None

    if use_offset_noise:
        offset_noise = torch.randn(b, c, f, 1, 1, device=latents.device)
        noise_latents = noise_latents + noise_strength * offset_noise

    return noise_latents

In [46]:
def finetune_unet(batch, train_encoder=False):
        use_offset_noise = False
        rescale_schedule = False
        
        # Check if we are training the text encoder
        text_trainable = (train_text_encoder or lora_manager.use_text_lora)
        
        # Unfreeze UNET Layers
        if global_step == 0: 
            already_printed_trainables = False
            unet.train()
            handle_trainable_modules(
                unet, 
                trainable_modules, 
                is_enabled=True,
                negation=unet_negation
            )

        # Convert videos to latent space
        pixel_values = batch["pixel_values"]

        if not cache_latents:
            latents = tensor_to_vae_latent(pixel_values, vae)
        else:
            latents = pixel_values

        # Get video length
        video_length = latents.shape[2]
        print("Video Length: ", video_length)

        # Sample noise that we'll add to the latents
        use_offset_noise = use_offset_noise and not rescale_schedule
        noise = sample_noise(latents, offset_noise_strength, use_offset_noise)
        bsz = latents.shape[0]

        # Sample a random timestep for each video
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
        # Enable text encoder training
        if text_trainable:
            text_encoder.train()

            if lora_manager.use_text_lora: 
                text_encoder.text_model.embeddings.requires_grad_(True)

            if global_step == 0 and train_text_encoder:
                handle_trainable_modules(
                    text_encoder, 
                    trainable_modules=trainable_text_modules,
                    negation=text_encoder_negation
            )
            cast_to_gpu_and_type([text_encoder], accelerator, torch.float32)
        print("1")
        # *Potentially* Fixes gradient checkpointing training.
        # See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
        if True:
            unet.eval()
            text_encoder.eval()
        print("2")    
        # Encode text embeddings
        token_ids = batch['prompt_ids']

        # Assume extra batch dimnesion.
        if len(token_ids.shape) > 2:
            token_ids = token_ids[0]
            
        encoder_hidden_states = text_encoder(token_ids)[0]

        # Get the target for loss depending on the prediction type
        if noise_scheduler.prediction_type == "epsilon":
            target = noise

        elif noise_scheduler.prediction_type == "v_prediction":
            target = noise_scheduler.get_velocity(latents, noise, timesteps)

        else:
            raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}")

        
        # Here we do two passes for video and text training.
        # If we are on the second iteration of the loop, get one frame.
        # This allows us to train text information only on the spatial layers.
        losses = []
        should_truncate_video = (video_length > 1 and text_trainable)

        # We detach the encoder hidden states for the first pass (video frames > 1)
        # Then we make a clone of the initial state to ensure we can train it in the loop.
        detached_encoder_state = encoder_hidden_states.clone().detach()
        trainable_encoder_state = encoder_hidden_states.clone()

        for i in range(2):

            should_detach = noisy_latents.shape[2] > 1 and i == 0

            if should_truncate_video and i == 1:
                noisy_latents = noisy_latents[:,:,1,:,:].unsqueeze(2)
                target = target[:,:,1,:,:].unsqueeze(2)
                       
            encoder_hidden_states = (
                detached_encoder_state if should_detach else trainable_encoder_state
            )

            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
            loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

            losses.append(loss)
            
            # This was most likely single frame training or a single image.
            if video_length == 1 and i == 0: break

        loss = losses[0] if len(losses) == 1 else losses[0] + losses[1] 

        return loss, latents

In [100]:
def save_pipe(
        path, 
        global_step,
        accelerator, 
        unet, 
        text_encoder, 
        vae, 
        output_dir,
        lora_manager: LoraHandler,
        unet_target_replace_module=None,
        text_target_replace_module=None,
        is_checkpoint=False,
        save_pretrained_model=True
    ):

    if is_checkpoint:
        save_path = os.path.join(output_dir, f"checkpoint-{global_step}")
        os.makedirs(save_path, exist_ok=True)
    else:
        save_path = output_dir

    # Save the dtypes so we can continue training at the same precision.
    u_dtype, t_dtype, v_dtype = unet.dtype, text_encoder.dtype, vae.dtype 

   # Copy the model without creating a reference to it. This allows keeping the state of our lora training if enabled.
    unet_save = copy.deepcopy(unet.cpu())
    text_encoder_save = copy.deepcopy(text_encoder.cpu())

    unet_out = copy.deepcopy(accelerator.unwrap_model(unet_save, keep_fp32_wrapper=False))
    text_encoder_out = copy.deepcopy(accelerator.unwrap_model(text_encoder_save, keep_fp32_wrapper=False))

    pipeline = TextToVideoSDPipeline.from_pretrained(
        path,
        unet=unet_out,
        text_encoder=text_encoder_out,
        vae=vae,
    ).to(torch_dtype=torch.float32)
    
    lora_manager.save_lora_weights(model=pipeline, save_path=save_path, step=global_step)

    if save_pretrained_model:
        pipeline.save_pretrained(save_path)

    if is_checkpoint:
        unet, text_encoder = accelerator.prepare(unet, text_encoder)
        models_to_cast_back = [(unet, u_dtype), (text_encoder, t_dtype), (vae, v_dtype)]
        [x[0].to(accelerator.device, dtype=x[1]) for x in models_to_cast_back]

    logger.info(f"Saved model at {save_path} on step {global_step}")
    
    del pipeline
    del unet_out
    del text_encoder_out
    torch.cuda.empty_cache()
    gc.collect()


In [101]:
def should_sample(global_step, validation_steps, validation_data):
    return (global_step % validation_steps == 0 or global_step == 1)  \
    and validation_data.sample_preview


In [102]:
# Validation data parameters.
validation_data= {
  # A custom prompt that is different from your training dataset. 
  "prompt": "",
  # Whether or not to sample preview during training (Requires more VRAM).
  "sample_preview": True,
  # The number of frames to sample during validation.
  "num_frames": 4,
  # Height and width of validation sample.
  "width": 320,
  "height": 240,
  # Number of inference steps when generating the video.
  "num_inference_steps": 25,
  # CFG scale
  "guidance_scale": 9
}

In [103]:
for epoch in range(first_epoch, num_train_epochs):
    train_loss = 0.0
    
    for step, batch in enumerate(train_dataloader):
        # Skip steps until we reach the resumed step
        if resume_from_checkpoint and epoch == first_epoch and step < resume_step:
            if step % gradient_accumulation_steps == 0:
                progress_bar.update(1)
            continue
        
        with accelerator.accumulate(unet) ,accelerator.accumulate(text_encoder):

            text_prompt = batch['text_prompt'][0]
            
            with accelerator.autocast():
                loss, latents = finetune_unet(batch, train_encoder=train_text_encoder)
            
            # Gather the losses across all processes for logging (if we use distributed training).
            avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean()
            train_loss += avg_loss.item() / gradient_accumulation_steps

            # Backpropagate
            try:
                accelerator.backward(loss)

                if any([train_text_encoder, use_text_lora]):
                    params_to_clip = list(unet.parameters()) + list(text_encoder.parameters())
                else:
                    params_to_clip = unet.parameters()

                if max_grad_norm > 0:
                    if accelerator.sync_gradients:
                        if any([train_text_encoder, use_text_lora]):
                            params_to_clip = list(unet.parameters()) + list(text_encoder.parameters())
                        else:
                            params_to_clip = list(unet.parameters())
                            
                        accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)
                        
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                
            except Exception as e:
                print(f"An error has occured during backpropogation! {e}") 
                continue

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            progress_bar.update(1)
            global_step += 1
            accelerator.log({"train_loss": train_loss}, step=global_step)
            train_loss = 0.0
        
            if global_step % checkpointing_steps == 0:
                save_pipe(
                    pretrained_model_path, 
                    global_step, 
                    accelerator, 
                    unet, 
                    text_encoder, 
                    vae, 
                    output_dir, 
                    lora_manager,
                    unet_lora_modules,
                    text_encoder_lora_modules,
                    is_checkpoint=True,
                    save_pretrained_model=save_pretrained_model
                )

            if should_sample(global_step, validation_steps, validation_data):
                if global_step == 1: print("Performing validation prompt.")
                if accelerator.is_main_process:
                    
                    with accelerator.autocast():
                        unet.eval()
                        text_encoder.eval()
                        unet_and_text_g_c(unet, text_encoder, False, False)
                        lora_manager.deactivate_lora_train([unet, text_encoder], True)    

                        pipeline = TextToVideoSDPipeline.from_pretrained(
                            pretrained_model_path,
                            text_encoder=text_encoder,
                            vae=vae,
                            unet=unet
                        )

                        diffusion_scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
                        pipeline.scheduler = diffusion_scheduler

                        prompt = text_prompt if len(validation_data.prompt) <= 0 else validation_data.prompt

                        curr_dataset_name = batch['dataset']
                        save_filename = f"{global_step}_dataset-{curr_dataset_name}_{prompt}"

                        out_file = f"{output_dir}/samples/{save_filename}.mp4"
                        
                        with torch.no_grad():
                            video_frames = pipeline(
                                prompt,
                                width=validation_data.width,
                                height=validation_data.height,
                                num_frames=validation_data.num_frames,
                                num_inference_steps=validation_data.num_inference_steps,
                                guidance_scale=validation_data.guidance_scale
                            ).frames
                        export_to_video(video_frames, out_file, train_data.get('fps', 8))

                        del pipeline
                        torch.cuda.empty_cache()

                logger.info(f"Saved a new sample to {out_file}")

                unet_and_text_g_c(
                    unet, 
                    text_encoder, 
                    gradient_checkpointing, 
                    text_encoder_gradient_checkpointing
                )

                lora_manager.deactivate_lora_train([unet, text_encoder], False)    

        logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
        accelerator.log({"training_loss": loss.detach().item()}, step=step)
        progress_bar.set_postfix(**logs)

        if global_step >= max_train_steps:
            break

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
    save_pipe(
            pretrained_model_path, 
            global_step, 
            accelerator, 
            unet, 
            text_encoder, 
            vae, 
            output_dir, 
            lora_manager,
            unet_lora_modules,
            text_encoder_lora_modules,
            is_checkpoint=False,
            save_pretrained_model=save_pretrained_model
    )     
accelerator.end_training()

Video Length:  8


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 15.72 GiB of which 15.25 MiB is free. Including non-PyTorch memory, this process has 15.59 GiB memory in use. Of the allocated memory 15.11 GiB is allocated by PyTorch, and 274.21 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)