In [1]:
from PIL import Image
import numpy as np
from tqdm import tqdm, trange

import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from 

In [2]:
class ImageNetDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        # Recursively find all image files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Create dataset and dataloader
dataset = ImageNetDataset(root_dir='/shared/imagenet/train', 
                         transform=transform)
dataloader = DataLoader(dataset, 
                       batch_size=128,
                       shuffle=True,
                       num_workers=4)

In [3]:
len(dataloader)

9996

In [4]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
for param in clip_model.parameters():
    param.requires_grad = False
resize = transforms.Resize((224, 224))

images = preprocess(Image.open('/shared/imagenet/train/image_0.jpg')).unsqueeze(0).to(device)
print(images.shape)

with torch.no_grad():
    for batch in dataloader:
        print(clip_model.encode_image(resize(batch).to(device)).shape)
        break

def viz_loss(truth_batch, output_batch):
    truth_scores = clip_model.encode_image(resize(truth_batch).to(device))
    output_scores = clip_model.encode_image(resize(output_batch).to(device))
    return (truth_scores - output_scores).pow(2).mean()

torch.Size([1, 3, 224, 224])
torch.Size([128, 512])


In [5]:
"""
SD-VAE

Replicates f8c4p2 on 256x256 ImageNet, i.e. 256x256 -> 16x16x12.

Conv structure:
* 
"""

def space_2_channel(x):
    output = torch.cat((x[:, ::2, ::2] + x[:, 0::2, 1::2], x[:, 1::2, ::2] + x[:, 1::2, 1::2]), dim=2)
    outshape = output.shape
    assert(x.shape == (outshape[0], outshape[1] / 2, outshape[2] / 2, outshape[3] * 2))
    return output

def channel_2_space(x):
    pass

class SANA_VAE(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_sizes = [11, 9, 5, 3, 3]
        strides = [2, 2, 2, 2, 2]
        paddings = [k//2 for k in kernel_sizes]
        encoder_layers = []
        decoder_layers = []
        for index, kernel_size in enumerate(kernel_sizes):
            encoder_layers.append(nn.Conv2d(3 * 2**index, 6 * 2**index, kernel_size, strides[index], paddings[index]))
            encoder_layers.append(nn.LeakyReLU(0.2))
            decoder_layers.append(nn.ConvTranspose2d(6 * 2**(4-index), 3 * 2**(4-index), kernel_sizes[-1-index], strides[-1-index], paddings[-1-index]))
            decoder_layers.append(nn.LeakyReLU(0.2))
        self.encoder = nn.Sequential(*encoder_layers)
        self.decoder = nn.Sequential(*decoder_layers)

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.decode(self.encode(x))

    def train(self, dataloader, epochs=100):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        for _ in trange(epochs, desc="Epochs"):
            for batch in tqdm(dataloader, desc="Batches"):
                batch = batch.to(device)
                optimizer.zero_grad()
                output = self.forward(batch)
                loss = viz_loss(batch, output)
                loss.backward()
                optimizer.step()

In [6]:
model = SANA_VAE().to(device)
def num_params(model):
    return sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params(model)}")

Number of parameters: 134379


In [8]:
num_params(model)

134379

In [7]:
model.train(dataloader, epochs=10)


Batches:  20%|█▉        | 1968/9996 [05:14<21:23,  6.25it/s]
Epochs:   0%|          | 0/10 [05:14<?, ?it/s]


KeyboardInterrupt: 