In [1]:
import torch

class Embed(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        
        self.embed = torch.nn.Embedding(30001, 768)
        self.pos_embed = torch.nn.Embedding(4, 768)

        self.register_buffer('pos_ids', torch.arange(4).unsqueeze(dim=0))

    def forward(self, input_ids):
        #input_ids -> [b, 77]
    
        #[b, 77] -> [b, 77, 768]
        embed = self.embed(input_ids)

        #[1, 77] -> [1, 77, 768]
        pos_embed = self.pos_embed(self.pos_ids)

        #[b, 77, 768]
        return embed + pos_embed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Atten(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.q = torch.nn.Linear(768, 768)
        self.k = torch.nn.Linear(768, 768)
        self.v = torch.nn.Linear(768, 768)
        self.out = torch.nn.Linear(768, 768)

    def forward(self, x):
        #x -> [b, 77, 768]

        b = x.shape[0]

        #维度不变
        #[b, 77, 768]
        q = self.q(x) * 0.125
        k = self.k(x)
        v = self.v(x)

        #拆分注意力头
        #[b, 77, 768] -> [b, 77, 12, 64] -> [b, 12, 77, 64] -> [b*12, 77, 64]
        q = q.reshape(b, 4, 12, 64).transpose(1, 2).reshape(b * 12, 4, 64)
        k = k.reshape(b, 4, 12, 64).transpose(1, 2).reshape(b * 12, 4, 64)
        v = v.reshape(b, 4, 12, 64).transpose(1, 2).reshape(b * 12, 4, 64)

        #计算qk乘积
        #[b*12, 77, 64] * [b*12, 64, 77] -> [b*12, 77, 77]
        attn = torch.bmm(q, k.transpose(1, 2))

        #[b*12, 77, 77] -> [b, 12, 77, 77]
        attn = attn.reshape(b, 12, 4, 4)

        #覆盖mask
        def get_mask(b):
            mask = torch.empty(b, 4, 4)

            #上三角的部分置为负无穷
            mask.fill_(-float('inf'))

            #对角线和以下的位置为0
            mask.triu_(1)

            return mask.unsqueeze(1)

        #[b, 12, 77, 77] + [b, 1, 77, 77] -> [b, 12, 77, 77]
        attn = attn + get_mask(attn.shape[0]).to(attn.device)

        #[b, 12, 77, 77] -> [b*12, 77, 77]
        attn = attn.reshape(b * 12, 4, 4)

        #计算softmax,被mask的部分值为0
        attn = attn.softmax(dim=-1)

        #计算和v的乘积
        #[b*12, 77, 77] * [b*12, 77, 64] -> [b*12, 77, 64]
        attn = torch.bmm(attn, v)

        #[b*12, 77, 64] -> [b, 12, 77, 64] -> [b, 77, 12, 64] -> [b, 77, 768]
        attn = attn.reshape(b, 12, 4, 64).transpose(1, 2).reshape(b, 4, 768)

        #线性输出,维度不变
        #[b, 77, 768]
        return self.out(attn)


Atten()(torch.randn(2, 4, 768)).shape

torch.Size([2, 4, 768])

In [3]:
class ClipEncoder(torch.nn.Module):

    def __init__(self):
        super().__init__()

        self.s1 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            Atten(),
        )

        self.s2 = torch.nn.Sequential(
            torch.nn.LayerNorm(768),
            torch.nn.Linear(768, 768),
        )

        self.s3 = torch.nn.Linear(768, 768)

    def forward(self, x):
        #x -> [2, 77, 768]

        #维度不变
        #[2, 77, 768]
        x = x + self.s1(x)

        #[2, 77, 768]
        res = x

        #[2, 77, 768] -> [2, 77, 3072]
        x = self.s2(x)

        #维度不变
        #[2, 77, 3072]
        x = x * (x * 1.702).sigmoid()

        #[2, 77, 3072] -> [2, 77, 768]
        return res + self.s3(x)


ClipEncoder()(torch.randn(2, 4, 768)).shape

torch.Size([2, 4, 768])

In [4]:
#经过优化之后的代码量少得吓人...
encoder = torch.nn.Sequential(
    Embed(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    ClipEncoder(),
    torch.nn.LayerNorm(768),
)

encoder(torch.ones(2, 4).long()).shape

torch.Size([2, 4, 768])

In [5]:
# 加载模型权重，并调整为单GPU或CPU格式
def load_model(model, model_path):
    # 加载保存的状态字典
    state_dict = torch.load(model_path)


    # 加载调整后的状态字典
    model.load_state_dict(state_dict)

# 加载模型并进行推断
model_path = "/data/run01/scz0ruj/model/new_encoder_model_parameters18800last.pth"
load_model(encoder, model_path)
encoder.eval()

Sequential(
  (0): Embed(
    (embed): Embedding(30001, 768)
    (pos_embed): Embedding(4, 768)
  )
  (1): ClipEncoder(
    (s1): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Atten(
        (q): Linear(in_features=768, out_features=768, bias=True)
        (k): Linear(in_features=768, out_features=768, bias=True)
        (v): Linear(in_features=768, out_features=768, bias=True)
        (out): Linear(in_features=768, out_features=768, bias=True)
      )
    )
    (s2): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Linear(in_features=768, out_features=768, bias=True)
    )
    (s3): Linear(in_features=768, out_features=768, bias=True)
  )
  (2): ClipEncoder(
    (s1): Sequential(
      (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (1): Atten(
        (q): Linear(in_features=768, out_features=768, bias=True)
        (k): Linear(in_features=768, out_features=768, bias=True)
        (v): L

In [6]:
text =[300,200,0.2,0.2]
text = torch.tensor(text, dtype=torch.float32)  # 确保数据类型一致

# 定义条件均值和标准差
condition_mean = torch.tensor([3.84952491e+02 ,2.43981134e+02, 1.12512536e+00 ,1.84848683e-01]).to(device)
condition_std = torch.tensor([1.19021432e+02, 1.23790469e+02, 1.13191379e+00 ,2.54632684e-02]).to(device)

text = text.to(device)
text=(text - condition_mean) / condition_std
print(text)

num_buckets = 30000
bins = torch.linspace(-2, 2, steps=num_buckets + 1, device='cuda')
#
# 映射数据到桶
indices = torch.bucketize(text, bins) - 1
indices[indices < 0] = 0 
print(indices)


#[1, 77, 768]
pos = encoder(indices)
print(pos)

NameError: name 'device' is not defined