In [44]:
from pathlib import Path
from functools import partial
from math import ceil, pi, sqrt, degrees

import torch
from torch import nn, Tensor, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.cuda.amp import autocast

from torchtyping import TensorType

from pytorch_custom_utils import save_load

from beartype import beartype
from beartype.typing import Union, Tuple, Callable, Optional, List, Dict, Any

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

from einx import get_at

from x_transformers import Decoder
from x_transformers.attend import Attend
from x_transformers.x_transformers import RMSNorm, FeedForward, LayerIntermediates

from x_transformers.autoregressive_wrapper import (
    eval_decorator,
    top_k,
    top_p,
)

from local_attention import LocalMHA

from vector_quantize_pytorch import (
    ResidualVQ,
    ResidualLFQ
)

from net.data import derive_face_edges_from_faces, myderive_face_edges_from_faces
from net.version import __version__

from taylor_series_linear_attention import TaylorSeriesLinearAttn

from classifier_free_guidance_pytorch import (
    classifier_free_guidance,
    TextEmbeddingReturner
)

from torch_geometric.nn.conv import SAGEConv

from gateloop_transformer import SimpleGateLoopLayer

from tqdm import tqdm

# helper functions

def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def first(it):
    return it[0]

def divisible_by(num, den):
    return (num % den) == 0

def is_odd(n):
    return not divisible_by(n, 2)

def is_empty(l):
    return len(l) == 0

def is_tensor_empty(t: Tensor):
    return t.numel() == 0

def set_module_requires_grad_(
    module: Module,
    requires_grad: bool
):
    for param in module.parameters():
        param.requires_grad = requires_grad

def l1norm(t):
    return F.normalize(t, dim = -1, p = 1)

def l2norm(t):
    return F.normalize(t, dim = -1, p = 2)

def safe_cat(tensors, dim):
    tensors = [*filter(exists, tensors)]

    if len(tensors) == 0:
        return None
    elif len(tensors) == 1:
        return first(tensors)

    return torch.cat(tensors, dim = dim)

def pad_at_dim(t, padding, dim = -1, value = 0):
    ndim = t.ndim
    right_dims = (ndim - dim - 1) if dim >= 0 else (-dim - 1)
    zeros = (0, 0) * right_dims
    return F.pad(t, (*zeros, *padding), value = value)

def pad_to_length(t, length, dim = -1, value = 0, right = True):
    curr_length = t.shape[dim]
    remainder = length - curr_length

    if remainder <= 0:
        return t

    padding = (0, remainder) if right else (remainder, 0)
    return pad_at_dim(t, padding, dim = dim, value = value)

eps = 1e-5

In [45]:
def derive_angle(x, y, eps = 1e-5):
    z = einsum('... d, ... d -> ...', l2norm(x), l2norm(y)) #对一个面三个点的三组xyz坐标进行l2norm，意味着把每个顶点归一化到球面？这tm不就改变了原始三角形的形状和大小了，这里怎么看都应该是对矢量做l2norm然后点乘再arccos得到角度合理吧，草拟吗

    return z.clip(-1 + eps, 1 - eps).arccos()#计算反余弦值得到角度 见鬼，这得到的是什么，输入xy应该是两个矢量就合理了，结果输入是点坐标，焯    
    '''face_coords = tensor([[[[-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303],
          [-0.4430, -0.0572, -0.1290]],
         ...,
         [[ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022],
          [ 0.4439, -0.0353,  0.0067]]]])
    shifted_face_coords = tensor([[[[-0.4430, -0.0572, -0.1290],
          [-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303]],
         ...,
         [[ 0.4439, -0.0353,  0.0067],
          [ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022]]]])'''

    '''L2范数具体采用的函数F.normalize(t, dim = -1, p = 2)
    对于每个顶点的 xyz 坐标，这个函数会计算它的欧几里得范数(点的L2范数应该就是和原点的欧氏距离)，然后将这个坐标的每个元素都除以这个范数。这样做的结果是，每个顶点的坐标都会被归一化到单位球面上。这意味着每个顶点的 xyz 坐标的平方和（即 x^2 + y^2 + z^2）都会等于 1。
    '''
    '''爱因斯坦求和约定https://zhuanlan.zhihu.com/p/361209187
    >>> As = torch.randn(3, 2, 5)
    >>> Bs = torch.randn(3, 5, 4)
    >>> torch.einsum('bij,bjk->bik', As, Bs) #会在j这个维度做相乘相加(点积)
    tensor([[[-1.0564, -1.5904,  3.2023,  3.1271],
            [-1.6706, -0.8097, -0.8025, -2.1183]],

            [[ 4.2239,  0.3107, -0.5756, -0.2354],
            [-1.4558, -0.3460,  1.5087, -0.8530]],

            [[ 2.8153,  1.8787, -4.3839, -1.2112],
            [ 0.3728, -2.1131,  0.0921,  0.8305]]])
    '''

In [46]:
# x = [0,0,1]
# y = [0,1,0]
# x = torch.tensor([0,0,1])
# y = torch.tensor([0,1,0])
# x = torch.tensor((0,0,1))
# y = torch.tensor((0,1,0))
# x = torch.tensor([[[[0,0,1],[0,1,0],[1,0,0]]]])
x = torch.tensor([0.,0.,1.])
print(l2norm(x))

tensor([0., 0., 1.])


In [47]:
x = torch.tensor([0.,0.,2.])
print(l2norm(x))

tensor([0., 0., 1.])


In [48]:
x = torch.tensor([1.,0.,1.])
print(l2norm(x))

tensor([0.7071, 0.0000, 0.7071])


In [49]:
x = torch.tensor([0.,0.,1.])
y = torch.tensor([0.,1.,0.])
print(einsum('... d, ... d -> ...', l2norm(x), l2norm(y)))

tensor(0.)


In [50]:
face_coords = torch.tensor([[[[-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303],
          [-0.4430, -0.0572, -0.1290]],
         [[ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022],
          [ 0.4439, -0.0353,  0.0067]]]])
shifted_face_coords = torch.tensor([[[[-0.4430, -0.0572, -0.1290],
          [-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303]],
         [[ 0.4439, -0.0353,  0.0067],
          [ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022]]]])
z = einsum('... d, ... d -> ...', l2norm(face_coords), l2norm(shifted_face_coords))
print(z)

tensor([[[0.9999, 0.9999, 0.9999],
         [0.9999, 0.9999, 0.9999]]])


In [51]:
z.clip(-1 + eps, 1 - eps).arccos()


tensor([[[0.0154, 0.0130, 0.0128],
         [0.0105, 0.0101, 0.0102]]])

In [52]:
face_coords = torch.tensor([[[[-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303],
          [-0.4430, -0.0572, -0.1290]],
         [[ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022],
          [ 0.4439, -0.0353,  0.0067]]]])
shifted_face_coords = torch.tensor([[[[-0.4430, -0.0572, -0.1290],
          [-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303]],
         [[ 0.4439, -0.0353,  0.0067],
          [ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022]]]])
edge_vector = face_coords - shifted_face_coords #得到了三条边的矢量(从原点出发的)
print(edge_vector)

tensor([[[[ 0.0020, -0.0011, -0.0068],
          [ 0.0033, -0.0036,  0.0055],
          [-0.0053,  0.0047,  0.0013]],

         [[ 0.0018, -0.0039, -0.0028],
          [-0.0013,  0.0043, -0.0017],
          [-0.0005, -0.0004,  0.0045]]]])


In [53]:
normv = l2norm(edge_vector) #得到了三条边的归一化矢量
print(normv)

tensor([[[[ 0.2788, -0.1534, -0.9480],
          [ 0.4487, -0.4894,  0.7478],
          [-0.7359,  0.6526,  0.1805]],

         [[ 0.3511, -0.7606, -0.5461],
          [-0.2707,  0.8953, -0.3539],
          [-0.1100, -0.0880,  0.9900]]]])


In [54]:
normv1 = torch.tensor([[[[ 0.2788, -0.1534, -0.9480],
          [ 0.4487, -0.4894,  0.7478],
          [-0.7359,  0.6526,  0.1805]]]])

dot_products = torch.einsum('abcd,abed->abe', normv1, normv1) #
print(dot_products)

magnitudes = torch.norm(normv1, dim=3)
print(magnitudes)

tensor([[[ 0.0148, -0.0233,  0.0090]]])
tensor([[[1.0000, 1.0000, 1.0000]]])


In [55]:
normv0 = torch.tensor([[[[ 0., 0., 1.],
          [ 0., 2., 0.],
          [ 3., 0., 0.]]]])
torch.einsum('abcd,abed->abe', normv0, normv0) #tensor([[[1., 4., 9.]]])


tensor([[[1., 4., 9.]]])

In [56]:
normv00 = torch.cat((normv0[:, :, -1:], normv0[:, :, :-1]), dim = 2)
print(normv00)
'''tensor([[[[3., 0., 0.],
          [0., 0., 1.],
          [0., 2., 0.]]]])'''
torch.einsum('abcd,abcd->abc', normv0, normv00) 
#tensor([[[0., 0., 0.]]]) 000是对的 因为互相垂直了

tensor([[[[3., 0., 0.],
          [0., 0., 1.],
          [0., 2., 0.]]]])


tensor([[[0., 0., 0.]]])

In [57]:
torch.einsum('abcd,abed->abe', normv0, normv00) #tensor([[[9., 1., 4.]]]) 错的结果

tensor([[[9., 1., 4.]]])

In [58]:
normv0 = torch.tensor([[[[ 0., 0., 1.],
          [ 0., 2., 0.],
          [ 3., 0., 0.]]]])
normv00 = torch.cat((normv0[:, :, -1:], normv0[:, :, :-1]), dim = 2)

torch.einsum('abcd,abcd->abc', normv0, normv00) 

tensor([[[0., 0., 0.]]])

In [59]:
normv0 = torch.tensor([[[[ 0., 0., 1.],
          [ 0., 2., 0.],
          [ 3., 0., 0.]]]])
normv00 = torch.cat((normv0[:, :, -1:], normv0[:, :, :-1]), dim = 2)

torch.einsum('efgh,efgh->efg', normv0, normv00) 

tensor([[[0., 0., 0.]]])

In [60]:
normv0 = torch.tensor([[[[ 1., 0., 1.],
          [ 0., 2., 0.],
          [ 3., 0., 0.]]]])
normv00 = torch.cat((normv0[:, :, -1:], normv0[:, :, :-1]), dim = 2)

torch.einsum('abcd,abcd->abc', normv0, normv00) 

tensor([[[3., 0., 0.]]])

In [61]:
normv0 = torch.tensor([[[[ 1., 0., 1.],
          [ 0., 2., 0.],
          [ 3., 0., 0.]]]])
normv00 = torch.cat((normv0[:, :, -1:], normv0[:, :, :-1]), dim = 2)

torch.einsum('...d,...d->...', normv0, normv00) 

tensor([[[3., 0., 0.]]])

In [63]:
face_coords = torch.tensor([[[[-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303],
          [-0.4430, -0.0572, -0.1290]],
         [[ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022],
          [ 0.4439, -0.0353,  0.0067]]]])
shifted_face_coords = torch.tensor([[[[-0.4430, -0.0572, -0.1290],
          [-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303]],
         [[ 0.4439, -0.0353,  0.0067],
          [ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022]]]])
edge_vector = face_coords - shifted_face_coords #得到了三条边的矢量(从原点出发的)
print('边矢量',edge_vector)

normv = l2norm(edge_vector) #得到了三条边的归一化矢量
print('归一化边矢量',normv)

guiyiprodot = torch.einsum('abcd,abcd->abc', normv, torch.cat((normv[:, :, -1:], normv[:, :, :-1]), dim = 2)) 
print('归一化边矢量之间的点积',guiyiprodot)

guiyiprodot.arccos()

边矢量 tensor([[[[ 0.0020, -0.0011, -0.0068],
          [ 0.0033, -0.0036,  0.0055],
          [-0.0053,  0.0047,  0.0013]],

         [[ 0.0018, -0.0039, -0.0028],
          [-0.0013,  0.0043, -0.0017],
          [-0.0005, -0.0004,  0.0045]]]])
归一化边矢量 tensor([[[[ 0.2788, -0.1534, -0.9480],
          [ 0.4487, -0.4894,  0.7478],
          [-0.7359,  0.6526,  0.1805]],

         [[ 0.3511, -0.7606, -0.5461],
          [-0.2707,  0.8953, -0.3539],
          [-0.1100, -0.0880,  0.9900]]]])
归一化边矢量之间的点积 tensor([[[-0.4764, -0.5087, -0.5146],
         [-0.5123, -0.5827, -0.3994]]])


tensor([[[2.0673, 2.1045, 2.1113],
         [2.1087, 2.1928, 1.9817]]])

In [64]:
-guiyiprodot

tensor([[[0.4764, 0.5087, 0.5146],
         [0.5123, 0.5827, 0.3994]]])

In [65]:
torch.tensor([[ [degrees(rad.item()) for rad in row] for row in matrix] for matrix in (-guiyiprodot).arccos()])

tensor([[[61.5503, 59.4203, 59.0294],
         [59.1812, 54.3605, 66.4582]]])

In [66]:
face_coords = torch.tensor([[[[-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303],
          [-0.4430, -0.0572, -0.1290]],
         [[ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022],
          [ 0.4439, -0.0353,  0.0067]]]])
shifted_face_coords = torch.tensor([[[[-0.4430, -0.0572, -0.1290],
          [-0.4410, -0.0583, -0.1358],
          [-0.4377, -0.0619, -0.1303]],
         [[ 0.4439, -0.0353,  0.0067],
          [ 0.4457, -0.0392,  0.0039],
          [ 0.4444, -0.0349,  0.0022]]]])
edge_vector = face_coords - shifted_face_coords #得到了三条边的矢量(从原点出发的)
print('边矢量：',edge_vector)

normv = l2norm(edge_vector) #得到了三条边的归一化矢量
print('归一化边矢量：',normv)

guiyiprodot = -torch.einsum('abcd,abcd->abc', normv, torch.cat((normv[:, :, -1:], normv[:, :, :-1]), dim = 2)) 
print('归一化边矢量之间的锐角点积：',guiyiprodot)

hudu = guiyiprodot.arccos()
print("弧度结果：",hudu)
jiaodu = torch.tensor([[ [degrees(rad.item()) for rad in row] for row in matrix] for matrix in hudu])
print("角度结果：",jiaodu)

边矢量 tensor([[[[ 0.0020, -0.0011, -0.0068],
          [ 0.0033, -0.0036,  0.0055],
          [-0.0053,  0.0047,  0.0013]],

         [[ 0.0018, -0.0039, -0.0028],
          [-0.0013,  0.0043, -0.0017],
          [-0.0005, -0.0004,  0.0045]]]])
归一化边矢量 tensor([[[[ 0.2788, -0.1534, -0.9480],
          [ 0.4487, -0.4894,  0.7478],
          [-0.7359,  0.6526,  0.1805]],

         [[ 0.3511, -0.7606, -0.5461],
          [-0.2707,  0.8953, -0.3539],
          [-0.1100, -0.0880,  0.9900]]]])
归一化边矢量之间的点积 tensor([[[0.4764, 0.5087, 0.5146],
         [0.5123, 0.5827, 0.3994]]])
弧度结果： tensor([[[1.0743, 1.0371, 1.0303],
         [1.0329, 0.9488, 1.1599]]])
角度结果： tensor([[[61.5503, 59.4203, 59.0294],
         [59.1812, 54.3605, 66.4582]]])
