In [7]:
from dataclasses import dataclass
from functools import lru_cache, reduce
import math
from typing import Union

from einops import rearrange
import torch
import torch.nn as nn
import torch._dynamo
from torch.nn import functional as F

try:
    import natten
except ImportError:
    natten = None

try:
    import flash_attn
except ImportError:
    flash_attn = None

In [151]:
class LinearGEGLU(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features * 2, bias=bias)
        self.out_features = out_features

    def forward(self, x):
        x = x @ weight.mT
        if bias is not None:
            x = x + bias
        x, gate = x.chunk(2, dim=-1)

        return x * F.gelu(gate)

class AdaRMSNorm(nn.Module):
    def __init__(self, features, cond_features, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.linear = nn.Linear(cond_features, features, bias=False)
        self.init_parameter()

    def init_parameter(self):
        nn.init.zeros_(self.linear.weight)
        if self.linear.bias is not None:
            nn.init.zeros_(self.linear.bias)

    def forward(self, x, cond):
        """
        x: (B, n_heads, tile_H*tile_W, d_head) 
        """
        scale = self.linear(cond)[:, None, None, :] + 1

        dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
        mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
        scale = scale.to(dtype) * torch.rsqrt(mean_sq + self.eps) 

        return x * scale.to(x.dtype)

class AxialRoPE(nn.Module):
    def __init__(self, dim, n_heads):
        super().__init__()
        log_min = math.log(math.pi)
        log_max = math.log(10.0 * math.pi)
        freqs = torch.linspace(log_min, log_max, n_heads * dim // 4 + 1)[:-1].exp() 
        self.register_buffer("freqs", freqs.view(dim // 4, n_heads).T.contiguous()) # freqs는 학습하지 않는 데이터

    def forward(self, pos):
        # pos : (tile_H*tile_W, 2) 타일 좌표값
        # self.freq: (n_heads, dim//4)

        # (tile_H*tile_W, n_heads, dim//4)
        theta_h = pos[..., None, 0:1] * self.freqs.to(pos.dtype) # 행 좌표값에 freq 곱
        theta_w = pos[..., None, 1:2] * self.freqs.to(pos.dtype) # 열 좌표값에 freq 곱

        return torch.cat((theta_h, theta_w), dim=-1) # (tile_H*tile_W, n_heads, dim//2) 

In [152]:
def _apply_rotary_emb_inplace(x, theta, conj):
    dtype = reduce(torch.promote_types, (x.dtype, theta.dtype, torch.float32))
    d = theta.shape[-1]
    assert d * 2 <= x.shape[-1]
    x1, x2 = x[..., :d], x[..., d : d * 2]
    x1_, x2_, theta = x1.to(dtype), x2.to(dtype), theta.to(dtype)
    cos, sin = torch.cos(theta), torch.sin(theta)
    sin = -sin if conj else sin
    y1 = x1_ * cos - x2_ * sin
    y2 = x2_ * cos + x1_ * sin
    x1.copy_(y1)
    x2.copy_(y2)


class ApplyRotaryEmbeddingInplace(torch.autograd.Function):
    @staticmethod
    def forward(x, theta, conj):
        _apply_rotary_emb_inplace(x, theta, conj=conj)
        return x

    @staticmethod
    def setup_context(ctx, inputs, output):
        _, theta, conj = inputs
        ctx.save_for_backward(theta)
        ctx.conj = conj

    @staticmethod
    def backward(ctx, grad_output):
        theta, = ctx.saved_tensors
        _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj)
        return grad_output, None, None


def apply_rotary_emb_(x, theta):
    return ApplyRotaryEmbeddingInplace.apply(x, theta, False)

In [169]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, d_model, d_head, cond_features, dropout=0.0):
        super().__init__()
        self.d_head = d_head
        self.n_heads = d_model // d_head
        self.norm = AdaRMSNorm(d_model, cond_features)
        self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
        self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
        self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.init_parameter()

    def init_parameter(self):
        nn.init.zeros_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            nn.init.zeros_(self.out_proj.bias) 

    def scale_for_cosine_sim(self, q, k, scale, eps):
        dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
        sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
        sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True)
        sqrt_scale = torch.sqrt(scale.to(dtype))
        scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
        scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
        return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)

    def forward(self, x, pos, cond):
        #x : (B, tile_H, tile_W, d_model)
        skip = x
        x = self.norm(x, cond) 
        qkv = self.qkv_proj(x) # (B, tile_H, tile_W, 3*d_model)
        pos = rearrange(pos, "... h w e -> ... (h w) e").to(qkv.dtype) # tile의 위치정보를 2차원에서 1차원으로
        theta = self.pos_emb(pos) # (tile_H*tile_W, n_head, d_head//4)
        #theta의 차원이 d_head//4이므로 apply_rotary_emb_가 앞 절반 차원에만 적용된다.

        # q,k,v: (B, n_heads, tile_H*tile_W, d_head)
        q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh (h w) e", t=3, e=self.d_head) # nh: number of head
        q, k = self.scale_for_cosine_sim(q, k, self.scale[:, None, None], 1e-6)
        theta = theta.movedim(-2, -3) # (n_head, tile_H*tile_W, d_head//4)
        q = apply_rotary_emb_(q, theta)
        k = apply_rotary_emb_(k, theta)
        x = F.scaled_dot_product_attention(q, k, v, scale=1.0) # (B, n_heads, tile_H*tile_W, d_head)
        x = rearrange(x, "n nh (h w) e -> n h w (nh e)", h=skip.shape[-3], w=skip.shape[-2]) # (B, tile_H, tile_W, d_model)

        x = self.dropout(x)
        x = self.out_proj(x)
        return x + skip

In [183]:
emb = AxialRoPE(8, 32)
pos = rearrange(make_axial_pos(2, 4).view(2, 4, 2), "... h w e -> ... (h w) e")

In [184]:
emb(pos).shape

torch.Size([8, 32, 4])

In [172]:
torch.sum(torch.randn(10, 8, 16, 32), dim=-1).shape

torch.Size([10, 8, 16])

In [173]:
def bounding_box(h, w, pixel_aspect_ratio=1.0):
    # Adjusted dimensions
    w_adj = w
    h_adj = h * pixel_aspect_ratio

    # Adjusted aspect ratio
    ar_adj = w_adj / h_adj

    # Determine bounding box based on the adjusted aspect ratio
    y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
    if ar_adj > 1:
        y_min, y_max = -1 / ar_adj, 1 / ar_adj
    elif ar_adj < 1:
        x_min, x_max = -ar_adj, ar_adj

    return y_min, y_max, x_min, x_max

In [174]:
bounding_box(30, 60)

(-0.5, 0.5, -1.0, 1.0)

In [175]:
def make_grid(h_pos, w_pos):
    grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
    h, w, d = grid.shape # (h, w)에 해당하는 변환된 d dim의 좌표
    return grid.view(h * w, d)


def bounding_box(h, w, pixel_aspect_ratio=1.0):
    '''
    긴 축을 -1 ~ 1 로 변환했을 때의 최소/최댓값
    h = 30, w = 60일 때 (-0.5, 0.5, -1.0, 1.0)
    '''

    # Adjusted dimensions
    w_adj = w
    h_adj = h * pixel_aspect_ratio

    # Adjusted aspect ratio
    ar_adj = w_adj / h_adj

    # Determine bounding box based on the adjusted aspect ratio
    y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
    if ar_adj > 1:
        y_min, y_max = -1 / ar_adj, 1 / ar_adj
    elif ar_adj < 1:
        x_min, x_max = -ar_adj, ar_adj

    return y_min, y_max, x_min, x_max


def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
    y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
    if align_corners:
        h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
        w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
    else:
        # 중앙 정렬된 h, w개수 만큼 bounding box 내의 값을 나눔
        h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
        w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
    return make_grid(h_pos, w_pos)

In [182]:
make_axial_pos(2, 4) # (B, 2, 4, C)의 데이터일때

tensor([[-0.2500, -0.7500],
        [-0.2500, -0.2500],
        [-0.2500,  0.2500],
        [-0.2500,  0.7500],
        [ 0.2500, -0.7500],
        [ 0.2500, -0.2500],
        [ 0.2500,  0.2500],
        [ 0.2500,  0.7500]])

In [177]:
class NeighborhoodSelfAttentionBlock(nn.Module):
    def __init__(self, d_model, d_head, cond_features, kernel_size, dropout=0.0):
        super().__init__()
        self.d_head = d_head
        self.n_heads = d_model // d_head
        self.kernel_size = kernel_size
        self.norm = AdaRMSNorm(d_model, cond_features)
        self.qkv_proj = nn.Linear(d_model, d_model * 3, bias=False)
        self.scale = nn.Parameter(torch.full([self.n_heads], 10.0))
        self.pos_emb = AxialRoPE(d_head // 2, self.n_heads)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        self.init_parameter()

    def init_parameter(self):
        nn.init.zeros_(self.out_proj.weight)
        if self.out_proj.bias is not None:
            nn.init.zeros_(self.out_proj.bias) 

    def scale_for_cosine_sim(self, q, k, scale, eps):
        dtype = reduce(torch.promote_types, (q.dtype, k.dtype, scale.dtype, torch.float32))
        sum_sq_q = torch.sum(q.to(dtype)**2, dim=-1, keepdim=True)
        sum_sq_k = torch.sum(k.to(dtype)**2, dim=-1, keepdim=True)
        sqrt_scale = torch.sqrt(scale.to(dtype))
        scale_q = sqrt_scale * torch.rsqrt(sum_sq_q + eps)
        scale_k = sqrt_scale * torch.rsqrt(sum_sq_k + eps)
        return q * scale_q.to(q.dtype), k * scale_k.to(k.dtype)

    def forward(self, x, pos, cond):
        skip = x
        x = self.norm(x, cond)
        qkv = self.qkv_proj(x)
        if natten is None:
            raise ModuleNotFoundError("natten is required for neighborhood attention")

        q, k, v = rearrange(qkv, "n h w (t nh e) -> t n nh h w e", t=3, e=self.d_head)
        q, k = scale_for_cosine_sim(q, k, self.scale[:, None, None, None], 1e-6)
        theta = self.pos_emb(pos).movedim(-2, -4)
        q = apply_rotary_emb_(q, theta)
        k = apply_rotary_emb_(k, theta)
        qk = natten.functional.na2d_qk(q, k, self.kernel_size)
        a = torch.softmax(qk, dim=-1).to(v.dtype)
        x = natten.functional.na2d_av(a, v, self.kernel_size)
        x = rearrange(x, "n nh h w e -> n h w (nh e)")

        x = self.dropout(x)
        x = self.out_proj(x)
        return x + skip

In [178]:
down = SelfAttentionBlock(64, 8, 64)
test = torch.randn(10, 4, 4, 64)
down(test, torch.randn(4, 4, 2), torch.randn(10, 64))

TypeError: scaled_dot_product_attention() got an unexpected keyword argument 'scale'

In [180]:
torch.randn((1,2,3,4)).movedim(-3, -1).shape

torch.Size([1, 3, 4, 2])

In [None]:
class LinearGEGLU(nn.Linear):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__(in_features, out_features * 2, bias=bias)
        self.out_features = out_features

    def forward(self, x):
        x = x @ weight.mT
        if bias is not None:
            x = x + bias
        x, gate = x.chunk(2, dim=-1)

        return x * F.gelu(gate)

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model, d_ff, cond_features, dropout=0.0):
        super().__init__()
        self.norm = AdaRMSNorm(d_model, cond_features)
        self.up_proj = LinearGEGLU(d_model, d_ff, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.down_proj = nn.Linear(d_ff, d_model, bias=False)
        self.init_parameter()

    def init_parameter(self):
        nn.init.zeros_(self.down_proj.weight)
        if self.down_proj.bias is not None:
            nn.init.zeros_(self.down_proj.bias) 

    def forward(self, x, cond):
        skip = x
        x = self.norm(x, cond)
        x = self.up_proj(x)
        x = self.dropout(x)
        x = self.down_proj(x)
        return x + skip