In [None]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.distributed import reduce
from torch.distributed import init_process_group, destroy_process_group
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
from diffusers import DDIMScheduler
import argparse
import matplotlib.pyplot as plt
import os
import time

torch.manual_seed(3)

from models.dit_model import TransformerBackbone
from models.vanilla_transformer import VanillaTransformer
from dataset import load_npy_files, process_motion_tensor, MotionAudioDataset


In [None]:
# Global config
config = {
    # Model config
    "x_dim": 63,
    "a_dim": 768,
    "max_seq_length": 75,
    "hidden_size": 512,
    "num_layers": 8,
    "num_attention_heads": 8,
    # Training config
    "batch_size": 32,
    "learning_rate": 1e-3,
    "num_epochs": -1,
    "save_interval": 5,
    "num_iterations": 8500000,
    "scheduler": "none",
    "warmup_iters": 5000,
    "lr_min_scale": 0.2,
    "cos_iters": 800000,
    # Validation config
    "validate_only": False,
    "valid_batch_size": 256,
    "validate_interval": 1000,
}

# Global data parameters
audio_root = '/mnt/e/data/live_encoder_output/audio_latent/'
motion_root = '/mnt/e/data/live_encoder_output/live_latent/'
start_idx = 0
end_idx = 400
output_dir = "output"
checkpoint_dir = None
world_size = 1
model_type = "vanilla"
audio_latents, motion_latents = load_npy_files(audio_root, motion_root, start_idx, end_idx)
motion_latents, _, _ = process_motion_tensor(motion_latents)

# Global training variables
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = None
noise_scheduler = None
optimizer = None
scaler = None
lr_scheduler = None
dataset = MotionAudioDataset(motion_latents, audio_latents)
sampler = DistributedSampler(dataset, num_replicas=1, rank=0, shuffle=True)
dataloader = DataLoader(dataset, batch_size=config["batch_size"], sampler=sampler, pin_memory=True)


epoch_losses = []
iteration_losses = []
prev_seq_length = 10


In [None]:



def init_model():
    global model, noise_scheduler, device, model_type, config

    os.makedirs(output_dir, exist_ok=True)

    # Initialize model based on model_type
    if model_type == "dit":
        model = TransformerBackbone(
            x_dim=config["x_dim"],
            a_dim=config["a_dim"],
            max_seq_length=config["max_seq_length"],
            hidden_size=config["hidden_size"],
            num_layers=config["num_layers"],
            num_attention_heads=config["num_attention_heads"],
            norm_type="ada_norm_zero",
            device=device
        ).to(device)
    elif model_type == "vanilla":
        model = VanillaTransformer(
            x_dim=config["x_dim"],
            a_dim=config["a_dim"],
            max_seq_length=config["max_seq_length"],
            hidden_size=config["hidden_size"],
            num_layers=config["num_layers"],
            num_heads=config["num_attention_heads"],
        ).to(device)
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    # model = model.bfloat16()

    # Initialize noise scheduler (only for DiT model)
    if model_type == "dit":
        noise_scheduler = DDIMScheduler(
            num_train_timesteps=1000,
            beta_schedule="squaredcos_cap_v2",
            clip_sample=False,
            set_alpha_to_one=False
        )

def load_checkpoint():
    global model, optimizer, checkpoint_dir, device

    model_path = os.path.join(checkpoint_dir, "model.pth")
    optimizer_path = os.path.join(checkpoint_dir, "optimizer.pth")

    if os.path.exists(model_path) or os.path.exists(optimizer_path):
        if os.path.exists(model_path): 
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Model checkpoint loaded from {model_path}")
        if os.path.exists(optimizer_path) and optimizer is not None:
            optimizer.load_state_dict(torch.load(optimizer_path, map_location=device))
            print(f"Checkpoint loaded from {optimizer_path}")
    else:
        print(f"No checkpoint found in {checkpoint_dir}")
        
def get_scheduler():
    global config, optimizer

    # if config["scheduler"] == "none":
    #     return None
    # elif config["scheduler"] == "cosine":
    #     return torch.optim.lr_scheduler.CosineAnnealingLR(
    #         optimizer,
    #         T_max=config["cos_iters"],
    #         eta_min=config["learning_rate"] * config["lr_min_scale"]
    #     )
    # else:
    #     raise ValueError(f"Unknown scheduler type: {config['scheduler']}")

def run(scheduler_type = "cosine", learning_rate = None, epochs = None):
    global model, checkpoint_dir, motion_latents, audio_latents, config, device, dataloader, optimizer, lr_scheduler, scaler, epoch_losses, iteration_losses

    model.train()
    if checkpoint_dir:
        load_checkpoint()
    
    # dataset = MotionAudioDataset(motion_latents, audio_latents)
    # dataloader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True, pin_memory=True)
    # Initialize optimizer
    
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    # lr_scheduler = get_scheduler()
    
        # Choose scheduler based on config
    if scheduler_type == "cosine":
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=epoch, eta_min=learning_rate * 0.1
        )
    elif scheduler_type == "plateau":
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=0.5, patience=3, verbose=True
        )
    elif scheduler_type == "linear":
        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=1.0, end_factor=0.1, total_iters=epoch
        )
    else:
        lr_scheduler = None
    
    scaler = GradScaler()
    
    # iter_per_epoch = data_size // config["batch_size"] + data_size % config["batch_size"]
    # config["num_epochs"] = config["num_iterations"] // iter_per_epoch
    
    epoch_pbar = tqdm(range(epochs), desc="Training Epochs")
    iteration = 0
    for epoch in epoch_pbar:
        batch_loss = torch.zeros(1).to(device)
        mini_batch_pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)
        for mini_batch in mini_batch_pbar:
            loss = train_step(mini_batch)
            batch_loss += loss
            
            iteration_losses.append(loss.item())
            iteration += 1
            if iteration % 10 == 0:
                epoch_pbar.set_postfix({"Loss": f"{loss.item():.2e}"})
            # mini_batch_pbar.set_postfix({"Loss": f"{loss.item():.2e}"})
            
        avg_loss = batch_loss.item() / len(dataloader)
        epoch_losses.append(avg_loss)

        # if (epoch + 1) % config["save_interval"] == 0:
        print(f"Loss at epoch {epoch+1}: {avg_loss:.2e}")
        # Step the scheduler
        if lr_scheduler is not None:
            if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                lr_scheduler.step(avg_loss)
            else:
                lr_scheduler.step()
        
        # Print current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Current learning rate: {current_lr:.2e}")
        # epoch_pbar.set_postfix({"Loss": f"{avg_loss:.2e}"})
            # plot_and_save_loss()
        # if (epoch + 1) % (config["save_interval"] * 5) == 0:
        #     save_checkpoint(epoch + 1)

def train_step(batch):
    global optimizer, device, prev_seq_length, model_type, model, noise_scheduler, lr_scheduler

    optimizer.zero_grad()

    x = batch['motion_latent'].to(device)
    a = batch['audio_latent'].to(device)

    x_prev, x_gt = x[:, :prev_seq_length], x[:, prev_seq_length:]
    x_input = torch.zeros_like(x_gt)  # Zero out the motion data for now
    a_prev, a_train = a[:, :prev_seq_length], a[:, prev_seq_length:]

    # Convert inputs to BF16
    # x_gt = x_gt.bfloat16()
    # x_prev = x_prev.bfloat16()
    # a_train = a_train.bfloat16()
    # a_prev = a_prev.bfloat16()
    if model_type == "dit":
        noise = torch.randn_like(x_gt)
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (x_gt.shape[0],), device=device).long()
        x_noisy = noise_scheduler.add_noise(x_gt, noise, timesteps)
        x_pred = model(x_noisy, x_prev, a_train, a_prev, timesteps)[:, prev_seq_length:]
    elif model_type == "vanilla":
        # with autocast(dtype=torch.bfloat16):
        x_pred = model(x_input, x_prev, a_train, a_prev)
        loss = torch.nn.functional.mse_loss(x_pred, x_gt)
            
        loss.backward()
    
        # if lr_scheduler:
        #     lr_scheduler.step()
        # else:
        optimizer.step()
    
    return loss

def save_checkpoint(epoch):
    global output_dir, model, optimizer

    checkpoint_dir = os.path.join(output_dir, f"checkpoint_epoch_{epoch}")
    os.makedirs(checkpoint_dir, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, "model.pth"))
    torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pth"))
    print(f"Checkpoint saved at epoch {epoch}")

def plot_and_save_loss():
    global iteration_losses, epoch_losses, output_dir

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 5))

    # Filter losses for each plot
    full_range_losses = [(i, loss) for i, loss in enumerate(iteration_losses) if loss <= 5e-4]
    zoomed_losses = [(i, loss) for i, loss in enumerate(iteration_losses) if loss <= 1e-6]

    # Full range plot
    if full_range_losses:
        iterations, losses = zip(*full_range_losses)
        ax1.plot(iterations, losses)
        ax1.set_title("Iteration Loss (Full Range)")
        ax1.set_xlabel("Iteration")
        ax1.set_ylabel("Loss")
        ax1.set_ylim(0, 5e-4)  # Set the y-axis limit to 1e-4

    # Zoomed-in plot
    if zoomed_losses:
        iterations, losses = zip(*zoomed_losses)
        ax2.plot(iterations, losses)
        ax2.set_title("Iteration Loss (Zoomed)")
        ax2.set_xlabel("Iteration")
        ax2.set_ylabel("Loss")
        ax2.set_ylim(0, 5e-5)  # Set the y-axis limit to 1e-6
        ax2.set_yscale('log')  # Use log scale for better visualization

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "loss_plot.png"))
    plt.close()

    # Save losses to files
    with open(os.path.join(output_dir, "epoch_losses.txt"), "w") as f:
        for loss in epoch_losses:
            f.write(f"{loss}\n")
    
    with open(os.path.join(output_dir, "iteration_losses.txt"), "w") as f:
        for loss in iteration_losses:
            f.write(f"{loss}\n")
    print("Loss plot and data saved")

def validate():
    global model, checkpoint_dir, motion_latents, audio_latents, config, device, prev_seq_length, model_type, noise_scheduler

    model.eval()
    if checkpoint_dir:
        load_checkpoint()
    else:
        raise ValueError("No checkpoint provided for validation")
    
    total_loss = 0
    num_batches = 0
    # Create a validation dataset and dataloader
    valid_dataset = MotionAudioDataset(motion_latents, audio_latents)
    valid_dataloader = DataLoader(valid_dataset, batch_size=config["valid_batch_size"], shuffle=False, pin_memory=True)

    with torch.no_grad():
        for batch in valid_dataloader:
            x = batch['motion_latent'].to(device)
            a = batch['audio_latent'].to(device)

            x_prev, x_gt = x[:, :prev_seq_length], x[:, prev_seq_length:]
            a_prev, a_train = a[:, :prev_seq_length], a[:, prev_seq_length:]

            if model_type == "dit":
                # For DiT model, we need to add noise and timesteps
                noise = torch.randn_like(x_gt)
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (x_gt.shape[0],), device=device).long()
                x_noisy = noise_scheduler.add_noise(x_gt, noise, timesteps)
                x_pred = model(x_noisy, x_prev, a_train, a_prev, timesteps)[:, prev_seq_length:]
            elif model_type == "vanilla":
                x_input = torch.zeros_like(x_gt)  # Zero out the motion data for prediction
                x_pred = model(x_input, x_prev, a_train, a_prev)

            loss = torch.nn.functional.mse_loss(x_pred, x_gt)
            total_loss += loss.item()
            num_batches += 1
            print(f"Batch validation Loss: {loss.item():.10f}")

    avg_loss = total_loss / num_batches

    print(f"Avg validation Loss: {avg_loss:.10f}")
    return avg_loss


In [None]:
init_model()

In [None]:
run(scheduler_type="none", learning_rate = 2e-4, epochs = 2000)

In [None]:
run(dataloader=dataloader, learning_rate = 1e-3, epochs = 1)

In [None]:
run(dataloader=dataloader, learning_rate = 4e-4, epochs = 2)

In [None]:
run(dataloader=dataloader, learning_rate = 3e-4, epochs =1)

In [None]:
run(dataloader=dataloader, learning_rate = 2e-4, epochs = 1)

In [None]:
run(dataloader=dataloader, learning_rate = 1e-4, epochs = 4)

In [None]:
run(dataloader=dataloader, learning_rate = 5e-5, epochs = 2)

In [None]:
run(dataloader=dataloader, learning_rate = 5e-5, epochs = 20)

In [None]:
run(dataloader=dataloader, learning_rate = 5e-5, epochs = 2)