In [None]:
import torch
import torch.distributed as dist
from models import DiT_models
from download import find_model
from diffusion import create_diffusion
from diffusers.models import AutoencoderKL
from tqdm import tqdm
import os
from PIL import Image
import numpy as np
import math
import argparse
from torchvision.utils import save_image
from timm.models.vision_transformer import Attention, Mlp

In [None]:
!nvidia-smi

In [None]:
device = 'cuda'

In [None]:
import torch
from torch import nn
from functools import partial


@torch.no_grad()
def quantize_weight_per_channel_absmax(w, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    w.div_(scales).round_().mul_(scales)
    return w


@torch.no_grad()
def quantize_weight_per_tensor_absmax(w, n_bits=8):
    # w: (out_features, in_features)
    scales = w.abs().max()
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    w.div_(scales).round_().mul_(scales)
    return w


@torch.no_grad()
def quantize_activation_per_token_absmax(t, n_bits=8):
    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max(dim=-1, keepdim=True)[0]
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t


@torch.no_grad()
def quantize_activation_per_tensor_absmax(t, n_bits=8):
    t_shape = t.shape
    t.view(-1, t_shape[-1])
    scales = t.abs().max()
    q_max = 2 ** (n_bits - 1) - 1
    scales.clamp_(min=1e-5).div_(q_max)
    t.div_(scales).round_().mul_(scales)
    return t


class W8A8Linear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        bias=True,
        act_quant="per_token",
        quantize_output=False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.register_buffer(
            "weight",
            torch.randn(
                self.out_features,
                self.in_features,
                dtype=torch.float16,
                requires_grad=False,
            ),
        )
        if bias:
            self.register_buffer(
                "bias",
                torch.zeros(
                    (1, self.out_features), dtype=torch.float16, requires_grad=False
                ),
            )
        else:
            self.register_buffer("bias", None)

        if act_quant == "per_token":
            self.act_quant_name = "per_token"
            self.act_quant = partial(quantize_activation_per_token_absmax, n_bits=8)
        elif act_quant == "per_tensor":
            self.act_quant_name = "per_tensor"
            self.act_quant = partial(quantize_activation_per_tensor_absmax, n_bits=8)
        else:
            raise ValueError(f"Invalid act_quant: {act_quant}")

        if quantize_output:
            self.output_quant_name = self.act_quant_name
            self.output_quant = self.act_quant
        else:
            self.output_quant_name = "None"
            self.output_quant = lambda x: x

    def to(self, *args, **kwargs):
        super(W8A8Linear, self).to(*args, **kwargs)
        self.weight = self.weight.to(*args, **kwargs)
        if self.bias is not None:
            self.bias = self.bias.to(*args, **kwargs)
        return self

    @torch.no_grad()
    def forward(self, x):
        q_x = self.act_quant(x)
        y = torch.functional.F.linear(q_x, self.weight, self.bias)
        q_y = self.output_quant(y)
        return q_y

    @staticmethod
    def from_float(
        module, weight_quant="per_channel", act_quant="per_token", quantize_output=False
    ):
        assert isinstance(module, torch.nn.Linear)
        new_module = W8A8Linear(
            module.in_features,
            module.out_features,
            module.bias is not None,
            act_quant=act_quant,
            quantize_output=quantize_output,
        )
        if weight_quant == "per_channel":
            new_module.weight = quantize_weight_per_channel_absmax(
                module.weight, n_bits=4
            )  # use 8-bit integer for weight
        elif weight_quant == "per_tensor":
            new_module.weight = quantize_weight_per_tensor_absmax(
                module.weight, n_bits=4
            )
        else:
            raise ValueError(f"Invalid weight_quant: {weight_quant}")
        new_module.weight_quant_name = weight_quant
        if module.bias is not None:
            new_module.bias = module.bias
        return new_module

    def __repr__(self):
        return f"W8A8Linear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name}, act_quant={self.act_quant_name}, output_quant={self.output_quant_name})"

In [None]:
image_size = 256 #@param [256, 512]
vae_model = "stabilityai/sd-vae-ft-ema" #@param ["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"]
# Load model:
latent_size = int(image_size) // 8

def init_models():
    model = DiT_models['DiT-XL/2'](input_size=latent_size).to(device)
    # state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt")
    state_dict = torch.load('/n/netscratch/nali_lab_seas/Everyone/mingze/models/pretrained_models/DiT-XL-2-256x256.pt', weights_only=True)
    model.load_state_dict(state_dict)
    model.eval() # important!
    vae = AutoencoderKL.from_pretrained(vae_model).to(device)

    return model, vae

In [None]:
def quantize_dit_only_block(
    model, weight_quant="per_tensor", act_quant="per_tensor", quantize_bmm_input=True
):
    for name, m in model.named_modules():
        if isinstance(m, Attention):  # quantize each linear layer within attention block
            m.qkv = W8A8Linear.from_float(
                m.qkv, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input
            )
            m.proj = W8A8Linear.from_float(
                m.proj, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input
            )
        if isinstance(m, Mlp):
            m.fc1 = W8A8Linear.from_float(
                m.fc1, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input
            )
            m.fc2 = W8A8Linear.from_float(
                m.fc2, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input
            )
            
    return model

In [None]:
def quantize_dit_all_linear_layer(
    model, weight_quant="per_tensor", act_quant="per_tensor", quantize_bmm_input=False
):
    for name, m in model.named_modules():
        if isinstance(m, torch.nn.Linear):  # quantize each linear layer
            m = W8A8Linear.from_float(
                m, weight_quant=weight_quant, act_quant=act_quant, quantize_output=quantize_bmm_input
            )
            
    return model

In [None]:
def main_func(quant_func=lambda x: x):

    # Set user inputs:
    seed = 1 #@param {type:"number"}
    torch.manual_seed(seed)
    num_sampling_steps = 200 #@param {type:"slider", min:0, max:1000, step:1}
    cfg_scale = 2 #@param {type:"slider", min:1, max:10, step:0.1}
    class_labels = 23, 123, 324, 405 #@param {type:"raw"}
    samples_per_row = 4 #@param {type:"number"}
    
    model, vae = init_models()

    model = quant_func(model)
    
    # Create diffusion object:
    diffusion = create_diffusion(str(num_sampling_steps))
    
    # Create sampling noise:
    n = len(class_labels)
    z = torch.randn(n, 4, latent_size, latent_size, device=device)
    y = torch.tensor(class_labels, device=device)
    
    # Setup classifier-free guidance:
    z = torch.cat([z, z], 0)
    y_null = torch.tensor([1000] * n, device=device)
    y = torch.cat([y, y_null], 0)
    model_kwargs = dict(y=y, cfg_scale=cfg_scale)
    
    # Sample images:
    samples = diffusion.p_sample_loop(
        model.forward_with_cfg, z.shape, z, clip_denoised=False,
        model_kwargs=model_kwargs, progress=True, device=device
    )
    samples, _ = samples.chunk(2, dim=0)  # Remove null class samples
    samples = vae.decode(samples / 0.18215).sample
    
    # Save and display images:
    save_image(samples, "sample.png", nrow=int(samples_per_row),
               normalize=True, value_range=(-1, 1))
    samples = Image.open("sample.png")
    display(samples)

In [None]:
main_func()

In [None]:
main_func(quantize_dit_all_linear_layer)

In [None]:
model

In [3]:
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import ImageFolder
from torchvision import transforms

# Dataset and Dataloader
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


# Specify the target classes
selected_class_ids = [0, 1, 2, 3]

full_dataset = ImageFolder("/n/home11/mingzeyuan/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini/train", transform=transform)
# Get indices of images belonging to the selected classes
# selected_indices = [i for i, (_, label) in enumerate(full_dataset) if label in selected_class_ids]

# Create a subset dataset with only the selected classes
# filtered_dataset = Subset(full_dataset, selected_indices)
# dataloader = DataLoader(filtered_dataset, batch_size=32, shuffle=False)
full_dataset

Dataset ImageFolder
    Number of datapoints: 34745
    Root location: /n/home11/mingzeyuan/.cache/kagglehub/datasets/ifigotin/imagenetmini-1000/versions/1/imagenet-mini/train
    StandardTransform
Transform: Compose(
               Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=True)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
           )

In [11]:
from tqdm.auto import tqdm
for i, (_, label) in tqdm(enumerate(full_dataset)):
    if label in selected_class_ids:
        print(i)

0it [00:00, ?it/s]

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779


KeyboardInterrupt: 

In [10]:
selected_indices = [i for i, (_, label) in enumerate(full_dataset) if label in selected_class_ids]

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn

# Define LoRA Layer
class LoRALayer(nn.Module):
    def __init__(self, original_layer, r=8, alpha=1.0):
        super(LoRALayer, self).__init__()
        self.original_layer = original_layer
        self.r = r  # Low rank
        self.alpha = alpha  # Scaling factor

        # Low-rank matrices
        self.A = nn.Parameter(torch.randn(original_layer.out_features, r) * 0.01)
        self.B = nn.Parameter(torch.randn(r, original_layer.in_features) * 0.01)

        # Scaling factor to ensure initial LoRA impact is small
        self.scale = alpha / self.r

    def forward(self, x):
        lora_adjustment = (x @ self.B.T) @ self.A.T  # (batch_size, in_features) -> (batch_size, out_features)
        return self.original_layer(x) + lora_adjustment * self.scale

def add_lora_to_model(model, r=8, alpha=1.0):
    layers_to_modify = []  # Collect layers to modify first

    # Collect all linear layers in a list
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear) and ("mlp" in name or "attn" in name):
            layers_to_modify.append((name, module))

    # Replace each collected layer with a LoRA layer
    for name, module in layers_to_modify:
        # Split the name by '.' to traverse submodules and set the new layer correctly
        submodule = model
        *module_names, layer_name = name.split(".")
        for module_name in module_names:
            submodule = getattr(submodule, module_name)

        # Replace the layer with a LoRA layer
        setattr(submodule, layer_name, LoRALayer(module, r=r, alpha=alpha))

In [None]:
# Assuming `model` is the DiT model with LoRA layers added
def freeze_model_weights(model):
    for name, param in model.named_parameters():
        if not ((".A" in name) or ('.B' in name)):  # Replace with the identifier for LoRA parameters
            param.requires_grad = False  # Freeze base model weights

# Add LoRA layers (as shown in previous responses) and freeze base model weights
add_lora_to_model(model)  # Add LoRA layers to the model
freeze_model_weights(model)  # Freeze original model weights
model.to("cuda");

In [None]:
from tqdm.auto import tqdm

# Training loop
# model.train()
# requires_grad(model, True)
epochs = 500

# Optimizer
# optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
for epoch in tqdm(range(epochs)):
    running_loss = 0.0
    for x, y in tqdm(dataloader):
        x, y = x.to(device), y.to(device)

        # Encode images to latent space and normalize latents
        with torch.no_grad():
          x = vae.encode(x).latent_dist.sample().mul_(0.18215)

        # Explicitly set requires_grad for the input latents
        # print(x.requires_grad_())
        # x.requires_grad_(True)

        # Sample a random timestep for each batch
        t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
        # y.requires_grad_(True)
        model_kwargs = {"y": y}

        # Compute training losses from diffusion
        # t.requires_grad_(True)
        x = torch.Tensor(x)
        x.requires_grad_(True)
        # print(x.requires_grad_())
        loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
        loss = loss_dict["loss"].mean()

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log loss
        running_loss += loss.item()

    # Print epoch loss
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(dataloader):.4f}")

print("Training completed.")