In [2]:
import torch
import torch.nn as nn

from safetensors.torch import load_file

tensor_dict = load_file("weights/t3_cfg.safetensors")

In [3]:
for k, v in tensor_dict.items():
    print(k, v.shape)

cond_enc.emotion_adv_fc.weight torch.Size([1024, 1])
cond_enc.perceiver.attn.norm.bias torch.Size([1024])
cond_enc.perceiver.attn.norm.weight torch.Size([1024])
cond_enc.perceiver.attn.proj_out.bias torch.Size([1024])
cond_enc.perceiver.attn.proj_out.weight torch.Size([1024, 1024])
cond_enc.perceiver.attn.to_k.bias torch.Size([1024])
cond_enc.perceiver.attn.to_k.weight torch.Size([1024, 1024])
cond_enc.perceiver.attn.to_q.bias torch.Size([1024])
cond_enc.perceiver.attn.to_q.weight torch.Size([1024, 1024])
cond_enc.perceiver.attn.to_v.bias torch.Size([1024])
cond_enc.perceiver.attn.to_v.weight torch.Size([1024, 1024])
cond_enc.perceiver.pre_attention_query torch.Size([1, 32, 1024])
cond_enc.spkr_enc.bias torch.Size([1024])
cond_enc.spkr_enc.weight torch.Size([1024, 256])
speech_emb.weight torch.Size([8194, 1024])
speech_head.weight torch.Size([8194, 1024])
speech_pos_emb.emb.weight torch.Size([4100, 1024])
text_emb.weight torch.Size([704, 1024])
text_head.weight torch.Size([704, 1024])


In [4]:
len(tensor_dict)

292

In [5]:
#print all keys and their shapes in a file
import os
with open("t3_cfg.txt", "wb") as f:
    for k, v in tensor_dict.items():
        f.write(f"{k}: {v.shape}\n".encode())

In [12]:
#Creating the transformer config file

config = {
    "d_model": 1024,
    "n_head": 16,   #not given in the paper; i used 16; 8 or 32 can also be tried
    "n_layer": 30,
    "mlp_inner_dim": 4096,
    "text_vocab_size": 704,
    "text_max_pos": 2050,
    "speech_vocab_size": 8194,
}

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=1024, num_heads=16):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  
        assert self.head_dim * num_heads == d_model

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, C = x.shape 

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = torch.softmax(scores, dim=-1)
        attn_out = torch.matmul(attn, v)

        attn_out = attn_out.transpose(1, 2)
        attn_out = attn_out.view(B, T, C)

        return self.o_proj(attn_out)

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self, d_model=1024, num_heads=16, mlp_inner_dim=4096):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)

        self.attn = MultiHeadAttention(
            d_model=d_model,
            num_heads=num_heads
        )

        self.ln2 = nn.LayerNorm(d_model)

        self.gate_proj = nn.Linear(d_model, mlp_inner_dim)
        self.up_proj   = nn.Linear(d_model, mlp_inner_dim)
        self.down_proj = nn.Linear(mlp_inner_dim, d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))

        x_ln = self.ln2(x)
        gate = self.gate_proj(x_ln)
        up   = self.up_proj(x_ln)
        x = x + self.down_proj(torch.nn.functional.silu(gate) * up)

        return x


class TransformerBackbone(nn.Module):
    def __init__(self, n_layers=30, d_model=1024, mlp_inner_dim=4096, text_vocab_size=704, max_pos_text=2050, num_heads=16):
        super().__init__()
        self.text_emb = nn.Embedding(text_vocab_size, d_model)
        self.text_pos_emb = nn.Embedding(max_pos_text, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, num_heads=num_heads, mlp_inner_dim=mlp_inner_dim) for _ in range(n_layers)])
        self.final_ln = nn.LayerNorm(d_model)

    def forward_text(self, token_ids):
        b, t = token_ids.shape
        pos = torch.arange(t, device=token_ids.device).unsqueeze(0).expand(b, t)
        x = self.text_emb(token_ids) + self.text_pos_emb(pos)
        for layer in self.layers:
            x = layer(x)
        x = self.final_ln(x)
        return x 

In [15]:
t3_transformer_part = TransformerBackbone(n_layers=config["n_layer"], d_model=config["d_model"], mlp_inner_dim=config["mlp_inner_dim"], text_vocab_size=config["text_vocab_size"], max_pos_text=config["text_max_pos"])

In [16]:
print(t3_transformer_part)

TransformerBackbone(
  (text_emb): Embedding(704, 1024)
  (text_pos_emb): Embedding(2050, 1024)
  (layers): ModuleList(
    (0-29): 30 x TransformerBlock(
      (ln1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (k_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (v_proj): Linear(in_features=1024, out_features=1024, bias=False)
        (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
      )
      (ln2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (gate_proj): Linear(in_features=1024, out_features=4096, bias=True)
      (up_proj): Linear(in_features=1024, out_features=4096, bias=True)
      (down_proj): Linear(in_features=4096, out_features=1024, bias=True)
    )
  )
  (final_ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)


In [17]:
for k, v in t3_transformer_part.state_dict().items():
    print(k, v.shape)

text_emb.weight torch.Size([704, 1024])
text_pos_emb.weight torch.Size([2050, 1024])
layers.0.ln1.weight torch.Size([1024])
layers.0.ln1.bias torch.Size([1024])
layers.0.attn.q_proj.weight torch.Size([1024, 1024])
layers.0.attn.k_proj.weight torch.Size([1024, 1024])
layers.0.attn.v_proj.weight torch.Size([1024, 1024])
layers.0.attn.o_proj.weight torch.Size([1024, 1024])
layers.0.ln2.weight torch.Size([1024])
layers.0.ln2.bias torch.Size([1024])
layers.0.gate_proj.weight torch.Size([4096, 1024])
layers.0.gate_proj.bias torch.Size([4096])
layers.0.up_proj.weight torch.Size([4096, 1024])
layers.0.up_proj.bias torch.Size([4096])
layers.0.down_proj.weight torch.Size([1024, 4096])
layers.0.down_proj.bias torch.Size([1024])
layers.1.ln1.weight torch.Size([1024])
layers.1.ln1.bias torch.Size([1024])
layers.1.attn.q_proj.weight torch.Size([1024, 1024])
layers.1.attn.k_proj.weight torch.Size([1024, 1024])
layers.1.attn.v_proj.weight torch.Size([1024, 1024])
layers.1.attn.o_proj.weight torch.Size