In [1]:
def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])
    
    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer.linear1)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)
def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])
    
    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)


In [2]:
from configs.transformer_tiny import Config
from tokenizer import TokenizerWrapper
from data import WMTENDE, CustomBatchSampler, pad_collate_fn
import torch
from torch.utils.data import DataLoader
from model import build_transformer

config = Config()
# tokenizer = TokenizerWrapper(config)
# data = WMTENDE(config, tokenizer, 'test')
# data.data.shuffle(seed=20)
# data.data = data.data.select(range(2))
# sampler = CustomBatchSampler(len(data), batch_size=2)
# dl = DataLoader(data, batch_sampler=sampler, collate_fn=pad_collate_fn)

custom_model = build_transformer(config).transformer
pytorch_model = torch.nn.Transformer(d_model=4, nhead=2, dropout=0.0, batch_first=True)
# for batch in dl:
#     src, trg = batch
#     break

src = torch.rand(1, 5, 4)
tgt = torch.rand(1, 5, 4)
src_mask = torch.tensor([1,1,1,0,0], dtype=bool).unsqueeze(0)
tgt_mask = torch.tensor([1,1,1,1,0], dtype=bool).unsqueeze(0)
copy_transformer_weights(custom_model, pytorch_model)

In [6]:
pytorch_model(src, tgt, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask)

tensor([[[ 1.4844,  0.3405, -0.8253, -0.9997],
         [ 1.5232,  0.2552, -0.7600, -1.0184],
         [ 1.4834,  0.3430, -0.8295, -0.9969],
         [ 1.4837,  0.3401, -0.8123, -1.0115],
         [ 1.4748,  0.3561, -0.8098, -1.0210]]],
       grad_fn=<NativeLayerNormBackward0>)

In [7]:
custom_model(src, tgt, src_mask, tgt_mask, None)

tensor([[[ 1.4844,  0.3405, -0.8253, -0.9997],
         [ 1.5232,  0.2552, -0.7600, -1.0184],
         [ 1.4834,  0.3430, -0.8295, -0.9969],
         [ 1.4837,  0.3402, -0.8123, -1.0116],
         [ 1.4748,  0.3561, -0.8098, -1.0210]]],
       grad_fn=<NativeLayerNormBackward0>)