In [1]:
def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])
    
    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer.linear1)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)
def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])
    
    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)


In [None]:
!cd /content/Transformer-from-scratch-in-pytorch
import os
import sys
sys.path.append('/content/Transformer-from-scratch-in-pytorch')
os.environ['PYTHONPATH'] = "/content/Transformer-from-scratch-in-pytorch"
from configs.transformer_tiny import Config
from tokenizer import TokenizerWrapper
from data import WMTENDE, CustomBatchSampler, pad_collate_fn
import torch
from torch.utils.data import DataLoader
from model import build_transformer

def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])

    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer.linear1)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)
def copy_attention_weights(custom_attn, torch_attn):
    # QKV weights: PyTorch uses one big linear layer for qkv
    qkv_weight = torch_attn.self_attn.in_proj_weight
    qkv_bias = torch_attn.self_attn.in_proj_bias
    d_model = qkv_weight.shape[1]
    d_k = d_model // custom_attn.h

    # Split into Q, K, V
    custom_attn.w_q.weight.data.copy_(qkv_weight[:d_model])
    custom_attn.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
    custom_attn.w_v.weight.data.copy_(qkv_weight[2*d_model:])

    # Output projection
    custom_attn.w_o.weight.data.copy_(torch_attn.self_attn.out_proj.weight)

def copy_ffn_weights(custom_ffn, torch_ffn):
    custom_ffn.linear_1.weight.data.copy_(torch_ffn.linear1.weight)
    custom_ffn.linear_1.bias.data.copy_(torch_ffn.linear1.bias)
    custom_ffn.linear_2.weight.data.copy_(torch_ffn.linear2.weight)
    custom_ffn.linear_2.bias.data.copy_(torch_ffn.linear2.bias)

def copy_layer_norm(custom_ln, torch_ln):
    custom_ln.weight.data.copy_(torch_ln.weight)
    custom_ln.bias.data.copy_(torch_ln.bias)

def copy_encoder_weights(custom_encoder, torch_encoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_encoder.encoder_layers, torch_encoder.layers)):
        copy_attention_weights(c_layer.multi_headed_self_attention, t_layer)
        copy_ffn_weights(c_layer.ffn, t_layer)  # Pass layer1 for dimension reference
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)

def copy_decoder_weights(custom_decoder, torch_decoder):
    for i, (c_layer, t_layer) in enumerate(zip(custom_decoder.decoder_layers, torch_decoder.layers)):
        # Masked self-attention
        qkv_weight = t_layer.self_attn.in_proj_weight
        d_model = qkv_weight.shape[1]
        c_layer.masked_multi_headed_self_attention.w_q.weight.data.copy_(qkv_weight[:d_model])
        c_layer.masked_multi_headed_self_attention.w_k.weight.data.copy_(qkv_weight[d_model:2*d_model])
        c_layer.masked_multi_headed_self_attention.w_v.weight.data.copy_(qkv_weight[2*d_model:])
        c_layer.masked_multi_headed_self_attention.w_o.weight.data.copy_(t_layer.self_attn.out_proj.weight)

        # Cross-attention
        cross_qkv_weight = t_layer.multihead_attn.in_proj_weight
        c_layer.multi_headed_cross_attention.w_q.weight.data.copy_(cross_qkv_weight[:d_model])
        c_layer.multi_headed_cross_attention.w_k.weight.data.copy_(cross_qkv_weight[d_model:2*d_model])
        c_layer.multi_headed_cross_attention.w_v.weight.data.copy_(cross_qkv_weight[2*d_model:])
        c_layer.multi_headed_cross_attention.w_o.weight.data.copy_(t_layer.multihead_attn.out_proj.weight)

        # FFN and LayerNorm
        copy_ffn_weights(c_layer.ffn, t_layer)
        copy_layer_norm(c_layer.norm1, t_layer.norm1)
        copy_layer_norm(c_layer.norm2, t_layer.norm2)
        copy_layer_norm(c_layer.norm3, t_layer.norm3)

def copy_transformer_weights(custom_transformer, torch_transformer):
    copy_encoder_weights(custom_transformer.encoder, torch_transformer.encoder)
    copy_decoder_weights(custom_transformer.decoder, torch_transformer.decoder)

config = Config()
tokenizer = TokenizerWrapper(config)
data = WMTENDE(config, tokenizer, 'test')
data.data.shuffle(seed=20)
# data.data = data.data.select(range(2))
sampler = CustomBatchSampler(len(data), batch_size=2)
dl = DataLoader(data, batch_sampler=sampler, collate_fn=pad_collate_fn)

custom_transformer = build_transformer(config)
custom_model = custom_transformer.transformer
pytorch_model = torch.nn.Transformer(d_model=256, nhead=8, dropout=0.0, batch_first=True)
for batch in dl:
    src, tgt = batch

src = torch.cat((src, torch.zeros(2,2, dtype=torch.long)), dim=1)
tgt = torch.cat((tgt, torch.zeros(2,3, dtype=torch.long)), dim=1)
src_mask = (src == 0)
tgt_mask = (tgt == 0)
causal_mask = pytorch_model.generate_square_subsequent_mask(tgt.size(1), dtype=bool)

src = custom_transformer.src_word_embedding(src)
tgt = custom_transformer.src_word_embedding(tgt)

# src = torch.rand(1, 5, 4)
# tgt = torch.rand(1, 5, 4)
# src_mask = ~torch.tensor([1,1,1,0,0], dtype=bool).unsqueeze(0)
# tgt_mask = ~torch.tensor([1,1,1,1,0], dtype=bool).unsqueeze(0)
copy_transformer_weights(custom_model, pytorch_model)

pytorch_out = pytorch_model(src, tgt, src_key_padding_mask=src_mask, tgt_key_padding_mask=tgt_mask, memory_key_padding_mask=src_mask, tgt_mask=causal_mask, tgt_is_causal=True)
custom_out = custom_model(src, tgt, src_mask.unsqueeze(1).unsqueeze(2), tgt_mask.unsqueeze(1).unsqueeze(2) | causal_mask, src_mask.unsqueeze(1).unsqueeze(2))

The system cannot find the path specified.
Reusing dataset wmt14 (C:\Users\neetm\.cache\huggingface\datasets\wmt14\de-en\1.0.0\3d7d25048da28a2f2a8dda5ca306bdc4affcf02bbcb3d277cfe2d5d7b1d71ebc)


  0%|          | 0/4 [00:00<?, ?ba/s]

In [3]:
pytorch_out

tensor([[[ 0.3518,  0.5350, -0.1408,  ...,  0.8623, -0.5968, -0.7835],
         [ 0.2023,  0.5008, -0.2484,  ...,  1.0054, -0.6231, -0.9629],
         [ 0.0614,  0.6979,  0.1961,  ...,  1.1849, -0.6526, -1.2559],
         ...,
         [ 0.1636,  0.8781,  0.0622,  ...,  1.4749, -1.0057, -1.1822],
         [ 0.3832,  1.5135,  0.1186,  ...,  0.9393, -0.9194, -1.6305],
         [-0.0811,  1.3183,  0.2863,  ...,  1.0324, -1.3263, -1.6134]],

        [[ 0.5642,  0.8953, -0.5905,  ...,  0.3664, -0.8350, -1.1837],
         [ 0.3323,  0.7225, -0.3582,  ...,  0.4065, -0.7129, -1.5853],
         [ 0.4308,  0.7377, -0.5279,  ...,  0.6157, -0.6066, -1.5947],
         ...,
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685],
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685],
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685]]],
       grad_fn=<NativeLayerNormBackward0>)

In [4]:
custom_out

tensor([[[ 0.3518,  0.5350, -0.1408,  ...,  0.8623, -0.5968, -0.7835],
         [ 0.2023,  0.5008, -0.2484,  ...,  1.0054, -0.6231, -0.9629],
         [ 0.0614,  0.6979,  0.1961,  ...,  1.1849, -0.6526, -1.2559],
         ...,
         [ 0.1636,  0.8781,  0.0622,  ...,  1.4749, -1.0057, -1.1822],
         [ 0.3832,  1.5135,  0.1186,  ...,  0.9393, -0.9194, -1.6305],
         [-0.0811,  1.3183,  0.2863,  ...,  1.0324, -1.3263, -1.6134]],

        [[ 0.5642,  0.8953, -0.5905,  ...,  0.3664, -0.8350, -1.1837],
         [ 0.3323,  0.7225, -0.3582,  ...,  0.4065, -0.7129, -1.5853],
         [ 0.4308,  0.7377, -0.5279,  ...,  0.6157, -0.6066, -1.5947],
         ...,
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685],
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685],
         [ 0.3337,  1.0413, -0.1129,  ...,  0.7482, -0.8384, -1.6685]]],
       grad_fn=<NativeLayerNormBackward0>)