In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import tqdm as tqdm
from ml_zoo import CelebAHQDataModule, CelebAHQDataModuleConfig, MNISTDataModule, MNISTDataModuleConfig

In [2]:
# dm = CelebAHQDataModule(
#     CelebAHQDataModuleConfig(
#         data_dir="data",
#         batch_size=64,
#         num_workers=2,
#         persistent_workers=True,
#         transforms=[
#             torchvision.transforms.Resize((32, 32)),
#             torchvision.transforms.ToTensor(),
#         ],
#     )
# )
dm = MNISTDataModule(
    MNISTDataModuleConfig(
        data_dir="data",
        batch_size=64,
        num_workers=2,
        persistent_workers=True,
        transforms=[
            torchvision.transforms.Resize((32, 32)),
            torchvision.transforms.ToTensor(),
        ],
    )
)
dm.prepare_data()
dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

In [3]:
import diffusers

scheduler = diffusers.DDIMScheduler()

In [24]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.time_emb = nn.Embedding(1000, 256)
        self.class_emb = nn.Embedding(11, 256)

        self.mlp1 = nn.Linear(256, 512)
        self.mlp2 = nn.Linear(256, 256)

        self.fc1 = nn.Linear(32 * 32, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 512)
        self.fc5 = nn.Linear(512, 32 * 32)

    def forward(self, x, y, t):
        y_emb = self.class_emb(y)
        t_emb = self.time_emb(t)

        mlp1_out = F.relu(self.mlp1(y_emb + t_emb))
        mlp2_out = F.relu(self.mlp2(y_emb + t_emb))

        x = x.view(x.size(0), -1)
        identity = x
        x = F.relu(self.fc1(x)) + mlp1_out
        x = F.relu(self.fc2(x)) + mlp2_out
        x = F.relu(self.fc3(x)) + mlp2_out
        x = F.relu(self.fc4(x)) + mlp1_out
        x = self.fc5(x)
        x = identity - x
        return x.view(x.size(0), 1, 32, 32)

In [25]:
model = Model().to("mps")

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")

Model has 1,835,008 parameters



In [26]:
@torch.no_grad()
def generate(steps=50):
    model.eval()
    noise = torch.randn((33, 1, 32, 32), device="mps")
    y = torch.arange(11, device=noise.device).repeat(3)
    t_steps = torch.arange(0, 1000, 1000//steps, device=noise.device)

    scheduler.set_timesteps(steps)
    for t in reversed(t_steps):
        t_s = torch.full((y.shape[0],), t, device=noise.device)
        out = model(noise, y, t_s).view(-1, 1, 32, 32)
        img = scheduler.step(out, t, noise, eta=0).prev_sample
        
    torchvision.utils.save_image(img.view(-1, 1, 32, 32), "samples.png", nrow=11)

@torch.no_grad()
def validate():
    model.eval()
    t_loss = torch.tensor(0.0, device="mps")
    for x, y in val_loader:
        x, y = x.to("mps"), y.to("mps")
        # y += 1

        noise = torch.randn_like(x)

        t = torch.randint(0, 1000, (x.shape[0],), device=x.device)
        noisy_x = scheduler.add_noise(x, noise, t)

        out = model(noisy_x, y, t)
        loss = F.mse_loss(out, noise)
        t_loss += loss
    return t_loss.item() / len(val_loader)

In [27]:
from ema_pytorch import EMA
ema = EMA(model, beta=0.9999, update_after_step=100, update_every=1)
optimizer = optim.AdamW(model.parameters(), lr=6e-4)

In [29]:
x, y = next(iter(train_loader))
x, y = x.to("mps"), y.to("mps")

x, y = x.repeat(64, 1, 1, 1), y.repeat(64)
print(x.shape, y.shape)

pbar = tqdm.trange(240_000)
for i in pbar:
    noise = torch.randn_like(x)
    t = torch.randint(0, 1000, (x.shape[0],), device=x.device)

    noisy_x = scheduler.add_noise(x, noise, t)

    optimizer.zero_grad()
    out = model(noisy_x, y, t)
    loss = F.mse_loss(out, noise)
    loss.backward()
    optimizer.step()
    ema.update()
    
    pbar.set_postfix(loss=loss.item())

    if i % 500 == 0:
        generate()


torch.Size([4096, 1, 32, 32]) torch.Size([4096])


  0%|          | 821/240000 [00:28<2:18:47, 28.72it/s, loss=0.112]


KeyboardInterrupt: 

In [None]:
for epoch in range(10):
    generate()
    val_loss = validate()

    model.train()
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, y in pbar:
        x, y = x.to("mps"), y.to("mps")
        
        # y = F.dropout(y.float(), 0.1).long()
        noise = torch.randn_like(x)

        t = torch.randint(0, scheduler.config.num_train_timesteps, (x.shape[0],), device=x.device)
        noisy_x = scheduler.add_noise(x, noise, t)

        optimizer.zero_grad()
        out = model(noisy_x, y, t)
        loss = F.mse_loss(out, noise)
        loss.backward()
        optimizer.step()

        # ema.update()

        pbar.set_postfix(loss=loss.item(), val_loss=val_loss)

In [None]:
generate()