In [34]:
import torch
from torchinfo import summary
from torchvision import utils as vutils
from torchvision.transforms.v2.functional import to_pil_image
from tqdm import tqdm

In [3]:
train_data = torch.load("../data/train_data.pth", weights_only=False)
test_data = torch.load("../data/test_data.pth", weights_only=False)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=128, shuffle=False)

In [4]:
class RF:
    def __init__(self, model, ln=True) -> None:
        self.model = model
        self.ln = ln

    def forward(self, x, cond):
        b = x.size(0)
        if self.ln:
            nt = torch.randn((b, 1), device=x.device)
            t = torch.sigmoid(nt)
        else:
            t = torch.rand((b, 1), device=x.device)

        texp = t.view([b, *([1] * len(x.shape[1:]))])
        z1 = torch.randn_like(x)
        zt = (1 - texp) * x + texp * z1
        vtheta = self.model(zt, t, cond)
        batchwise_mse = ((z1 - x - vtheta) ** 2).mean(dim=list(range(1, len(x.shape))))
        tlist = batchwise_mse.detach().cpu().reshape(-1).tolist()
        ttloss = list(zip(t, tlist))
        return batchwise_mse.mean(), ttloss

    @torch.no_grad()
    def sample(self, z, cond, null_cond=None, sample_steps=50, cfg=2.0):
        b = z.size(0)
        dt = 1.0 / sample_steps
        dt = torch.tensor([dt] * b).to(z.device).view([b, *([1] * len(z.shape[1:]))])
        images = [z]
        for i in range(sample_steps, 0, -1):
            t = i / sample_steps
            t = torch.tensor([t] * b, device=z.device)

            vc = self.model(z, t, cond)
            if null_cond is not None:
                vu = self.model(z, t, null_cond)
                vc = vu + cfg * (vc - vu)

            z = z - dt * vc
            images.append(z)
        return images

In [None]:
from digitdreamer import DiT

device = torch.device("mps")
model = DiT().to(device)
rf = RF(model)
summary(model, input_size=(128, 8, 2, 2), depth=2, device=device)

In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr=6e-4)

In [None]:
val_loss = 0

for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        y += 1

        drop = torch.rand(x.shape[0])
        y[drop < 0.1] = 0

        optimizer.zero_grad()
        loss, _ = rf.forward(x, y)
        loss.backward()
        optimizer.step()

        pbar.set_postfix_str(f"loss: {loss.item():.4f}, val_loss: {val_loss:.4f}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y += 1
            loss, _ = rf.forward(x, y)
            val_loss += loss.item()

    val_loss /= len(test_loader)

In [None]:
from digitdreamer import Autoencoder

autoencoder = Autoencoder()
autoencoder.encoder.load_state_dict(
    torch.load("../models/encoder.pth", weights_only=True),
)
autoencoder.decoder.load_state_dict(
    torch.load("../models/decoder.pth", weights_only=True),
)

autoencoder.to(device)
autoencoder.eval()

# generate samples
model.eval()
with torch.no_grad():
    noise = torch.randn(10, 8, 2, 2, device=device)
    cond = torch.arange(10, device=device, dtype=torch.long).repeat(1) + 1
    uncond = torch.zeros(10, device=device, dtype=torch.long)

    samples = rf.sample(noise, cond, uncond, sample_steps=12, cfg=2.25)
    samples = torch.cat(samples, dim=0)

    imgs = autoencoder.decoder(samples).cpu()
    imgs = imgs.view(-1, 10, 1, 32, 32)

    pil_imgs = [to_pil_image(vutils.make_grid(img, nrow=10)) for img in imgs]
    pil_imgs = pil_imgs + pil_imgs[::-1]

    # save as gif
    pil_imgs[0].save(
        "../assets/samples.gif",
        save_all=True,
        append_images=pil_imgs[1:],
        duration=100,
        loop=0,
    )

In [12]:
torch.save(model.state_dict(), "../models/diffusion.pth")

In [None]:
model.load_state_dict(torch.load("../models/encoder.pth", weights_only=True))