# Train DiT

In [None]:
from DiT import *
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
import numpy as np
from time import time
import os
import argparse
import logging
from copy import deepcopy
from collections import OrderedDict


from diffusers.models import AutoencoderKL

from tqdm import tqdm

os.chdir("./")

logging.basicConfig(filename='training.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def main(args):
    """
    Trains a new DiT model without using distributed training.
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    if not os.path.exists(args.results_dir):
        os.mkdir(args.results_dir)

    # Create model:
    # Note: You need to define DiT_models, create_diffusion, and other required objects
    #       or import them from your implementation.
    model = DiT_models[args.model](
        input_size=args.image_size,
        num_classes=args.num_classes,
        class_dropout_prob=args.class_dropout_prob
    )
    ema = deepcopy(model).cuda()  # Create an EMA of the model for use after training
    requires_grad(ema, False)
    model = model.cuda()
    diffusion = create_diffusion(timestep_respacing="")
    # vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").cuda()

    # Setup optimizer
    opt = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt, step_size=3000, gamma=0.8)
    # Setup data:
    data = torch.load(args.data_file)
    dataset = TensorDataset(torch.tensor(data).cpu())
    loader = DataLoader(
        dataset,
        batch_size=args.global_batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )

    # Variables for monitoring/logging purposes:
    train_steps = 0
    log_steps = 0
    running_loss = 0
    start_time = time()

    print(f"Training for {args.epochs} epochs...")
    for epoch in range(args.epochs):
        print(f"Beginning epoch {epoch}...")
        for x in tqdm(loader):
            x = x[0].to(torch.float32).cuda()
            y = torch.zeros(x.shape[0], dtype=torch.int).cuda()
            # print('input x shape: ',x.shape)
            # with torch.no_grad():
            #     x = vae.encode(x).latent_dist.sample().mul_(0.18215)
            #     print(x.shape)
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],)).cuda()
            model_kwargs = dict(y=y)
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            scheduler.step()
            update_ema(ema, model)

            # Log loss values:
            running_loss += loss.item()
            log_steps += 1
            train_steps += 1
            if train_steps % args.log_every == 0:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / (end_time - start_time)
                avg_loss = torch.tensor(running_loss / log_steps).cuda()
                logging.info(f'Epoch {epoch}, Step {train_steps}: Loss = {avg_loss}')
                print(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                running_loss = 0
                log_steps = 0
                start_time = time()
                
        # Save DiT checkpoint:
        checkpoint = {
                "model": model.state_dict(),
                "ema": ema.state_dict(),
                "opt": opt.state_dict(),
        }
        checkpoint_path = f"{args.results_dir}/ep{epoch}_{train_steps:07d}.pt"
        torch.save(checkpoint, checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")

    print("Done!")



In [None]:
num_videos=10000
seq_length=240
repeat=0

args_dict = {
                    "nnodes":1,
                    "nproc_per_node":"N",
                    "data_file":"latent_data_ball_10_240_bounce.pt",
                    "results_dir": "./video_data/video_results_testQKV_real",
                    "model": f"DiT-PD/2_N=240",
                    "image_size": 2,
                    "num_classes": 1000,
                    "epochs": 1000,
                    "global_batch_size": 256,
                    "global_seed": 0,
                    "vae": "ema",
                    "num_workers": 4,
                    "log_every": 100,
                    "ckpt_every": 50000,
                    "class_dropout_prob": 1
                }

args = ArgsDict(args_dict)
main(args)