In [6]:
import torch

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_built():
    device = "mps"
else:
    device = "cpu"

device

'mps'

In [29]:
from torch.utils.data import DataLoader

from notebook.mnist_dataset import MnistVideoDataset

image_size = 64
batch_size = 1
num_workers = 4

dataset = MnistVideoDataset(image_size=image_size)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    num_workers=num_workers,
)


In [7]:
# Create Tensor Writer

from torch.utils.tensorboard import SummaryWriter
import os

def create_tensorboard_writer(exp_dir):
    tensorboard_dir = f"{exp_dir}/tensorboard"
    os.makedirs(tensorboard_dir, exist_ok=True)
    writer = SummaryWriter(tensorboard_dir)
    return writer

writer = create_tensorboard_writer("local/tensorboard")

In [10]:
from notebook.vae import VideoAutoencoderKL

vae = VideoAutoencoderKL(from_pretrained = "stabilityai/sd-vae-ft-ema", device=device)
latent_size = vae.get_latent_size((3, image_size, image_size))

In [11]:
# from notebook.t5 import T5Encoder

# text_encoder = T5Encoder(from_pretrained='DeepFloyd/t5-v1_1-xxl', shardformer=False, model_max_length=1, device=device)

from notebook.t5 import MNistEncoder
text_encoder = MNistEncoder(device)

In [12]:
from notebook.stdit import STDiT

model = STDiT(
    input_size=latent_size,
    in_channels=vae.out_channels,
    caption_channels=text_encoder.output_dim,
    model_max_length=text_encoder.model_max_length,
    dtype = "bf16",

    depth=16,
    hidden_size=32,
    patch_size=(1, 2, 2),
    num_heads=16
).to(device)

In [13]:
num_params = 0
num_params_trainable = 0
for p in model.parameters():
    num_params += p.numel()
    if p.requires_grad:
        num_params_trainable += p.numel()
num_params, num_params_trainable

(358144, 358144)

In [14]:
from notebook.iddpm import IDDPM

scheduler = IDDPM()

In [15]:
import torch.optim as optim

# Assuming `model` is your model and `cfg.lr` is the learning rate from your configuration
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), lr=2e-4, weight_decay=0
)

In [30]:
import logging
from tqdm import tqdm

dtype = torch.float32
start_epoch = start_step = log_step = sampler_start_idx = 0
running_loss = 0.0
n_epoch = 100
num_steps_per_epoch = len(dataloader)

logger = logging.getLogger(__name__)

def all_reduce_mean(tensor: torch.Tensor):
    dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
    tensor.div_(dist.get_world_size())
    return tensor

for epoch in range(n_epoch):
    dataloader_iter = iter(dataloader)
    logger.info(f"Beginning epoch {epoch}...")

    with tqdm(
        range(start_step, num_steps_per_epoch),
        desc=f"Epoch {epoch}",
        total=num_steps_per_epoch,
        initial=start_step,
    ) as pbar:
        for step in pbar:
            batch = next(dataloader_iter)
            x = batch["video"].to(device, dtype)  # [B, C, T, H, W]
            y = batch["text"]

            with torch.no_grad():
                # Prepare visual inputs
                x = vae.encode(x)  # [B, C, T, H/P, W/P]
                # Prepare text inputs
                model_args = text_encoder.encode(y)

            # Diffusion
            t = torch.randint(0, scheduler.num_timesteps, (x.shape[0],), device=device)
            loss_dict = scheduler.training_losses(model, x, t, model_args)

            # Backward & update
            loss = loss_dict["loss"].mean()
            booster.backward(loss=loss, optimizer=optimizer)
            optimizer.step()
            optimizer.zero_grad()

            # Update EMA
            # update_ema(ema, model.module, optimizer=optimizer)

            # Log loss values:
            all_reduce_mean(loss)
            running_loss += loss.item()
            global_step = epoch * num_steps_per_epoch + step
            log_step += 1

            avg_loss = running_loss / log_step
            pbar.set_postfix({"loss": avg_loss, "step": step, "global_step": global_step})
            running_loss = 0
            log_step = 0
            writer.add_scalar("loss", loss.item(), global_step)

            # # Save checkpoint
            # if (global_step + 1) % 100 == 0:
            #     save(
            #         booster,
            #         model,
            #         ema,
            #         optimizer,
            #         lr_scheduler,
            #         epoch,
            #         step + 1,
            #         global_step + 1,
            #         cfg.batch_size,
            #         coordinator,
            #         exp_dir,
            #         ema_shape_dict,
            #     )
            #     logger.info(
            #         f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {exp_dir}"
            #     )

    start_step = 0

Epoch 0:   0%|                                                                                                                                                      | 0/60000 [00:02<?, ?it/s]


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.