In [1]:
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm, trange

import os
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from accelerate import Accelerator, notebook_launcher

In [2]:
class ImageNetDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        # Recursively find all image files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(root, file))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image

# Define transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                       std=[0.229, 0.224, 0.225])
])

# Create dataset and dataloader
dataset = ImageNetDataset(root_dir='/shared/imagenet/train', 
                         transform=transform)
dataloader = DataLoader(dataset, 
                       batch_size=2048,
                       shuffle=True,
                       num_workers=4)

In [3]:
len(dataloader)

625

In [4]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device)
for param in clip_model.parameters():
    param.requires_grad = False
resize = transforms.Resize((224, 224))

images = preprocess(Image.open('/shared/imagenet/train/image_0.jpg')).unsqueeze(0).to(device)
print(images.shape)

with torch.no_grad():
    for batch in dataloader:
        print(clip_model.encode_image(resize(batch).to(device)).shape)
        break

def viz_loss(truth_batch, output_batch):
    truth_scores = clip_model.encode_image(resize(truth_batch).to(device))
    output_scores = clip_model.encode_image(resize(output_batch).to(device))
    return (truth_scores - output_scores).pow(2).mean()

torch.Size([1, 3, 224, 224])
torch.Size([2048, 512])


In [5]:
"""
SD-VAE

Replicates f8c4p2 on 256x256 ImageNet, i.e. 256x256 -> 16x16x12.

Conv structure:
* 
"""

def space_2_channel(x):
    output = torch.cat((x[:, :, 0::2, 0::2] + x[:, :, 0::2, 1::2], x[:, :, 1::2, 0::2] + x[:, :, 1::2, 1::2]), dim=1) / 2
    assert(output.shape == (x.shape[0], x.shape[1] * 2, x.shape[2] // 2, x.shape[3] // 2))
    return output

def channel_2_space(x):
    output = torch.zeros((x.shape[0], x.shape[1] // 2, x.shape[2] * 2, x.shape[3] * 2), device=x.device)
    half_C = x.shape[1] // 2
    output[:, :, ::2, ::2] = x[:, :half_C, :, :]
    output[:, :, ::2, 1::2] = x[:, half_C:, :, :]
    output[:, :, 1::2, ::2] = x[:, half_C:, :, :]
    output[:, :, 1::2, 1::2] = x[:, :half_C, :, :]
    return output

kernel_sizes = [11, 9, 5, 3, 3]
strides = [2, 2, 2, 2, 2]
paddings = [k//2 for k in kernel_sizes]

class SANA_Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Conv2d(3 * 2**index, 6 * 2**index, kernel_size, strides[index], paddings[index]) for index, kernel_size in enumerate(kernel_sizes)])
        self.activation = nn.LeakyReLU(0.2)
        self.norms = nn.ModuleList([nn.LayerNorm(
            (
                int(3 * 2**(index+1)),
                int(256 / 2**(index+1)),
                int(256 / 2**(index+1))
            )
        ) for index in range(len(kernel_sizes))])

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x) + space_2_channel(x)
            x = self.activation(x)
            x = norm(x)
        return x

class SANA_Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.ConvTranspose2d(6 * 2**index, 3 * 2**index, kernel_size, strides[index], paddings[index], output_padding=strides[index]-1) for index, kernel_size in enumerate(kernel_sizes)][::-1])
        self.activation = nn.LeakyReLU(0.2)
        self.norms = nn.ModuleList([nn.LayerNorm(
            (
                int(3 * 2**index),
                int(256 / 2**index),
                int(256 / 2**index)
            )
        ) for index in range(len(kernel_sizes))][::-1])

    def forward(self, x):
        for layer, norm in zip(self.layers, self.norms):
            x = layer(x) + channel_2_space(x)
            x = self.activation(x)
            x = norm(x)
        return x

class SANA_VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = SANA_Encoder()
        self.decoder = SANA_Decoder()

    def encode(self, x):
        return self.encoder(x)

    def decode(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self.decode(self.encode(x))

    def train(self, dataloader, optimizer, epochs=100):
        best_loss = float('inf')
        for _ in trange(epochs, desc="Epochs", leave=True):
            epoch_loss = 0
            for batch in tqdm(dataloader, desc="Batches", leave=True):
                batch = batch.to(device)
                optimizer.zero_grad()
                output = self.forward(batch)
                loss = viz_loss(batch, output)
                loss.backward()
                optimizer.step()
                epoch_loss += loss.item()
            print(f"Loss: {epoch_loss / len(dataloader)}")
            if epoch_loss < best_loss:
                best_loss = epoch_loss
                torch.save(self.state_dict(), "sana_vae.pt")
        print(f"Best loss: {best_loss}")

In [6]:
# model = SANA_VAE().to(device)
def num_params(model):
    return sum(p.numel() for p in model.parameters())
# print(f"Number of parameters: {num_params(model)}")

In [7]:
model = SANA_VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train(dataloader, optimizer, epochs=10)

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/625 [00:00<?, ?it/s]

In [7]:
def training_function():
    accelerator = Accelerator()
    model = SANA_VAE()
    print(f"Number of parameters: {num_params(model)}")
    dataloader = DataLoader(dataset, batch_size=2048, shuffle=True, num_workers=4)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
    for _ in trange(10, desc="Epochs", leave=True):
        for batch in tqdm(dataloader, desc="Batches", leave=True):
            optimizer.zero_grad()
            output = model(batch)
            loss = viz_loss(batch, output)
            loss.backward()
            optimizer.step()
        print(f"Loss: {loss.item()}")

In [8]:
if __name__ == "__main__":
    notebook_launcher(training_function, num_processes=8)

Launching training on 8 GPUs.


E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732] failed (exitcode: 1) local_rank: 0 (pid: 200201) of fn: training_function (start_method: fork)
E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732] Traceback (most recent call last):
E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/alex-zhao/.venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 687, in _poll
E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732]     self._pc.join(-1)
E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732]   File "/home/alex-zhao/.venv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 203, in join
E1224 08:44:31.473000 199646 torch/distributed/elastic/multiprocessing/api.py:732]     raise ProcessRaisedException(msg, error_index, failed_process.pid)
E1224 08:44:31.473000 199646 torch/distr

ChildFailedError: 
============================================================
training_function FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-12-24_08:44:31
  host      : worker-7.etched-slurm-worker-svc.etched-slurm.svc.cluster.local
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 200201)
  error_file: /tmp/torchelastic_7_9x3ers/none_9wcxsstm/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
      return f(*args, **kwargs)
    File "/tmp/ipykernel_199646/2341702742.py", line 2, in training_function
      accelerator = Accelerator()
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 425, in __init__
      self.state = AcceleratorState(
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/accelerate/state.py", line 861, in __init__
      PartialState(cpu, **kwargs)
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/accelerate/state.py", line 276, in __init__
      self.set_device()
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/accelerate/state.py", line 791, in set_device
      device_module.set_device(self.device)
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 478, in set_device
      torch._C._cuda_setDevice(device)
    File "/home/alex-zhao/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 305, in _lazy_init
      raise RuntimeError(
  RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
  
============================================================

Epochs:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/625 [00:00<?, ?it/s]

KeyboardInterrupt: 