In [None]:
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====


# download checkpoint
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'

# build vae, var
FOR_512_px = MODEL_DEPTH == 16
# if FOR_512_px:
#     patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32)
# else:
#     patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
vae, var = build_vae_var(
    V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
    device=device, patch_nums=patch_nums,
    num_classes=1000, depth=MODEL_DEPTH, shared_aln=FOR_512_px,
)

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
#var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval()
for p in vae.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class NeuralVAR(nn.Module):
    def __init__(
        self,
        cond_dim=1024,          # 条件向量维度
        feature_dim=32,         # 特征图通道数
        depth=12,               # Transformer层数
        num_heads=8,            # 注意力头数
        mlp_ratio=4.0,          # MLP扩展比例
        dropout=0.1,            # Dropout率
        patch_nums=(1, 2, 3, 4, 6, 9, 13, 18, 24, 32),  # 层级尺寸序列
    ):
        super().__init__()
        
        # 超参数设置
        self.cond_dim = cond_dim
        self.feature_dim = feature_dim
        self.depth = depth
        self.num_heads = num_heads
        self.patch_nums = patch_nums
        self.num_levels = len(patch_nums)
        
        # 条件投影模块
        self.cond_proj = nn.Sequential(
            nn.Linear(cond_dim, 4 * cond_dim),
            nn.GELU(),
            nn.Linear(4 * cond_dim, cond_dim),
            nn.LayerNorm(cond_dim)
        )
        
        # 位置编码系统
        self.init_level_embeddings()
        
        # Transformer核心
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(
                dim=feature_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                cond_dim=cond_dim
            ) for _ in range(depth)
        ])
        
        # 特征预测头
        self.output_head = nn.Sequential(
            nn.LayerNorm(feature_dim),
            nn.Linear(feature_dim, feature_dim * 2),
            nn.GELU(),
            nn.Linear(feature_dim * 2, feature_dim)
        )
        
        # 上采样模块 - 修复尺寸问题
        self.upsample_layers = nn.ModuleList()
        for i in range(1, self.num_levels):
            input_size = patch_nums[i-1]
            output_size = patch_nums[i]
            scale_factor = output_size / input_size
            
            self.upsample_layers.append(
                nn.Sequential(
                    nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
                    nn.Conv2d(feature_dim, feature_dim, 3, padding=1),
                    nn.GroupNorm(8, feature_dim),
                    nn.GELU()
                )
            )
        
        # 初始特征标记
        self.init_token = nn.Parameter(torch.zeros(1, 1, feature_dim))
        nn.init.normal_(self.init_token, std=0.02)

    def init_level_embeddings(self):
        """初始化层级位置编码"""
        # 层级嵌入
        self.level_embed = nn.Embedding(self.num_levels, self.feature_dim)
        nn.init.normal_(self.level_embed.weight, std=0.02)
        
        # 空间位置编码
        self.pos_embeddings = nn.ParameterList()
        for size in self.patch_nums:
            pos = torch.zeros(1, size*size, self.feature_dim)
            nn.init.trunc_normal_(pos, std=0.02)
            self.pos_embeddings.append(nn.Parameter(pos))

    def forward(self, cond_vector):
        """
        修复后的前向传播
        :param cond_vector: 条件向量 [B, cond_dim]
        :return: 最终特征图 [B, C, H, W]
        """
        # 1. 预处理条件向量
        cond = self.cond_proj(cond_vector)  # [B, cond_dim]
        
        # 2. 初始化特征图
        B = cond.size(0)
        device = cond.device
        f_hat_current = None  # 动态特征图
        
        # 3. 逐层级生成
        for level_idx in range(self.num_levels):
            current_size = self.patch_nums[level_idx]
            
            # 获取当前层级的token
            if level_idx == 0:
                # 初始层级使用可学习token
                tokens = self.init_token.expand(B, 1, -1)
            else:
                # 从当前特征图提取token
                tokens = self.get_level_tokens(f_hat_current, level_idx)
            
            # 添加位置编码
            pos_emb = self.pos_embeddings[level_idx]
            level_emb = self.level_embed(
                torch.tensor([level_idx], device=device)
            ).view(1, 1, -1)
            
            tokens = tokens + pos_emb + level_emb
            
            # Transformer处理
            tokens = self.apply_transformer(tokens, cond)
            
            # 预测特征更新
            pred_tokens = self.output_head(tokens)
            
            # 转换为空间特征
            new_features = pred_tokens.reshape(
                B, current_size, current_size, -1
            ).permute(0, 3, 1, 2)
            
            # 更新特征图
            if level_idx == 0:
                f_hat_current = new_features
            else:
                # 关键修复：正确上采样当前特征图
                upsampled = self.upsample_layers[level_idx-1](f_hat_current)
                
                # 确保尺寸匹配
                _, _, h, w = upsampled.shape
                target_size = self.patch_nums[level_idx]
                
                if h != target_size or w != target_size:
                    # 二次调整确保尺寸精确匹配
                    upsampled = F.interpolate(
                        upsampled, 
                        size=(target_size, target_size),
                        mode='bilinear',
                        align_corners=False
                    )
                
                # 融合特征
                f_hat_current = upsampled + new_features
        
        return f_hat_current

    def get_level_tokens(self, f_hat, level_idx):
        """
        从当前特征图中提取层级token
        :param f_hat: 当前特征图 [B, C, H, W]
        :param level_idx: 当前层级索引
        :return: token序列 [B, N, C]
        """
        # 获取当前层级尺寸
        current_size = self.patch_nums[level_idx]
        
        # 调整到当前分辨率
        resized = F.interpolate(
            f_hat, 
            size=(current_size, current_size), 
            mode='bilinear',
            align_corners=False
        )
        
        # 展开为token序列 [B, C, S, S] -> [B, S*S, C]
        return resized.permute(0, 2, 3, 1).reshape(
            resized.size(0), current_size*current_size, -1
        )

    def apply_transformer(self, tokens, cond):
        """应用Transformer块"""
        for block in self.transformer_blocks:
            tokens = block(tokens, cond)
        return tokens

class TransformerBlock(nn.Module):
    """条件Transformer模块"""
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, cond_dim=1024):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadAttention(dim, num_heads, dropout)
        
        # 条件自适应层归一化
        self.cond_norm = AdaLN(dim, cond_dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dim, dropout)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, cond):
        # 1. 自注意力
        attn_out = self.attn(self.norm1(x))
        x = x + self.dropout(attn_out)
        
        # 2. 条件MLP
        residual = x
        x = self.cond_norm(x, cond)
        mlp_out = self.mlp(x)
        x = residual + self.dropout(mlp_out)
        
        return x


class MultiHeadAttention(nn.Module):
    """多头注意力机制"""
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        assert dim % num_heads == 0, "dim必须能被num_heads整除"
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        q, k, v = qkv.unbind(2)  # [B, N, num_heads, head_dim]
        
        # 缩放点积注意力
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        # 聚合值
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)


class AdaLN(nn.Module):
    """自适应层归一化"""
    def __init__(self, dim, cond_dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim, elementwise_affine=False)
        self.ada_lin = nn.Sequential(
            nn.Linear(cond_dim, dim * 2),
            nn.SiLU(),
            nn.Linear(dim * 2, dim * 2)
        )
        
        # 初始化适配器
        nn.init.zeros_(self.ada_lin[-1].weight)
        nn.init.zeros_(self.ada_lin[-1].bias)
    
    def forward(self, x, cond):
        # 计算自适应参数
        params = self.ada_lin(cond)
        scale, shift = params.chunk(2, dim=-1)
        
        # 应用自适应归一化
        x = self.norm(x)
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)


class MLP(nn.Module):
    """多层感知机"""
    def __init__(self, in_dim, hidden_dim, out_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

In [None]:
Neuralvar = NeuralVAR()

In [None]:
image = torch.randn((10, 3, 256, 256)).to('cuda')

In [None]:
gt_ms_idx_Bl = vae.img_to_idxBl(image)

In [None]:
gt_ms_idx_Bl.shape

In [None]:
gt_BL = torch.cat(gt_ms_idx_Bl, dim=1)
v_patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [None]:
x_BLCv_wo_first_l: Ten = var.quantize_local.idxBl_to_var_input(gt_idx_Bl)


In [None]:
next_scales = []
B = gt_ms_idx_Bl[0].shape[0]
C = 32
H = W = 16
SN = 10

f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
pn_next: int = v_patch_nums[0]
for si in range(SN-1):
    h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
    f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
    pn_next = v_patch_nums[si+1]
    next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))

In [None]:
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
    
    return torch.cat(next_scales, dim=1) if len(next_scales) else None    # cat BlCs to BLC, this should be float32

In [None]:
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

label_B: torch.LongTensor = torch.tensor(class_labels, device=device)


In [None]:
sos = cond_BD = var.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=var.num_classes)), dim=0))


In [None]:
label_B = torch.tensor((1, 2, 3, 4)).to('cuda')

In [None]:
var.class_emb(torch.cat((label_B, torch.full_like(label_B, fill_value=var.num_classes)), dim=0)).shape

In [None]:
torch.tensor(np.array(range(4)))

In [None]:
a = var.class_emb(torch.full_like(torch.tensor(np.array(range(4))).to('cuda'), fill_value=var.num_classes))

In [None]:
b = torch.randn((4, 1024)).to('cuda')

In [None]:
torch.vstack((a, b)).shape

In [None]:
B = 4

In [None]:
label_B = torch.where(torch.rand(B, device=label_B.device) < var.cond_drop_rate, var.num_classes, torch.tensor(range(B)))


In [None]:
label_B

In [None]:
drift_index = torch.where(label_B == 1000)[0]

In [None]:
label_B = torch.randn((10, 1024))

In [None]:
label_B[[1, 2, 4, 5], :] = torch.randn((1, 1024))

In [None]:
len(drift_index)