In [None]:
import torch
import torchvision
from torchvision import transforms
import PIL
import numpy as np
import matplotlib.pyplot as plt

In [None]:
device = torch.device(torch.accelerator.current_accelerator() if torch.accelerator.is_available() else 'cpu')
device

In [None]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),(0.247, 0.243, 0.261))
])

train_set = torchvision.datasets.CIFAR10(root="./data", train=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
images, labels = next(iter(train_loader))
print(images.shape, labels.shape)
grid = torchvision.utils.make_grid(images, nrow=16)
grid = grid.permute(1,2,0)

plt.figure(figsize=(10,10))
plt.imshow(grid)

In [None]:
from src.engine.trainer import RectifiedFlowTrainer
from tqdm import tqdm
import itertools
import wandb
from torchinfo import summary
from src.models.rope_dit_modelling import RoPEDiT


In [None]:

def train_loop(trainer, train_loader, num_iter, device, run=None):
    trainer.model.to(device)
    trainer.model.train()

    data_iter = itertools.cycle(train_loader)

    progress_bar = tqdm(range(1, num_iter+1))
    for step in progress_bar:

        images, labels = next(data_iter)
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        loss = trainer.step(images, labels)

        if step%100==0 and (run is not None):
            run.log({
                "loss":loss,
                "step":step
            })

        if num_iter > 10 and step%(num_iter//10) == 0:
            trainer.model.eval()
            torch.save(trainer.model.state_dict(), f"trained/checkpoints/{run.name}_step_{step}_loss_{loss:.4f}.pt")
            trainer.model.train()

        progress_bar.set_description(f"step : {step} | loss : {loss}")

    trainer.model.eval()
    
    if run is not None:
        torch.save(trainer.model.state_dict(), f"trained/{run.name}_final.pt")
        run.finish()


In [None]:
############ hyperparameters #############

model_config = dict(model_dim=256,
                    num_dit_blocks=6,
                    num_attn_heads=8,
                    patch_size=4,
                    num_classes=len(train_set.classes),
                    in_channels=3,
                    use_cfg=True)

learning_rate = 3e-4
drop_prob = 0.2
num_iterations = 5_000
batch_size = 512

##########################################

DiT = RoPEDiT(**model_config)

print(summary(DiT,
        input_data=[
            images,
            labels,
            torch.ones_like(labels),
        ]))
DiT.compile(fullgraph=True)

optimizer = torch.optim.AdamW(DiT.parameters(), lr=learning_rate)

trainer = RectifiedFlowTrainer(DiT, optimizer, drop_prob)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

run = None

In [None]:
transform

In [None]:

run = wandb.init(
    project="customDiT",
    entity="divyanshukla",
    config={
        "model_config":model_config,
        "learning_rate":learning_rate,
        "drop_prob":drop_prob,
        "num_iterations":num_iterations,
        "data":{
                "dataset":train_set.filename,
                "transform":transform,
                "batch_size":batch_size,
                },
        
    }
)

In [None]:
train_loop(trainer, train_loader, num_iterations, device, run)

In [None]:
# trainer.model.eval()
# torch.save(trainer.model.state_dict(), f"models/{run.name}_final.pt")

In [None]:
# torch.load(f"models/playful-aardvark-14_final.pt", DiT.state_dict())
# DiT.eval().to(device)

In [None]:
import src.utils.sampler as sampler
from src.utils.sampler import euler_sampler

num_samples = 4

x0 = torch.randn((num_samples, *images.shape[1:]), device=device)
y = torch.randint(len(train_set.classes), size=(num_samples,), device=device)
h = 1e-2
num_steps = int(1/h)
cfg_scale = 4.0

x = euler_sampler(DiT, x0, y, h, num_steps, with_traj=False, cfg_scale=cfg_scale).cpu()

mean = torch.tensor((0.4914, 0.4822, 0.4465)).view(1, -1, 1, 1) 
std = torch.tensor((0.247, 0.243, 0.261)).view(1, -1, 1, 1)

x = x*std + mean

grid = torchvision.utils.make_grid(x, nrow=(num_samples//2))
grid = grid.permute(1, 2, 0)

plt.figure(figsize=(10,10))
plt.imshow(grid)

In [None]:
list(train_set.classes[i] for i in list(y.cpu()))