In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
torch.manual_seed(42)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)#d_model is the the dimension of the embedding,maxlen-is the maximum length of sequence
        position = torch.arange(0, max_len).unsqueeze(1).float() #shape('max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model)) #calculates the positional encodings using the formula
        self.encoding[:, 0::2] = torch.sin(position * div_term)#even position
        self.encoding[:, 1::2] = torch.cos(position * div_term)#odd position
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModelLogger:
  def __init__(self, is_logging=False):
    self.is_logging = is_logging
    self.logs = []

  def log(self, message):
    if self.is_logging:
      self.logs.append(message)
      print(message)

  def print_logs(self):
     for message in self.logs:

      print(message)

      self.logs = []

class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.logger = TransformerModelLogger(is_logging=True)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)

        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()#triangular matrix used to mask future positiions of the target sequence

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        self.logger.log("Tgt mask shape = " + str(tgt_mask.shape))

        src = self.src_embedding(src) + self.positional_encoding(src)
        self.logger.log(f"src (after embedding and positional encoding): {src.shape}")
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)
        self.logger.log(f"tgt (after embedding and positional encoding): {tgt.shape}")


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        self.logger.log(f"output (after transformer): {output.shape}")
        output = self.fc(output)
        print("output:{output}")

        return output



In [93]:
torch.manual_seed(42)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
d_ff = 20
max_seq_len = 5
dropout = 0

transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)



src_data = torch.tensor([[2], [1], [5], [4]])
tgt_data = torch.tensor([[1], [16], [5], [3]])


state_dict = transformer.state_dict()



In [3]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [4]:
transformer.train()
output = transformer(src_data, tgt_data)







Tgt mask shape = torch.Size([4, 4])
src (after embedding and positional encoding): torch.Size([4, 1, 16])
tgt (after embedding and positional encoding): torch.Size([4, 1, 16])
inital input for else  tensor([[[-1.3847,  0.1288, -0.2234,  2.7174,  0.3189,  0.5755,  0.3057,
           0.2254, -1.5576,  1.9956, -0.8798,  0.3989, -1.2742,  3.1228,
          -1.2347,  0.5121]],

        [[ 1.6423,  0.8404, -0.4974,  1.4396, -0.7581,  2.0783,  0.8008,
           2.6806,  1.2791,  2.2964,  0.6105,  2.3347, -0.2316,  1.0418,
          -0.2516,  1.8599]],

        [[ 0.0109,  0.6613, -1.3407,  0.4146,  0.5362,  1.5246,  1.1412,
           1.0516,  0.7440,  0.5184, -1.0495,  1.6039, -1.7223,  0.1722,
           1.3347,  1.4835]],

        [[ 1.4451,  1.8564,  2.2181,  1.5232,  0.3466,  0.8027, -1.0546,
           2.2780, -0.1722,  1.5238,  0.0566,  1.4263,  0.5750,  0.3583,
          -2.2064,  0.2492]]], grad_fn=<AddBackward0>)
 in_projection_packed 
shapes  =  torch.Size([4, 1, 16]) torch.Size([

In [5]:
print(output)

tensor([[[ 0.5040, -0.5643, -0.9291,  0.8937,  0.5021,  0.6562, -1.3417,
           0.5391,  0.0286,  0.4028,  0.6045,  0.0248,  0.4549, -0.3143,
          -1.1332,  0.7255,  0.1034, -0.2828, -0.1362,  0.0014]],

        [[ 0.2908, -0.4276, -1.1432,  0.8511,  0.3242,  0.2948, -1.3010,
           0.6434,  0.5933, -0.0025,  0.3320, -0.1345,  0.6465, -0.0753,
          -1.1380,  0.1948, -0.0462, -0.3189, -0.3971,  0.0195]],

        [[ 0.2088, -0.7398, -1.1280,  0.7504,  0.3378,  0.3129, -1.4142,
           0.7165,  0.6326, -0.1854,  0.2747, -0.0866,  0.7972, -0.0041,
          -1.0578,  0.4315,  0.0825, -0.2325, -0.3063,  0.0676]],

        [[ 0.4189, -0.6181, -0.7970,  1.2109,  0.7027,  0.0376, -0.7485,
           0.8167,  0.2986, -0.2507,  0.6037, -0.0616,  1.0253,  0.1367,
          -1.2094,  0.7994,  0.3362, -0.8986,  0.2187, -0.1611]]],
       grad_fn=<ViewBackward0>)


In [73]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
torch.manual_seed(42)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)#d_model is the the dimension of the embedding,maxlen-is the maximum length of sequence
        position = torch.arange(0, max_len).unsqueeze(1).float() #shape('max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model)) #calculates the positional encodings using the formula
        self.encoding[:, 0::2] = torch.sin(position * div_term)#even position
        self.encoding[:, 1::2] = torch.cos(position * div_term)#odd position
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModelLogger:
  def __init__(self, is_logging=False):
    self.is_logging = is_logging
    self.logs = []

  def log(self, message):
    if self.is_logging:
      self.logs.append(message)
      print(message)

  def print_logs(self):
     for message in self.logs:

      print(message)

      self.logs = []

class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.logger = TransformerModelLogger(is_logging=True)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)

        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()#triangular matrix used to mask future positiions of the target sequence

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        self.logger.log("Tgt mask shape = " + str(tgt_mask.shape))

        src = self.src_embedding(src) + self.positional_encoding(src)
        self.logger.log(f"src (after embedding and positional encoding): {src.shape}")
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)
        self.logger.log(f"tgt (after embedding and positional encoding): {tgt.shape}")


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        self.logger.log(f"output (after transformer): {output.shape}")
        output = self.fc(output)
        print("output:{output}")

        return output



In [94]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)#d_model is the the dimension of the embedding,maxlen-is the maximum length of sequence
        position = torch.arange(0, max_len).unsqueeze(1).float() #shape('max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model)) #calculates the positional encodings using the formula
        self.encoding[:, 0::2] = torch.sin(position * div_term)#even position
        self.encoding[:, 1::2] = torch.cos(position * div_term)#odd position
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()


In [95]:
def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):

            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

In [96]:
def get_embedding_outputs(src_sentence,  tgt_sentence, max_seq_len, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_sentence, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)

    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    tgt_embedding =  look_up_table(tgt_sentence, tgt_vocab_embeds, tgt_embedding)

    pe = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

    print("PE of src :")
    print(pe(src_sentence))
    print()
    print("PE of tgt :")
    print(pe(tgt_sentence))
    print()

    pe_src_embeds = src_embedding + pe(src_sentence)

    pe_tgt_embeds = tgt_embedding + pe(tgt_sentence)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds

In [76]:
def get_embedding_outputs(src_sentence,  tgt_sentence, max_seq_len, state_dict, d_model):

    src_vocab_embeds = state_dict["src_embedding.weight"]

    src_embedding = torch.zeros(src_sentence.size(0), src_sentence.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_sentence, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)

    tgt_vocab_embeds = state_dict["tgt_embedding.weight"]

    tgt_embedding = torch.zeros(tgt_sentence.size(0), tgt_sentence.size(1), d_model)


    print("Target sentence embedding")
    tgt_embedding =  look_up_table(tgt_sentence, tgt_vocab_embeds, tgt_embedding)

    pe = PositionalEncoding(d_model = d_model, dropout=0, max_len=max_seq_len)

    print("PE of src :")
    print(pe(src_sentence))
    print()
    print("PE of tgt :")
    print(pe(tgt_sentence))
    print()

    pe_src_embeds = src_embedding + pe(src_sentence)

    pe_tgt_embeds = tgt_embedding + pe(tgt_sentence)

    print("PE source embeddings : \n")
    print(pe_src_embeds)
    print()

    print("PE target embeddings : \n")
    print(pe_tgt_embeds)
    print()

    return pe_src_embeds, pe_tgt_embeds

In [97]:
src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
d_ff = 20
max_seq_len = 5
dropout = 0

torch.manual_seed(42)
model = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

In [99]:
transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)
state_dict = model.state_dict()

In [100]:
import numpy as np
src_data = np.array([[2], [1], [5], [4]])
tgt_data = np.array([[1], [16], [5], [3]])

src_tensor = torch.tensor(src_data, dtype=torch.long).T
tgt_tensor = torch.tensor(tgt_data, dtype=torch.long).T

In [56]:
d_model = 16  # Example dimension of the embeddings, should match d_model used in your Transformer modelReplace with your actual state_dict obtained earlier

src_data = np.array([[2], [1], [5], [4]])
tgt_data = np.array([[1], [16], [5], [3]])
# Retrieve src_embedding weights from state_dict
src_vocab_embeds = state_dict["src_embedding.weight"]

# Initialize src_embeddings array
src_embeddings = np.zeros((src_data.shape[0], d_model))

# Iterate over src_data to fill src_embeddings
for i in range(src_data.shape[0]):
    word_index = src_data[i].item()  # Extract the word index as an integer
    if word_index < 0 or word_index >= src_vocab_embeds.shape[0]:
        print(f"Invalid word index: {word_index}")
    else:
        src_embeddings[i, :] = src_vocab_embeds[word_index, :].numpy()

    # Print the word index and corresponding embedding (for verification)
    print(f"Word index: {word_index}, Embedding: {src_embeddings[i, :]}")

# Print the shape of src_embeddings
print(f"Shape of src_embeddings: {src_embeddings.shape}")


Word index: 2, Embedding: [-1.38467371 -0.87123615 -0.22336592  1.71736145  0.31888032 -0.42451897
  0.30572093 -0.77459252 -1.55757248  0.99563611 -0.87978584 -0.60114205
 -1.27415121  2.12278509 -1.23465312 -0.48791388]
Word index: 1, Embedding: [ 1.64231694 -0.15959747 -0.49739754  0.43958926 -0.75813115  1.07831764
  0.80080056  1.68062055  1.27912438  1.29642284  0.61046648  1.33473778
 -0.23162432  0.04175949 -0.25157529  0.85985851]
Word index: 5, Embedding: [ 0.01086814 -0.33874235 -1.34067953 -0.58537054  0.53618813  0.52462262
  1.14120162  0.0516436   0.74395198 -0.4815844  -1.04946613  0.60389882
 -1.72229505 -0.82776886  1.33470297  0.48353928]
Word index: 4, Embedding: [ 1.44513381  0.85641253  2.21807575  0.52316552  0.34664667 -0.19733144
 -1.05458891  1.27799559 -0.17219013  0.52378845  0.05662182  0.42629614
  0.57500505 -0.64172411 -2.20639849 -0.75080305]
Shape of src_embeddings: (4, 16)


In [9]:
import numpy as np
def positional_encoding(max_len, d_model):

    pos_encodings = np.zeros((max_len, d_model))
    for pos in range(max_len):
        for i in range(0, d_model, 2):
            pos_encodings[pos, i] = np.sin(pos / (10000 ** ((2 * i)/d_model)))
            pos_encodings[pos, i + 1] = np.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
    return pos_encodings

In [10]:
pos_encodings=positional_encoding(4,16)
final_src=pos_encodings + src_embeddings

In [11]:
enc=final_src

In [80]:
def calculate_qkv(query, key, value ,W, b):


    E = query.size(-1)

    if key is value:
        if query is key:
            tempop1 = query@W.T
            tempop1 = tempop1.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
            return tempop1[0], tempop1[1], tempop1[2]
        else:
            W_q, W_kv = W.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            q_matmul = query@W_q.T
            kv_matmul = key@W_kv.T
            print(kv_matmul)

            kv_matmul = kv_matmul.unflatten(-1, (2, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
            return q_matmul, kv_matmul[0], kv_matmul[1]
    else:

        W_q, W_k, W_v = W.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        q_matmul = query@W_q.T
        k_matmul = key@W_k.T
        v_matmul = value@W_v.T
        return q_matmul, k_matmul, v_matmul

In [81]:
def attention_calculate(Q, V, K, bsz, head_dim, src_len, tgt_len, embed_dim, attn_mask, num_heads):
    Q1 = Q.view(bsz, num_heads, tgt_len, head_dim)
    K1 = K.view(bsz, num_heads, src_len, head_dim)
    V1 = V.view(bsz, num_heads, src_len, head_dim)

    L, S = Q1.size(-2), K1.size(-2)

    scale_factor = 1 / math.sqrt(Q1.size(-1))
    attn_bias = torch.zeros(L, S, dtype=Q1.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            masked_tensor = attn_mask.float().masked_fill(attn_mask, float('-inf'))
            masked_tensor = masked_tensor.masked_fill(~attn_mask, 0)
            attn_mask = masked_tensor
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)
            attn_bias += attn_mask

        else:
            attn_bias += attn_mask
            attn_bias = attn_bias.unsqueeze(0).unsqueeze(0)
    attn_weight = Q1 @ K1.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)

    print("Attention_weights after softmax= ", attn_weight)
    print()
    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"

    print(attn_weight)
    attn_output = attn_weight @ V1

    print("final_attention_value after multiplication with Venc= ", attn_output)
    attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)

    print("Attention output = ")
    print(attn_weight.shape, attn_weight)

    return attn_output

def attention_calculate_if_needweightstrue(Q, K, V, bsz, tgt_len, embed_dim, attn_mask):
    B, Nt, E = Q.shape
    Q_scaled = Q / math.sqrt(E)

    if attn_mask is not None:
        temp_pdt_matrix = torch.baddbmm(attn_mask, Q_scaled, K.transpose(-2, -1))
    else:
        temp_pdt_matrix = torch.bmm(Q_scaled, K.transpose(-2, -1))

    attn_wt_matrix = torch.nn.functional.softmax(temp_pdt_matrix, dim=-1)
    attn_output = torch.bmm(attn_wt_matrix, V)
    attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)


    sum_last_dim = attn_wt_matrix.sum(dim=-1)
    tolerance = 1e-6
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"

    print("Encoder Attention output = ")
    print(attn_output)
    print()

    return attn_output, attn_wt_matrix

In [82]:
def encoder_block_attention_output(x, state_dict, layer_num, embed_dim, num_heads, need_weights = False, src_mask = None):
    query_enc = key_enc = value_enc = x
    tgt_len, bsz, embed_dim = x.shape

    W_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_enc = state_dict["transformer.encoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]

    head_dim = embed_dim//num_heads
    Q_enc,K_enc,V_enc = calculate_qkv(query_enc, key_enc, value_enc ,W_enc, b_enc)

    Q_enc = Q_enc.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_enc = K_enc.reshape(K_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_enc = V_enc.reshape(V_enc.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    print("Query_{} = ".format(layer_num))
    print(Q_enc)
    print()

    print("Key_{} = ".format(layer_num))
    print(K_enc)
    print()

    print("Value_enc_{} = ".format(layer_num))
    print(V_enc)
    print()


    src_len = K_enc.size(1)

    attn_mask = src_mask
    if attn_mask is not None:
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)


    if need_weights is False:

        attn_output = attention_calculate(Q_enc, V_enc, K_enc, bsz, head_dim, src_len, tgt_len, embed_dim, attn_mask, num_heads=num_heads)

        return attn_output, src_len, head_dim, None

    else:

        attn_enc_output,attn_wt_matrix_enc = attention_calculate_if_needweightstrue(Q_enc, K_enc, V_enc, bsz, tgt_len, embed_dim,  attn_mask)

        return attn_enc_output, src_len, head_dim, attn_wt_matrix_enc


In [83]:
def encoder_block_post_attn_output(x, attn_enc_output, state_dict, layer_num, bsz, tgt_len):
    op_enc_1 = torch.matmul(attn_enc_output, state_dict["transformer.encoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
    attn_enc_output = op_enc_1.view(tgt_len, bsz, attn_enc_output.size(1))
    output_enc_1 = attn_enc_output + x
    linear_result_enc_1 = output_enc_1*state_dict["transformer.encoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_enc_1 = torch.nn.LayerNorm(normalized_shape=linear_result_enc_1.shape[2:])
    linear_op_enc_1 = layernorm_enc_1(linear_result_enc_1)

    x = linear_result_enc_1
    w = layernorm_enc_1.weight
    b = layernorm_enc_1.bias

    linear_result_enc_1f = w*x + b

    epsilon = 1e-05
    mean = linear_result_enc_1f.mean(dim=-1, keepdim=True)
    std = linear_result_enc_1f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_enc_1 = (linear_result_enc_1f - mean) / (std + epsilon) * w + b
    op_enc_1 = torch.matmul(normalized_result_enc_1, state_dict["transformer.encoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear1.bias".format(layer_num)]
    op_enc_1_relu = torch.nn.functional.relu(op_enc_1)
    op_enc_2 = torch.matmul(op_enc_1_relu, state_dict["transformer.encoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.encoder.layers.{}.linear2.bias".format(layer_num)]
    output_enc_2 = op_enc_2 + linear_op_enc_1
    output_enc_2_norm = output_enc_2*state_dict["transformer.encoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.encoder.layers.{}.norm2.bias".format(layer_num)]
    layernorm_enc_final = torch.nn.LayerNorm(normalized_shape=output_enc_2_norm.shape[2:])
    output_enc_final = layernorm_enc_final(output_enc_2_norm)
    x = output_enc_2_norm
    w = layernorm_enc_final.weight
    b = layernorm_enc_final.bias

    linear_result_enc_2 = w*x + b

    print(linear_result_enc_2)
    epsilon = 1e-05
    mean = linear_result_enc_2.mean(dim=-1, keepdim=True)
    std = linear_result_enc_2.std(dim=-1, unbiased=False, keepdim=True)
    output_enc_final = (linear_result_enc_2 - mean) / (std + epsilon) * w + b


    print("final_encoder_output :".format(layer_num))
    print("after feed forward and norm2 is applied")
    print(output_enc_final)
    print()
    return output_enc_final

In [84]:
def decoder_block_self_attn_output(x, state_dict, layer_num, num_heads, tgt_mask = None,need_weights = False):
    query_dec = key_dec = value_dec = x
    tgt_len, bsz, embed_dim = x.shape
    W_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_weight".format(layer_num)]
    b_dec = state_dict["transformer.decoder.layers.{}.self_attn.in_proj_bias".format(layer_num)]


    head_dim = embed_dim//num_heads
    Q_dec,K_dec,V_dec = calculate_qkv(query_dec, key_dec, value_dec ,W_dec, b_dec)
    Q_dec = Q_dec.reshape(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
    K_dec = K_dec.reshape(K_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec = V_dec.reshape(V_dec.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec)
    print()
    print("K_dec_{} = ".format(layer_num))
    print(K_dec)
    print()
    print("V_dec_{} = ".format(layer_num))
    print(V_dec)
    print()

    src_len = K_dec.size(1)


    attn_mask = tgt_mask
    if attn_mask is not None:
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)



    if need_weights is False:
        attn_output = attention_calculate(Q = Q_dec, V = V_dec, K = K_dec, bsz = bsz, head_dim=head_dim, src_len=src_len, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim, num_heads=num_heads
        )

        print("decoder = ")
        print(attn_output)
        print()

        op_dec_1 = torch.matmul(attn_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_output.size(1))

        return attn_dec_output, None

    else:

        attn_dec_output,attn_wt_matrix_dec = attention_calculate_if_needweightstrue(Q=Q_dec, K=K_dec, V=V_dec, bsz=bsz, tgt_len=tgt_len, attn_mask = attn_mask, embed_dim=embed_dim)

        print("attention if needweights is true = ")
        print(attn_wt_matrix_dec)
        print()

        op_dec_1 = torch.matmul(attn_dec_output, state_dict["transformer.decoder.layers.{}.self_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.self_attn.out_proj.bias".format(layer_num)]
        attn_dec_output = op_dec_1.view(tgt_len, bsz, attn_dec_output.size(1))


        return attn_dec_output, attn_wt_matrix_dec


In [62]:
def dec_post_self_attn(self_attn_dec, x, state_dict, layer_num):
    output_dec_1 = self_attn_dec + x
    print(output_dec_1)
    linear_result_dec_1 = output_dec_1*state_dict["transformer.decoder.layers.{}.norm1.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm1.bias".format(layer_num)]

    layernorm_dec_1 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_1.shape[2:])
    linear_op_dec_1 = layernorm_dec_1(linear_result_dec_1)
    x = linear_result_dec_1

    w = layernorm_dec_1.weight
    b = layernorm_dec_1.bias

    linear_result_dec_1f = w*x + b

    epsilon = 1e-05
    mean = linear_result_dec_1f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_1f.std(dim=-1, unbiased=False, keepdim=True)

    normalized_result_dec_1 = (linear_result_dec_1f - mean) / (std + epsilon) * w + b


    print("decoder after layernorm1".format(layer_num))
    print(normalized_result_dec_1)
    print()
    return normalized_result_dec_1

In [85]:
def decoder_block_cross_attn_output(x_dec, memory, state_dict, layer_num, tgt_len, src_len, head_dim, num_heads, memory_mask = None,need_weights = False):

    print("next input from the encoder= ",memory)
    print()
    query_dec_mha = x_dec
    key_dec_mha, value_dec_mha = memory, memory
    tgt_len, bsz, embed_dim = query_dec_mha.shape

    W_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_weight".format(layer_num)]
    b_dec_mha = state_dict["transformer.decoder.layers.{}.multihead_attn.in_proj_bias".format(layer_num)]
    Q_dec_mha,K_dec_mha,V_dec_mha = calculate_qkv(query_dec_mha, key_dec_mha, value_dec_mha ,W_dec_mha, b_dec_mha)
    Q_dec_mha = Q_dec_mha.reshape(Q_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    K_dec_mha = K_dec_mha.reshape(K_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
    V_dec_mha = V_dec_mha.reshape(V_dec_mha.shape[0], bsz * num_heads, head_dim).transpose(0, 1)

    print("Q_dec_{} = ".format(layer_num))
    print(Q_dec_mha)
    print()
    print("K_dec_{} = ".format(layer_num))
    print(K_dec_mha)
    print()
    print("V_dec_{} = ".format(layer_num))
    print(V_dec_mha)
    print()


    attn_mask = memory_mask
    if attn_mask is not None:
        if attn_mask.dim() == 2:
            correct_2d_size = (tgt_len, src_len)
            if attn_mask.shape != correct_2d_size:
                raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
            attn_mask = attn_mask.unsqueeze(0)
        elif attn_mask.dim() == 3:
            correct_3d_size = (bsz * num_heads, tgt_len, src_len)
            if attn_mask.shape != correct_3d_size:
                raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
        else:
            raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
    if attn_mask is not None:
        if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
            attn_mask = attn_mask.unsqueeze(0)
        else:
            attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)



    if need_weights is False:

        attn_output_dec_mha = attention_calculate(Q =Q_dec_mha, V=V_dec_mha, K=K_dec_mha, bsz=bsz, head_dim=head_dim, src_len=src_len, tgt_len=tgt_len, attn_mask=attn_mask, embed_dim=embed_dim, num_heads=num_heads)
        print("after cross_attention_{}".format(layer_num))
        print(attn_output_dec_mha)
        print()

        op_dec_mha_1 = torch.matmul(attn_output_dec_mha, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
        attn_dec_mha_output = op_dec_mha_1.view(tgt_len, bsz, attn_output_dec_mha.size(1))

        return attn_dec_mha_output, None

    else:
      attn_dec_mha_output ,attn_wt_matrix_dec_mha = attention_calculate_if_needweightstrue(Q=Q_dec_mha, K=K_dec_mha, V=V_dec_mha, bsz=bsz, tgt_len=tgt_len, attn_mask=attn_mask, embed_dim=embed_dim)
      print("decoder attention_output = ")
      print(attn_dec_mha_output)
      print()


      op_dec_mha = torch.matmul(attn_dec_mha_output, state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.multihead_attn.out_proj.bias".format(layer_num)]
      attn_dec_output_mha = op_dec_mha.view(tgt_len, bsz, attn_dec_mha_output.size(1))


      return attn_dec_output_mha , attn_wt_matrix_dec_mha


In [64]:
def decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num):
    output_dec_2 = attn_dec_mha_output + x_dec
    print(output_dec_2)
    linear_result_dec_2 = output_dec_2*state_dict["transformer.decoder.layers.{}.norm2.weight".format(layer_num)] + state_dict["transformer.decoder.layers.{}.norm2.bias".format(layer_num)]
    layernorm_dec_2 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_2.shape[2:])
    linear_op_dec_2 = layernorm_dec_2(linear_result_dec_2)

    x = linear_result_dec_2
    w = layernorm_dec_2.weight
    b = layernorm_dec_2.bias

    linear_result_dec_2f = w*x + b

    epsilon = 1e-05
    mean = linear_result_dec_2f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_2f.std(dim=-1, unbiased=False, keepdim=True)
    normalized_result_dec_2 = (linear_result_dec_2f - mean) / (std + epsilon) * w + b
    print((linear_result_dec_2f - mean))
    print("after norm2")
    print(normalized_result_dec_2)
    x_dec2_norm = normalized_result_dec_2

    op_dec_1 = torch.matmul(x_dec2_norm, state_dict["transformer.decoder.layers.{}.linear1.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear1.bias".format(layer_num)]

    op_dec_1_relu = torch.nn.functional.relu(op_dec_1)
    op_dec_2 = torch.matmul(op_dec_1_relu, state_dict["transformer.decoder.layers.{}.linear2.weight".format(layer_num)].t()) + state_dict["transformer.decoder.layers.{}.linear2.bias".format(layer_num)]


    ff_dec = op_dec_2

    x_dec3_unorm = x_dec2_norm + ff_dec

    linear_result_dec_3 = x_dec3_unorm*state_dict["transformer.decoder.layers.0.norm3.weight"] + state_dict["transformer.decoder.layers.0.norm3.bias"]

    layernorm_dec_3 = torch.nn.LayerNorm(normalized_shape=linear_result_dec_3.shape[2:])
    linear_op_dec_3 = layernorm_dec_3(linear_result_dec_3)
    x = linear_result_dec_3
    w = layernorm_dec_3.weight
    b = layernorm_dec_3.bias

    linear_result_dec_3f = w*x + b

    epsilon = 1e-5
    mean = linear_result_dec_3f.mean(dim=-1, keepdim=True)
    std = linear_result_dec_3f.std(dim=-1, unbiased=False, keepdim=True)
    print( (linear_result_dec_3f - mean))
    normalized_result_dec_3 = (linear_result_dec_3f - mean) / (std + epsilon) * w + b
    print("after norm3")
    print(normalized_result_dec_3)
    print()
    return normalized_result_dec_3

def feedforward(dec_output_final, state_dict):
    W_ff=state_dict["fc.weight"]
    b_ff=state_dict["fc.bias"]
    final_op = dec_output_final@W_ff.T + b_ff

    return final_op

In [86]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        print(x.shape)
        return self.encoding[:, :x.size(1)].detach()



class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)

        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        print("Tgt mask shape = ", tgt_mask.shape)

        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        output = self.fc(output)

        return output


In [66]:
torch.manual_seed(42)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
d_ff = 20
max_seq_len = 5
dropout = 0


transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)


src_data = torch.tensor([[2], [1], [5], [4]])
tgt_data = torch.tensor([[1], [16], [5], [3]])

state_dict = transformer.state_dict()

In [67]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [88]:
def generate_mask(src, tgt):
    src_mask = None
    # Convert tgt to a PyTorch tensor if it's not already
    if not isinstance(tgt, torch.Tensor):
        tgt = torch.tensor(tgt)
    seq_length = tgt.size(0)
    nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

    return src_mask, nopeak_mask

src_mask, tgt_mask = generate_mask(src_data, tgt_data)

In [89]:
def get_all_intermediate_outputs_mask(src_sentence, tgt_sentence,d_model, num_heads ,state_dict, num_encoder_layers , num_decoder_layers, tgt_mask, d_ff):

    pe_src_embeds, pe_tgt_embeds = get_embedding_outputs(src_sentence=src_sentence,  tgt_sentence=tgt_sentence, state_dict=state_dict, max_seq_len=max_seq_len, d_model = d_model)
    print("encoder")
    print()
    x_enc = pe_src_embeds

    for lno in range(num_encoder_layers):
        attn_enc_output, src_len, head_dim, attn_weights = encoder_block_attention_output(x_enc, state_dict, layer_num = lno, need_weights = False, embed_dim = d_model, num_heads= num_heads, src_mask=None)

        if lno == 0:
            tgt_len, bsz, embed_dim = x_enc.shape

        output_enc_final = encoder_block_post_attn_output(x_enc, attn_enc_output, state_dict, layer_num = lno , bsz = bsz, tgt_len = tgt_len)

        x_enc = output_enc_final
    print("decoder")
    print()

    x_dec = pe_tgt_embeds
    memory = x_enc

    for lno in range(num_decoder_layers):

        self_attn_dec, dec_sa_wts = decoder_block_self_attn_output(x_dec, state_dict, layer_num = lno, need_weights = False, tgt_mask=tgt_mask, num_heads = num_heads)
        x_dec = dec_post_self_attn(self_attn_dec, x_dec, state_dict, layer_num = lno)

        attn_dec_mha_output, attn_dec_mha_wts = decoder_block_cross_attn_output(x_dec, memory, state_dict, num_heads = num_heads, layer_num = lno, tgt_len = tgt_len, src_len = src_len, head_dim = head_dim, need_weights = False, memory_mask=None)
        final_op = decoder_block_post_attn_output(x_dec, attn_dec_mha_output, state_dict, layer_num = lno)

        print(pe_tgt_embeds.shape, final_op.shape)
        x_dec = final_op

    final_op = feedforward(final_op, state_dict)

    return final_op


In [90]:

src_mask, tgt_mask = generate_mask(src_data, tgt_data)

In [91]:
state_dict.keys()

odict_keys(['src_embedding.weight', 'tgt_embedding.weight', 'transformer.encoder.layers.0.self_attn.in_proj_weight', 'transformer.encoder.layers.0.self_attn.in_proj_bias', 'transformer.encoder.layers.0.self_attn.out_proj.weight', 'transformer.encoder.layers.0.self_attn.out_proj.bias', 'transformer.encoder.layers.0.linear1.weight', 'transformer.encoder.layers.0.linear1.bias', 'transformer.encoder.layers.0.linear2.weight', 'transformer.encoder.layers.0.linear2.bias', 'transformer.encoder.layers.0.norm1.weight', 'transformer.encoder.layers.0.norm1.bias', 'transformer.encoder.layers.0.norm2.weight', 'transformer.encoder.layers.0.norm2.bias', 'transformer.encoder.layers.1.self_attn.in_proj_weight', 'transformer.encoder.layers.1.self_attn.in_proj_bias', 'transformer.encoder.layers.1.self_attn.out_proj.weight', 'transformer.encoder.layers.1.self_attn.out_proj.bias', 'transformer.encoder.layers.1.linear1.weight', 'transformer.encoder.layers.1.linear1.bias', 'transformer.encoder.layers.1.linear

In [101]:
final_op = get_all_intermediate_outputs_mask(src_data, tgt_data, state_dict = state_dict1, num_heads = num_heads, num_encoder_layers = 2 , num_decoder_layers = 2, d_model=d_model,  d_ff = d_ff, tgt_mask = tgt_mask)

TypeError: 'int' object is not callable

In [None]:
final_op

tensor([[[ 0.5040, -0.5643, -0.9291,  0.8937,  0.5021,  0.6562, -1.3417,
           0.5391,  0.0286,  0.4028,  0.6045,  0.0248,  0.4549, -0.3143,
          -1.1331,  0.7255,  0.1034, -0.2828, -0.1362,  0.0014]],

        [[ 0.2908, -0.4276, -1.1431,  0.8511,  0.3242,  0.2948, -1.3010,
           0.6434,  0.5933, -0.0025,  0.3320, -0.1345,  0.6465, -0.0753,
          -1.1380,  0.1948, -0.0462, -0.3189, -0.3971,  0.0195]],

        [[ 0.2088, -0.7398, -1.1280,  0.7504,  0.3378,  0.3129, -1.4142,
           0.7165,  0.6326, -0.1854,  0.2747, -0.0866,  0.7972, -0.0041,
          -1.0578,  0.4315,  0.0825, -0.2325, -0.3063,  0.0676]],

        [[ 0.4189, -0.6181, -0.7970,  1.2109,  0.7027,  0.0376, -0.7485,
           0.8167,  0.2986, -0.2507,  0.6037, -0.0616,  1.0253,  0.1367,
          -1.2094,  0.7994,  0.3362, -0.8986,  0.2187, -0.1611]]],
       grad_fn=<AddBackward0>)

In [None]:
#verification

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy
torch.manual_seed(42)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)#d_model is the the dimension of the embedding,maxlen-is the maximum length of sequence
        position = torch.arange(0, max_len).unsqueeze(1).float() #shape('max_len,1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model)) #calculates the positional encodings using the formula
        self.encoding[:, 0::2] = torch.sin(position * div_term)#even position
        self.encoding[:, 1::2] = torch.cos(position * div_term)#odd position
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return self.encoding[:, :x.size(1)].detach()

class TransformerModelLogger:
  def __init__(self, is_logging=False):
    self.is_logging = is_logging
    self.logs = []

  def log(self, message):
    if self.is_logging:
      self.logs.append(message)
      print(message)

  def print_logs(self):
     for message in self.logs:

      print(message)

      self.logs = []

class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_seq_len)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)
        self.logger = TransformerModelLogger(is_logging=True)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)

        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()#triangular matrix used to mask future positiions of the target sequence

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        self.logger.log("Tgt mask shape = " + str(tgt_mask.shape))

        src = self.src_embedding(src) + self.positional_encoding(src)
        self.logger.log(f"src (after embedding and positional encoding): {src.shape}")
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)
        self.logger.log(f"tgt (after embedding and positional encoding): {tgt.shape}")


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        self.logger.log(f"output (after transformer): {output.shape}")
        output = self.fc(output)
        print("output:{output}")

        return output



In [None]:
torch.manual_seed(42)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 2
num_decoder_layers = 2
d_ff = 20
max_seq_len = 5
dropout = 0

transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)



src_data = torch.tensor([[2], [1], [5], [4]])
tgt_data = torch.tensor([[1], [16], [5], [3]])


state_dict = transformer.state_dict()



In [None]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [None]:



transformer.train()



output = transformer(src_data, tgt_data)


print("output shape = ",output.shape)






Tgt mask shape = torch.Size([4, 4])
src (after embedding and positional encoding): torch.Size([4, 1, 16])
tgt (after embedding and positional encoding): torch.Size([4, 1, 16])
output (after transformer): torch.Size([4, 1, 16])
output:{output}
output shape =  torch.Size([4, 1, 20])


In [None]:
print(output)

tensor([[[ 0.5040, -0.5643, -0.9291,  0.8937,  0.5021,  0.6562, -1.3417,
           0.5391,  0.0286,  0.4028,  0.6045,  0.0248,  0.4549, -0.3143,
          -1.1332,  0.7255,  0.1034, -0.2828, -0.1362,  0.0014]],

        [[ 0.2908, -0.4276, -1.1432,  0.8511,  0.3242,  0.2948, -1.3010,
           0.6434,  0.5933, -0.0025,  0.3320, -0.1345,  0.6465, -0.0753,
          -1.1380,  0.1948, -0.0462, -0.3189, -0.3971,  0.0195]],

        [[ 0.2088, -0.7398, -1.1280,  0.7504,  0.3378,  0.3129, -1.4142,
           0.7165,  0.6326, -0.1854,  0.2747, -0.0866,  0.7972, -0.0041,
          -1.0578,  0.4315,  0.0825, -0.2325, -0.3063,  0.0676]],

        [[ 0.4189, -0.6181, -0.7970,  1.2109,  0.7027,  0.0376, -0.7485,
           0.8167,  0.2986, -0.2507,  0.6037, -0.0616,  1.0253,  0.1367,
          -1.2094,  0.7994,  0.3362, -0.8986,  0.2187, -0.1611]]],
       grad_fn=<ViewBackward0>)
