In [1]:
# To use the files in the parent directory run this cell
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from src.models.vqvae import VQVAE
from src.models.var import VAR
from src.datasets.hugging_face_dataset import HuggingFaceDataset

In [3]:
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import gc


model_params = {
    "mnist": {
        "VQVAE_DIM": 64,
        "VOCAB_SIZE": 32,
        "PATCH_SIZES": [1, 2, 3, 4, 8],
        "VAR_DIM": 64,
        "N_HEADS": 4,
        "N_LAYERS": 6,
        "channels": 1,
    },
    "cifar": {
        "VQVAE_DIM": 512,
        "VOCAB_SIZE": 2048,
        "PATCH_SIZES": [1, 2, 3, 4, 6, 8],
        "VAR_DIM": 512,
        "N_HEADS": 16,
        "N_LAYERS": 12,
        "channels": 3,
    },
    "small": {
        "VQVAE_DIM": 512,
        "VOCAB_SIZE": 1024,
        "PATCH_SIZES": [1, 2, 3, 4, 6, 8],
        "VAR_DIM": 512,
        "N_HEADS": 16,
        "N_LAYERS": 16,
        "channels": 3,
    },
    "medium": {
        "VQVAE_DIM": 512,
        "VOCAB_SIZE": 2048,
        "PATCH_SIZES": [1, 2, 3, 4, 6, 8],
        "VAR_DIM": 512,
        "N_HEADS": 32,
        "N_LAYERS": 20,
        "channels": 3,
    },
    "large": {
        "VQVAE_DIM": 512,
        "VOCAB_SIZE": 4096,
        "PATCH_SIZES": [1, 2, 3, 4, 6, 8],
        "VAR_DIM": 512,
        "N_HEADS": 64,
        "N_LAYERS": 24,
        "channels": 3,
    },
}

training_params = {
    "mnist": {
        "VQVAE": {
            "batch_size": 2048,
            "lr": 3e-4,
            "epochs": 40,
        },
        "VAR": {
            "batch_size": 1024,
            "lr": 1e-3,
            "epochs": 100,
        },
    },
    "cifar": {
        "VQVAE": {
            "batch_size": 512,
            "lr": 3e-4,
            "epochs": 100,
        },
        "VAR": {
            "batch_size": 64,
            "lr": 1e-4,
            "epochs": 100,
        },
    },
}


def get_data(batch_size=1024, dataset="mnist"):
    if dataset == "cifar":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.RandomCrop(32),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        train_ds = datasets.CIFAR10(
            root="./data", train=True, download=True, transform=transform
        )
        test_ds = datasets.CIFAR10(
            root="./data", train=False, download=True, transform=transform
        )
    elif dataset == "mnist":
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Pad(2),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        train_ds = datasets.MNIST(
            root="./data", train=True, download=True, transform=transform
        )
        test_ds = datasets.MNIST(
            root="./data", train=False, download=True, transform=transform
        )
    else:
        transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((32, 32)),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )
        # Use HuggingFace datasets
        train_ds = HuggingFaceDataset(dataset_path=dataset, split="train", transform=transform)
        test_ds = HuggingFaceDataset(dataset_path=dataset, split="val", transform=transform)


    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=False, drop_last=False
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, drop_last=True
    )

    print(len(train_loader), len(test_loader))

    return train_loader, test_loader


def plot_images(pred, original=None):
    n = pred.size(0)
    pred = pred * 0.5 + 0.5
    pred = pred.clamp(0, 1)
    img = pred.cpu().detach()

    if original is not None:
        original = original * 0.5 + 0.5
        original = original.clamp(0, 1)
        original = original.cpu().detach()
        img = torch.cat([original, img], dim=0)

    img_grid = make_grid(img, nrow=n)
    img_grid = img_grid.permute(1, 2, 0).numpy()
    img_grid = (img_grid * 255).astype("uint8")
    plt.imshow(img_grid)
    plt.axis("off")


In [4]:
dataset = "zzsi/afhq64_16k"
model_params = model_params["small"]
training_params = training_params["cifar"]

# Training VQVAE

In [4]:
from tqdm import tqdm

print("=" * 10 + "Training VQVAE" + "=" * 10)
vq_model = VQVAE(
    model_params["VQVAE_DIM"],
    model_params["VOCAB_SIZE"],
    model_params["PATCH_SIZES"],
    num_channels=model_params["channels"],
)
optimizer = torch.optim.AdamW(
    vq_model.parameters(), lr=training_params["VQVAE"]["lr"]
)

train_loader, test_loader = get_data(
    batch_size=training_params["VQVAE"]["batch_size"], dataset=dataset
)
vq_model = vq_model.to("cuda")

# All epochs
for epoch in tqdm(range(training_params["VQVAE"]["epochs"])):
    epoch_loss = 0
    epoch_recon_loss = 0
    # Single epochs
    for i, (x, c) in enumerate(train_loader):
        x, c = x.cuda(), c.cuda()
        optimizer.zero_grad()
        xhat, r_maps, idxs, scales, q_loss = vq_model(x)
        recon_loss = F.mse_loss(xhat, x)
        loss = recon_loss + q_loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_recon_loss += recon_loss.item()

    epoch_loss /= len(train_loader)
    epoch_recon_loss /= len(train_loader)
    print(f"Epoch: {epoch}, Loss: {epoch_loss}, Recon Loss: {epoch_recon_loss}")

    if epoch % 5 == 0:
        with torch.no_grad():
            total_loss = 0
            total_recon_loss = 0
            for i, (x, c) in enumerate(test_loader):
                x, c = x.cuda(), c.cuda()
                xhat, r_maps, idxs, scales, q_loss = vq_model(x)
                recon_loss = F.mse_loss(xhat, x)
                loss = recon_loss + q_loss
                total_loss += loss.item()
                total_recon_loss += recon_loss.item()

            total_loss /= len(test_loader)
            total_recon_loss /= len(test_loader)

            print(
                f"Epoch: {epoch}, Test Loss: {total_loss}, Test Recon Loss: {total_recon_loss}"
            )

            x = x[:10, :].cuda()
            x_hat = vq_model(x)[0]

            plot_images(pred=x_hat, original=x)
            plt.savefig(f"vqvae_{epoch}.png")
            plt.close()

torch.save(vq_model.state_dict(), "vqvae.pth")
del vq_model, optimizer, x, x_hat, train_loader, test_loader
gc.collect()
torch.cuda.empty_cache()


29 2


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch: 0, Loss: 4.866057824471901, Recon Loss: 0.2037651019877401


  2%|▏         | 1/50 [02:15<1:50:34, 135.40s/it]

Epoch: 0, Test Loss: 3.8543347120285034, Test Recon Loss: 0.2917487621307373


  4%|▍         | 2/50 [04:28<1:47:02, 133.80s/it]

Epoch: 1, Loss: 1.544010181365342, Recon Loss: 0.16000377489575024


  6%|▌         | 3/50 [06:40<1:44:22, 133.25s/it]

Epoch: 2, Loss: 0.2832418890862629, Recon Loss: 0.11356585344363904


  8%|▊         | 4/50 [08:53<1:42:03, 133.12s/it]

Epoch: 3, Loss: 0.3485085064994878, Recon Loss: 0.08586812404723003


 10%|█         | 5/50 [11:06<1:39:44, 132.98s/it]

Epoch: 4, Loss: 0.18769233982110844, Recon Loss: 0.0627873841801594
Epoch: 5, Loss: 0.14358470614614158, Recon Loss: 0.055644100596164835


 12%|█▏        | 6/50 [13:22<1:38:16, 134.01s/it]

Epoch: 5, Test Loss: 0.16564584523439407, Test Recon Loss: 0.055470388382673264


 14%|█▍        | 7/50 [15:35<1:35:49, 133.72s/it]

Epoch: 6, Loss: 0.13170330534721242, Recon Loss: 0.05064002925465847


 16%|█▌        | 8/50 [17:48<1:33:27, 133.52s/it]

Epoch: 7, Loss: 0.10611220963042357, Recon Loss: 0.046990403070532046


 18%|█▊        | 9/50 [20:01<1:31:05, 133.30s/it]

Epoch: 8, Loss: 0.0958550356585404, Recon Loss: 0.04487122261318667


 20%|██        | 10/50 [22:14<1:28:49, 133.23s/it]

Epoch: 9, Loss: 0.09425708677234321, Recon Loss: 0.04185084870149349
Epoch: 10, Loss: 0.09060800178297634, Recon Loss: 0.04025830020164621


 22%|██▏       | 11/50 [24:30<1:27:04, 133.96s/it]

Epoch: 10, Test Loss: 0.10135117545723915, Test Recon Loss: 0.04018284194171429


 24%|██▍       | 12/50 [26:43<1:24:41, 133.73s/it]

Epoch: 11, Loss: 0.0886174812912941, Recon Loss: 0.03811735891062638


 26%|██▌       | 13/50 [28:56<1:22:19, 133.49s/it]

Epoch: 12, Loss: 0.08969448326990523, Recon Loss: 0.038376129264461586


 28%|██▊       | 14/50 [31:09<1:20:07, 133.53s/it]

Epoch: 13, Loss: 0.08487800277512648, Recon Loss: 0.03569597988550005


 30%|███       | 15/50 [33:22<1:17:49, 133.42s/it]

Epoch: 14, Loss: 0.08499745279550552, Recon Loss: 0.040501814464042926
Epoch: 15, Loss: 0.08797449053361497, Recon Loss: 0.03595935752422645


 32%|███▏      | 16/50 [35:39<1:16:05, 134.27s/it]

Epoch: 15, Test Loss: 0.09106981009244919, Test Recon Loss: 0.035018209367990494


 34%|███▍      | 17/50 [37:52<1:13:42, 134.01s/it]

Epoch: 16, Loss: 0.07861711958359027, Recon Loss: 0.032968056677230476


 36%|███▌      | 18/50 [40:06<1:11:24, 133.90s/it]

Epoch: 17, Loss: 0.07785250593362184, Recon Loss: 0.03151722939620758


 38%|███▊      | 19/50 [42:19<1:09:02, 133.62s/it]

Epoch: 18, Loss: 0.07636784325385916, Recon Loss: 0.030136421065906


 40%|████      | 20/50 [44:32<1:06:43, 133.45s/it]

Epoch: 19, Loss: 0.07501227524259994, Recon Loss: 0.029429074248363232
Epoch: 20, Loss: 0.07364063093374515, Recon Loss: 0.027916064421678412


 42%|████▏     | 21/50 [46:48<1:04:51, 134.20s/it]

Epoch: 20, Test Loss: 0.08149045333266258, Test Recon Loss: 0.02744843065738678


 44%|████▍     | 22/50 [49:01<1:02:30, 133.96s/it]

Epoch: 21, Loss: 0.07409323829001394, Recon Loss: 0.03050791658461094


 46%|████▌     | 23/50 [51:14<1:00:08, 133.66s/it]

Epoch: 22, Loss: 0.07717355443485852, Recon Loss: 0.027840690045007343


 48%|████▊     | 24/50 [53:28<57:53, 133.61s/it]  

Epoch: 23, Loss: 0.07298342934970198, Recon Loss: 0.026266125062930173


 50%|█████     | 25/50 [55:41<55:36, 133.44s/it]

Epoch: 24, Loss: 0.07263677862697634, Recon Loss: 0.02677988993196652
Epoch: 25, Loss: 0.07230470689206288, Recon Loss: 0.025101954310104764


 52%|█████▏    | 26/50 [57:57<53:41, 134.22s/it]

Epoch: 25, Test Loss: 0.07872357964515686, Test Recon Loss: 0.02432861551642418


 54%|█████▍    | 27/50 [1:00:10<51:20, 133.91s/it]

Epoch: 26, Loss: 0.07201809281932897, Recon Loss: 0.02500430719348891


 56%|█████▌    | 28/50 [1:02:23<49:02, 133.77s/it]

Epoch: 27, Loss: 0.07150217312677153, Recon Loss: 0.0235540849509938


 58%|█████▊    | 29/50 [1:04:36<46:45, 133.58s/it]

Epoch: 28, Loss: 0.07005925116867855, Recon Loss: 0.02338059856716929


 60%|██████    | 30/50 [1:06:50<44:29, 133.45s/it]

Epoch: 29, Loss: 0.06976968615219512, Recon Loss: 0.022256459848120295
Epoch: 30, Loss: 0.06969840387845862, Recon Loss: 0.022317839657952046


 62%|██████▏   | 31/50 [1:09:06<42:29, 134.19s/it]

Epoch: 30, Test Loss: 0.07672495022416115, Test Recon Loss: 0.0212257606908679


 64%|██████▍   | 32/50 [1:11:19<40:10, 133.90s/it]

Epoch: 31, Loss: 0.06945623264744363, Recon Loss: 0.022399662126754892


 66%|██████▌   | 33/50 [1:13:32<37:52, 133.66s/it]

Epoch: 32, Loss: 0.06941864819362245, Recon Loss: 0.020924362257636827


 68%|██████▊   | 34/50 [1:15:45<35:37, 133.62s/it]

Epoch: 33, Loss: 0.06871920323063588, Recon Loss: 0.020831638240608675


 70%|███████   | 35/50 [1:17:58<33:21, 133.42s/it]

Epoch: 34, Loss: 0.06914223101118515, Recon Loss: 0.02090210313427037
Epoch: 35, Loss: 0.06984086899921813, Recon Loss: 0.020252555544520247


 72%|███████▏  | 36/50 [1:20:14<31:18, 134.21s/it]

Epoch: 35, Test Loss: 0.07590556517243385, Test Recon Loss: 0.01940724067389965


 74%|███████▍  | 37/50 [1:22:28<29:00, 133.90s/it]

Epoch: 36, Loss: 0.06964375065832303, Recon Loss: 0.02327565116615131


 76%|███████▌  | 38/50 [1:24:41<26:43, 133.62s/it]

Epoch: 37, Loss: 0.0730013571165759, Recon Loss: 0.020531707860786338


 78%|███████▊  | 39/50 [1:26:54<24:28, 133.48s/it]

Epoch: 38, Loss: 0.06979389989684368, Recon Loss: 0.019396698680417292


 80%|████████  | 40/50 [1:29:07<22:14, 133.43s/it]

Epoch: 39, Loss: 0.06858309563891642, Recon Loss: 0.019280468916584706
Epoch: 40, Loss: 0.06866673482903118, Recon Loss: 0.019353121262172174


 82%|████████▏ | 41/50 [1:31:23<20:07, 134.16s/it]

Epoch: 40, Test Loss: 0.0747237391769886, Test Recon Loss: 0.01894175447523594


 84%|████████▍ | 42/50 [1:33:36<17:51, 133.96s/it]

Epoch: 41, Loss: 0.06869632853516217, Recon Loss: 0.020744971160230965


 86%|████████▌ | 43/50 [1:35:49<15:35, 133.67s/it]

Epoch: 42, Loss: 0.07161692696912535, Recon Loss: 0.01933481345145867


 88%|████████▊ | 44/50 [1:38:03<13:21, 133.57s/it]

Epoch: 43, Loss: 0.06914105831549086, Recon Loss: 0.01909165110053687


 90%|█████████ | 45/50 [1:40:16<11:07, 133.43s/it]

Epoch: 44, Loss: 0.06926265885603838, Recon Loss: 0.01853909157216549
Epoch: 45, Loss: 0.06853570442261367, Recon Loss: 0.018559810994514103


 92%|█████████▏| 46/50 [1:42:32<08:57, 134.40s/it]

Epoch: 45, Test Loss: 0.07530079782009125, Test Recon Loss: 0.01822421234101057


 94%|█████████▍| 47/50 [1:44:45<06:41, 134.00s/it]

Epoch: 46, Loss: 0.06955526486552994, Recon Loss: 0.020134666356547124


 96%|█████████▌| 48/50 [1:46:59<04:27, 133.76s/it]

Epoch: 47, Loss: 0.06947716338367298, Recon Loss: 0.01871033449625147


 98%|█████████▊| 49/50 [1:49:12<02:13, 133.68s/it]

Epoch: 48, Loss: 0.06802729388763165, Recon Loss: 0.018576490981825466


100%|██████████| 50/50 [1:51:25<00:00, 133.72s/it]

Epoch: 49, Loss: 0.06796855762087066, Recon Loss: 0.0177432169238555





# Training VAR

In [5]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [6]:
print("=" * 10 + "Training VAR" + "=" * 10)
vqvae = VQVAE(
    model_params["VQVAE_DIM"],
    model_params["VOCAB_SIZE"],
    model_params["PATCH_SIZES"],
    num_channels=model_params["channels"],
)
vqvae.load_state_dict(torch.load("small_vqvae.pth"))  # LOADS the trained VQVAE model
vqvae = vqvae.to("cuda")
vqvae.eval()

for param in vqvae.parameters():
    param.requires_grad = False

var_model = VAR(
    vqvae=vqvae,
    dim=model_params["VAR_DIM"],
    n_heads=model_params["N_HEADS"],
    n_layers=model_params["N_LAYERS"],
    patch_sizes=model_params["PATCH_SIZES"],
    n_classes=10,
)
optimizer = torch.optim.AdamW(
    var_model.parameters(), lr=training_params["VAR"]["lr"]
)

print(f"VQVAE Parameters: {sum(p.numel() for p in vqvae.parameters())/1e6:.2f}M")
print(f"VAR Parameters: {sum(p.numel() for p in var_model.parameters())/1e6:.2f}M")

train_loader, test_loader = get_data(
    batch_size=training_params["VAR"]["batch_size"], dataset=dataset
)
var_model = var_model.to("cuda")
for epoch in range(training_params["VAR"]["epochs"]):
    epoch_loss = 0
    for i, (x, c) in enumerate(train_loader):
        x, c = x.cuda(), c.cuda()
        optimizer.zero_grad()

        _, _, idxs_R_BL, scales_BlC, _ = vqvae(x)
        idx_BL = torch.cat(idxs_R_BL, dim=1)
        scales_BlC = scales_BlC.cuda()
        logits_BLV = var_model(scales_BlC, cond=c)
        loss = F.cross_entropy(
            logits_BLV.view(-1, logits_BLV.size(-1)), idx_BL.view(-1)
        )

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)
    print(f"Epoch: {epoch}, Loss: {epoch_loss}")

    if epoch % 5 == 0:
        with torch.no_grad():

            cond = torch.arange(10).cuda()
            out_B3HW = var_model.generate(cond, 0)
            plot_images(pred=out_B3HW)

            plt.savefig(f"var_{epoch}.png")
            plt.close()

torch.save(var_model.state_dict(), "var.pth")



  vqvae.load_state_dict(torch.load("small_vqvae.pth"))  # LOADS the trained VQVAE model


VQVAE Parameters: 85.00M
VAR Parameters: 178.97M
229 23
Epoch: 0, Loss: 4.210374780096862
Epoch: 1, Loss: 4.026469395150264
Epoch: 2, Loss: 3.9341810374280772
Epoch: 3, Loss: 3.851914644241333
Epoch: 4, Loss: 3.745109129160252
Epoch: 5, Loss: 3.6351813031075824
Epoch: 6, Loss: 3.5295720475209333
Epoch: 7, Loss: 3.427367805914067
Epoch: 8, Loss: 3.3203730312497335
Epoch: 9, Loss: 3.2066679937870743
Epoch: 10, Loss: 3.081625181514623
Epoch: 11, Loss: 2.9439822003310425
Epoch: 12, Loss: 2.7930768155635186
Epoch: 13, Loss: 2.6147650490681675
Epoch: 14, Loss: 2.43959974930276
Epoch: 15, Loss: 2.241192906704532
Epoch: 16, Loss: 2.0920631901145503
Epoch: 17, Loss: 1.9529445371773566
Epoch: 18, Loss: 1.8152784655708414
Epoch: 19, Loss: 1.6868888712345773
Epoch: 20, Loss: 1.5457673895306983
Epoch: 21, Loss: 1.4383269316764897
Epoch: 22, Loss: 1.3347437722714186
Epoch: 23, Loss: 1.2131889936184779
Epoch: 24, Loss: 1.1350385648194359
Epoch: 25, Loss: 1.047056037265661
Epoch: 26, Loss: 0.966550968