In [53]:
import clip
import math
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data


def train(model, train_loader, z_loader, optimizer, scheduler, model_depth):
    model.train()
    train_losses = []
    for (x, y), z in tqdm(zip(train_loader, z_loader), total=len(train_loader)):
        # batch = [b.cuda() for b in batch]
        # print(batch)
        # print(*batch)
        loss = model.loss(x, y, z, model_depth)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_losses.append(loss.item())
    return train_losses


@torch.no_grad()
def eval_loss(model, data_loader, model_depth):
    model.eval()
    total_loss, total = 0, 0
    for batch in data_loader:
        batch = [b.cuda() for b in batch]
        loss = model.loss(*batch, model_depth)
        total_loss += loss.item() * batch[0].shape[0]
        total += batch[0].shape[0]
    avg_loss = total_loss / total
    return avg_loss


def get_lr(step, total_steps, warmup_steps, use_cos_decay):
    if step < warmup_steps:
        mul = (step + 1) / warmup_steps
        return mul
    else:
        if use_cos_decay:
            progress = (step - warmup_steps) / (total_steps - warmup_steps)
            return (1 + math.cos(math.pi * progress)) / 2
        else:
            return 1


def train_epochs(model, train_loader, test_loader, z_loader, train_args):
    epochs, lr = train_args["epochs"], train_args["lr"]
    warmup_steps = train_args.get("warmup", 0)
    use_cos_decay = train_args.get("use_cos_decay", False)
    optimizer = optim.Adam(model.parameters(), lr)
    total_steps = epochs * len(train_loader)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: get_lr(step, total_steps, warmup_steps, use_cos_decay),
    )

    train_losses = []
    # test_losses = [eval_loss(model, test_loader)]
    for epoch in tqdm(list(range(epochs))):
        train_loss = train(
            model,
            train_loader,
            z_loader,
            optimizer,
            scheduler,
            train_args["model_depth"],
        )
        train_losses.extend(train_loss)
        # test_loss = eval_loss(model, test_loader)
        test_losses.append(test_loss)
        print(
            f"Epoch {epoch}, Test loss {test_loss:.4f}, Train loss {np.mean(train_loss):.4f}"
        )

    return np.array(train_losses)  # , np.array(test_losses)


class Diffusion:
    def __init__(self, model, data_shape, encode_fn=None, decode_fn=None):
        self.model = model
        self.data_shape = data_shape
        self.encode_fn = encode_fn
        self.decode_fn = decode_fn
        self.clip, self.transform = clip.load("ViT-B/32", device="cuda")

    def _get_alpha_sigma(self, t):
        return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)

    def _expand(self, t):
        for _ in range(len(self.data_shape)):
            t = t[..., None]
        return t

    def _noise(self, x, eps=None, t=None):
        if t is None:
            t = torch.rand(x.shape[0], dtype=torch.float32, device=x.device)
        if eps is None:
            eps = torch.randn_like(x)
        alpha_t, sigma_t = self._get_alpha_sigma(self._expand(t))
        x_t = alpha_t * x + sigma_t * eps
        return x_t, eps, t

    def _x_hat(self, x_t, eps_hat, t):
        alpha_t, sigma_t = self._get_alpha_sigma(self._expand(t))
        return (x_t - sigma_t * eps_hat) / alpha_t

    def loss(self, x, y=None, z=None, model_depth=0):
        z = x if model_depth == 0 else z
        # if self.encode_fn is not None:
        #     x = self.encode_fn(x)
        print("in loss")
        print(x)
        print(y)
        print(z)
        print(model_depth)
        print(x.shape)
        print(y.shape)
        y = self.clip.encode_text(y)
        x_t, eps, t = self._noise(x)
        if y is not None:
            eps_hat = self.model.forward_step(x_t, y, t, model_depth)
        else:
            eps_hat = self.model.forward_step(x_t, t, model_depth)
        return torch.mean((eps_hat - eps) ** 2)

    @torch.no_grad()
    def sample(self, n, num_steps, clip_denoised=False, model_fn=None, cfg_val=None):
        model_fn = model_fn or self.model

        ts = np.linspace(1 - 1e-4, 1e-4, num_steps + 1, dtype=np.float32)
        x = torch.randn(n, *self.data_shape, dtype=torch.float32).cuda()
        for i in range(num_steps):
            t_cur = torch.full((n,), ts[i], dtype=torch.float32).cuda()
            t_next = torch.full((n,), ts[i + 1], dtype=torch.float32).cuda()

            alpha_cur, sigma_cur = self._get_alpha_sigma(self._expand(t_cur))
            alpha_next, sigma_next = self._get_alpha_sigma(self._expand(t_next))
            ddim_sigma = (sigma_next / sigma_cur) * torch.sqrt(
                1 - alpha_cur**2 / alpha_next**2
            )

            if cfg_val is None:
                eps_hat = model_fn(x, t_cur)
            else:
                eps_hat_cond = model_fn(x, t_cur)
                eps_hat_uncond = model_fn(x, t_cur, dropout_cond=True)
                eps_hat = eps_hat_uncond + cfg_val * (eps_hat_cond - eps_hat_uncond)

            x_hat = self._x_hat(x, eps_hat, t_cur)
            if clip_denoised:
                x_hat = torch.clamp(x_hat, -1, 1)
            x = (
                alpha_next * x_hat
                + torch.sqrt((sigma_next**2 - ddim_sigma**2).clamp(min=0)) * eps_hat
                + ddim_sigma * torch.randn_like(eps_hat)
            )
        # if self.decode_fn is not None:
        #     x = self.decode_fn(x)
        return x

    def __getattr__(self, name):
        if name in ["train", "eval", "parameters", "state_dict", "load_state_dict"]:
            return getattr(self.model, name)
        return self.__getattribute__(name)


class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_hidden_layers, timestep_dim=1):
        super().__init__()
        self.timestep_dim = timestep_dim
        prev_dim = input_dim + timestep_dim
        net = []
        dims = [hidden_dim] * n_hidden_layers + [input_dim]
        for i, dim in enumerate(dims):
            net.append(nn.Linear(prev_dim, dim))
            if i < len(dims) - 1:
                net.append(nn.ReLU())
            prev_dim = dim
        self.net = nn.Sequential(*net)

    def forward(self, x, t):
        x = torch.cat([x, t[:, None]], dim=1)
        return self.net(x)

In [54]:
# !if [ -d deepul ]; then rm -Rf deepul; fi
# !git clone https://github.com/rll/deepul.git
# !pip install ./deepul
# !pip install scikit-learn

In [55]:
from deepul.hw4_helper import *
import warnings

warnings.filterwarnings("ignore")

In [56]:
def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(0,
                                             half, dtype=torch.float32) / half
    ).cuda()
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], axis=-1)
    if dim % 2:
        embedding = torch.cat(
            [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


class Attention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.wq = nn.Linear(hidden_size, hidden_size, bias=False)
        self.wk = nn.Linear(hidden_size, hidden_size, bias=False)
        self.wv = nn.Linear(hidden_size, hidden_size, bias=False)
        self.wo = nn.Linear(hidden_size, hidden_size, bias=False)

    def forward(self, x):
        q, k, v = self.wq(x), self.wk(x), self.wv(x)
        q, k, v = map(lambda x: x.view(
            *x.shape[:-1], self.num_heads, -1), (q, k, v))
        attn_weights = torch.einsum(
            "bqhd,bkhd->bhqk", q, k) * q.shape[-1] ** -0.5
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_outputs = torch.einsum("bhqk,bkhd->bqhd", attn_weights, v)
        attn_outputs = attn_outputs.reshape(*attn_outputs.shape[:-2], -1)
        return self.wo(attn_outputs)


class MLP(nn.Module):
    def __init__(self, hidden_size, expand=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * expand),
            nn.SiLU(),
            nn.Linear(hidden_size * expand, hidden_size),
        )

    def forward(self, x):
        return self.net(x)


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
    assert embed_dim % 2 == 0

    # use half of dimensions to encode grid_h
    emb_h = get_1d_sincos_pos_embed_from_grid(
        embed_dim // 2, grid[0])  # (H*W, D/2)
    emb_w = get_1d_sincos_pos_embed_from_grid(
        embed_dim // 2, grid[1])  # (H*W, D/2)

    emb = np.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
    return emb


def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = np.arange(embed_dim // 2, dtype=np.float64)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = np.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = np.sin(out)  # (M, D/2)
    emb_cos = np.cos(out)  # (M, D/2)

    emb = np.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_2d_sincos_pos_embed(embed_dim, grid_size):
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)  # here w goes first
    grid = np.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    return pos_embed


class DiTBlock(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.attention = Attention(hidden_size, num_heads)
        self.mlp = MLP(hidden_size)
        self.attention_norm = nn.LayerNorm(
            hidden_size, elementwise_affine=False)
        self.mlp_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
            self.adaLN_modulation(c).chunk(6, dim=1)
        )
        x = x + gate_msa.unsqueeze(1) * self.attention(
            modulate(self.attention_norm(x), shift_msa, scale_msa)
        )
        x = x + gate_mlp.unsqueeze(1) * self.mlp(
            modulate(self.mlp_norm(x), shift_mlp, scale_mlp)
        )
        return x


class FinalLayer(nn.Module):
    def __init__(self, hidden_size, patch_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(
            hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(
            hidden_size, patch_size * patch_size * out_channels, bias=True
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class DiT(nn.Module):
    def __init__(
        self,
        input_shape,
        patch_size,
        hidden_size,
        num_heads,
        num_layers,
        num_classes=10,
        frequency_embedding_size=64,
        cfg_dropout_prob=0.1,
    ):
        super().__init__()
        self.cfg_dropout_prob = cfg_dropout_prob
        self.frequency_embedding_size = frequency_embedding_size
        self.num_classes = num_classes
        self.input_shape = input_shape
        self.patch_size = patch_size

        self.time_embedding = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size),
        )
        # batch_size, num_blocks, hidden_size
        self.label_embedder = nn.Embedding(num_classes + 1, hidden_size)

        C, H, W = input_shape
        n_patches = (H // patch_size) * (W // patch_size)
        self.patchify = nn.Conv2d(
            C, hidden_size, patch_size, stride=patch_size)
        self.register_buffer(
            "pos_embed",
            torch.FloatTensor(get_2d_sincos_pos_embed(
                hidden_size, H // patch_size)),
        )
        self.blocks = nn.ModuleList(
            [DiTBlock(hidden_size, num_heads) for _ in range(num_layers)]
        )
        self.out_layer = FinalLayer(hidden_size, patch_size, C)

    def unpatchify(self, x):
        C, H, W = self.input_shape
        P = self.patch_size
        x = x.view(x.shape[0], H // P, W // P, P, P, C)
        x.shape[0], H // P, P, W // P, P, C
        x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
        x = x.view(x.shape[0], C, H, W)
        return x

    def forward(self, x, y, t):
        # x: BCHW
        x = self.patchify(x).movedim(1, -1)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        x = x + self.pos_embed

        t = self.time_embedding(timestep_embedding(
            t, self.frequency_embedding_size))
        if self.training:
            drop_ids = torch.rand(
                y.shape[0], device=y.device) < self.cfg_dropout_prob
            y = torch.where(drop_ids, self.num_classes, y)
        y = self.label_embedder(y)
        c = t + y
        for block in self.blocks:
            x = block(x, c)
        x = self.out_layer(x, c)
        return self.unpatchify(x)

In [57]:
import clip
import os
import sys
import torch

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../../data")))


def load_q3_data():
    import multimnist

    ds = 1000
    images_train, text_descs_train = multimnist.generate(ds)
    images_test, text_descs_test = multimnist.generate(ds)
    text_descs_train = torch.cat(
        [clip.tokenize(txt).to("cuda") for txt in text_descs_train]
    )
    text_descs_test = torch.cat(
        [clip.tokenize(txt).to("cuda") for txt in text_descs_test]
    )
    train_data = {"images": images_train, "texts": text_descs_train}
    test_data = {"images": images_test, "texts": text_descs_test}
    return train_data, test_data
    # return dataset[:500], dataset[500:]
    # train_data = torchvision.datasets.CIFAR10(
    #     "./data", transform=torchvision.transforms.ToTensor(), download=True, train=True
    # )
    # test_data = torchvision.datasets.CIFAR10(
    #     "./data",
    #     transform=torchvision.transforms.ToTensor(),
    #     download=True,
    #     train=False,
    # )
    # return train_data, test_data


def show_samples(
    samples: np.ndarray, fname: str = None, nrow: int = 10, title: str = "Samples"
):
    import torch
    from torchvision.utils import make_grid

    samples = (torch.FloatTensor(samples) / 255).permute(0, 3, 1, 2)
    grid_img = make_grid(samples, nrow=nrow)
    plt.figure()
    plt.title(title)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.axis("off")

    if fname is not None:
        savefig(fname)
    else:
        plt.show()

In [58]:
# Hyperparameters.
seed = 0
batch_size = 1  # wondering trained on some fixed slot makes difference
num_slots = 7
num_iterations = 3
hid_dim = 64
resolution = (128, 128)
online = False
results_dir = "/shared/rzhang/slot_att/results/full-token-compressor"
model_name = "dummy"

In [59]:
from model import *
import wandb
from glob import glob


def q3_b(train_data, train_texts, test_data, test_texts, vae):
    """
    train_data: A (50000, 32, 32, 3) numpy array of images in [0, 1]
    train_labels: A (50000,) numpy array of class labels
    test_data: A (10000, 32, 32, 3) numpy array of images in [0, 1]
    test_labels: A (10000,) numpy array of class labels
    vae: a pretrained VAE

    Returns
    - a (# of training iterations,) numpy array of train losses evaluated every minibatch
    - a (# of num_epochs + 1,) numpy array of test losses evaluated at the start of training and the end of every epoch
    - a numpy array of size (10, 10, 32, 32, 3) of samples in [0, 1] drawn from your model.
      The array represents a 10 x 10 grid of generated samples. Each row represents 10 samples generated
      for a specific class (i.e. row 0 is class 0, row 1 class 1, ...). Use 512 diffusion timesteps
    """
    # train_data = 2 * np.transpose(train_data, (0, 3, 1, 2)) - 1
    # test_data = 2 * np.transpose(test_data, (0, 3, 1, 2)) - 1
    image_data = torch.FloatTensor(train_data)
    train_data = data.TensorDataset(torch.FloatTensor(train_data), train_texts)
    test_data = data.TensorDataset(torch.FloatTensor(test_data), test_texts)
    train_loader = data.DataLoader(train_data, batch_size=256, shuffle=True)
    test_loader = data.DataLoader(train_data, batch_size=256)
    z_dataloader = data.DataLoader(
        torch.ones(len(train_data)).to("cpu"), batch_size=batch_size, shuffle=True
    )

    scale_factor = 1.3101

    def encode_fn(x):
        z = vae.encode(x)
        z /= scale_factor
        return z

    def decode_fn(z):
        z *= scale_factor
        x = vae.decode(z)
        return x

    resolution = (128, 128)
    model_slotattention = SlotAttentionAutoEncoder(
        resolution,
        num_slots,
        num_iterations,
        hid_dim,
        cnn_depth=4,
        use_trfmr=False,
        use_transformer_encoder=False,
        use_transformer_decoder=False,
    ).to(device)
    model_direct = SlotAttentionCompressionAutoencoderDirect(
        model_slotattention, num_slots, hid_dim
    ).to(device)
    model = SlotAttentionCompressionDiffusion(model_direct, hid_dim).to(device)
    image_shape = (3, 128, 128)
    # DiT((4, 8, 8), patch_size=2, hidden_size=512, num_heads=8, num_layers=12).cuda()
    model = Diffusion(model, image_shape,
                      encode_fn=encode_fn, decode_fn=decode_fn)

    for model_depth in range(num_slots):
        # Make results folder (holds all experiment subfolders)
        if online:
            os.makedirs(results_dir, exist_ok=True)
            exp_index = len(glob(f"{results_dir}/*"))
            model_filename = f"{exp_index:03d}-{model_name}-slots{num_slots}-layer{num_slots-model_depth}-"
            wandb.init(
                dir=os.path.abspath(results_dir),
                project=f"slot_att_pretrained",
                name=model_filename,
                job_type="train",
                mode="online",
            )
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                },
                results_dir + f"/{model_filename}.ckpt",
            )

        train_losses = train_epochs(
            model,
            train_loader,
            test_loader,  # TODO: use z_dataloader,
            z_dataloader,
            dict(
                epochs=20,
                lr=1e-3,
                warmup=100,
                use_cos_decay=True,
                model_depth=model_depth,
            ),
        )
        test_losses = None

        # we want to replace the dataloader with the compressed data
        z_new = []
        with torch.no_grad():
            model.eval()
            for (image, text), z in tqdm(
                zip(train_loader, z_dataloader), total=len(z_dataloader)
            ):
                z = x["image"].to(device) if model_depth == 0 else z.to(device)
                z_fwd = model.get_compressed(z, model_depth).detach().clone()
                z_new.append(z_fwd)
            z_new = torch.cat(z_new, dim=0).cpu()

        # train_data = data.TensorDataset(z_new, train_texts, image_data)
        z_dataloader = torch.utils.data.DataLoader(
            z_new, batch_size=batch_size, shuffle=True
        )

        if online:
            torch.save(
                {
                    "model_state_dict": model.state_dict(),
                },
                opt.results_dir + f"/{model_filename}.ckpt",
            )
            wandb.finish()

    # torch.save(model.state_dict(), "q3_model.pt")

    text = "dummy"
    text_emb = clip.tokenize(text).cuda()

    # def model_fn(x, t):
    #     return model.model(x, text_emb, t)

    # samples = model.sample(
    #     labels.shape[0], 512, clip_denoised=False, model_fn=model_fn)
    samples = samples.movedim(1, -1).cpu().numpy()
    samples = samples * 0.5 + 0.5
    samples = samples.reshape(10, 10, *samples.shape[1:])
    return train_losses, test_losses, samples


# import os
# def get_data_dir(hw_number: int):
#     return os.path.join("deepul", "homeworks", f"hw{hw_number}", "data")


# def load_pretrain_vae():
#     data_dir = get_data_dir(4)
#     vae = VAE()
#     vae.load_state_dict(torch.load(os.path.join(data_dir, f"vae_cifar10.pth")))
#     vae.eval()
#     return vae.cuda()

In [60]:
from torchvision.utils import make_grid
import torch

# opt = parse_args()
train_data, test_data = load_q3_data()
train_images = train_data["images"] / 255.0
test_images = test_data["images"] / 255.0
vae = load_pretrain_vae()
train_losses, test_losses, samples = q3_b(
    train_images, train_data["texts"], test_images, test_data["texts"], vae
)

print(f"Final Test Loss: {test_losses[-1]:.4f}")
save_training_plot(
    train_losses, test_losses, "Q3(b) Train Plot", "results/q3_b_train_plot.png"
)

samples = samples.reshape(-1, *samples.shape[2:])
show_samples(
    samples * 255.0,
    fname=f"results/q3_b_samples.png",
    title=f"Q3(b) CIFAR-10 generated samples",
)

Creating MultiMNIST dataset...


 73%|███████▎  | 733/1000 [00:00<00:00, 1251.19it/s]

100%|██████████| 1000/1000 [00:00<00:00, 1231.25it/s]


Creating MultiMNIST dataset...


100%|██████████| 1000/1000 [00:00<00:00, 1233.99it/s]


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

in loss
tensor([[[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         ...,

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]],

         [[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.]]],


        [[[-1., -1., -1.,  ..., -1., -1., -1.],
          [-1., -1., -1.,  ..., -1., -1., -1.],
   

TypeError: forward_step() takes 3 positional arguments but 5 were given

In [None]:
# q3b_save_results(q3_b)

In [None]:
import os

os._exit(00)

: 