<a href="https://colab.research.google.com/github/karam-koujan/Transformer/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
import torch
import torch.nn as nn
from dataclasses import dataclass

def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_device()

@dataclass
class Config :
    model_d: int
    max_sequence_len: int
    ger_vocab_size: int
    eng_vocab_size : int
    n_layers : int
    n_heads : int
    eng_to_index : dict
    ger_to_index : dict
    start_token : str
    end_token : str
    pad_token: str
    dropout_p : int = 0.1
    hidden_size : int = 2048


In [13]:
class Tokenizer(nn.Module) :
      def __init__(self,config,layer_type) :
            super(Tokenizer,self).__init__()
            self.vocab_size = config.eng_vocab_size if layer_type == "encoder" else config.ger_vocab_size
            self.embedding = nn.Embedding(self.vocab_size,config.model_d)
            self.max_sequence_length = config.max_sequence_len
            self.language_to_index =  config.eng_to_index if layer_type == "encoder" else config.ger_to_index
            self.start_token = config.start_token
            self.end_token = config.end_token
            self.pad_token = config.pad_token
            self.dropout = nn.Dropout(p=config.dropout_p)
            self.positional_embedding = Positional_embedding(config)
      def batch_tokenization(self,batch,is_starttoken,is_endtoken) :
              def sentence_tokenize(sentence,is_starttoken,is_endtoken) :
                      sentence_to_index = []
                      for tokenIdx,token in enumerate(list(sentence)) :
                              if tokenIdx + int(is_starttoken) + int(is_endtoken) + 1 >= self.max_sequence_length :
                                          break
                              if token in self.language_to_index:
                                      sentence_to_index.append(self.language_to_index[token] )
                              else :
                                      sentence_to_index.append(self.language_to_index['<unk>'])
                      if is_starttoken :
                            sentence_to_index.insert(0,self.language_to_index[self.start_token])
                      if is_endtoken :
                            sentence_to_index.append(self.language_to_index[self.end_token])

                      for _ in range(len(sentence_to_index),self.max_sequence_length):
                                  sentence_to_index.append(self.language_to_index[self.pad_token])
                      return torch.tensor(sentence_to_index).to(device)

              sentence_batch = []
              for sentence_idx in range(len(batch)) :
                    sentence_batch.append(sentence_tokenize(batch[sentence_idx],is_starttoken,is_endtoken))

              sentence_batch = torch.stack(sentence_batch)
              return sentence_batch.to(device)
      def forward(self,x,is_starttoken,is_endtoken) :
            # (batch,vocab_size,embed_d)
            print("============== Tokenization ===============")
            x = self.batch_tokenization(x,is_starttoken,is_endtoken)
            x = self.embedding(x)
            pos = self.positional_embedding(x)
            x = self.dropout(x+pos)
            return x


In [14]:
class Positional_embedding(nn.Module):
      def __init__(self,config) :
            super(Positional_embedding,self).__init__()
            self.model_d = config.model_d
            self.max_len = config.max_sequence_len
            self.positional_embedding = torch.zeros((self.max_len,self.model_d)).to(device)
            pos = torch.arange(0,self.max_len,dtype=torch.float).unsqueeze(1)
            div = torch.pow(1000,torch.arange(0,self.model_d,2,dtype=torch.float)/self.model_d)
            self.positional_embedding[:,0::2] = torch.sin(pos/div)
            self.positional_embedding[:,1::2] = torch.cos(pos/div)
      def forward(self,x) :
            print("============ positional embedding ==========",x.size(),self.positional_embedding.size())
            x = x +  self.positional_embedding
            return x


In [15]:
class LayerNorm(nn.Module):
      def __init__(self,epsilon=1e-6):
              super(LayerNorm,self).__init__()
              self.epsilon = epsilon

      def forward(self,x) :
              print("============== LayerNormalization ===============", x.size())
              batch_size,seq_length,model_d = x.shape
              gamma = nn.Parameter(torch.ones(model_d).to(device))
              beta = nn.Parameter(torch.zeros(model_d).to(device))
              mean = x.mean(-1,keepdim=True)
              var = x.var(-1,keepdim=True)
              x_normalized = (x - mean) / torch.sqrt(var + self.epsilon)

              x = gamma * x_normalized + beta
              return x

In [16]:
class MultiHeadAttention(nn.Module) :
      def __init__(self,config) :
                  super(MultiHeadAttention,self).__init__()
                  assert config.model_d % config.n_heads == 0  , f"{config.model_d} should be divisible by {config.n_heads}"
                  self.max_sequence_len = config.max_sequence_len
                  self.model_d = config.model_d
                  self.heads_num = config.n_heads
                  self.qkv_d = self.model_d // self.heads_num
                  self.queryP = nn.Linear(self.model_d,self.model_d)
                  self.valueP = nn.Linear(self.model_d,self.model_d)
                  self.keyP = nn.Linear(self.model_d,self.model_d)
                  self.out = nn.Linear(self.model_d,self.model_d)
      def attention(self,q,k,v,mask=None) :
             dk = torch.tensor(k.shape[-1],dtype=torch.float32)
             energy = torch.matmul(q,k.transpose(-2,-1)) / torch.sqrt(dk)
             if mask != None  :
                        print("________Masking__________")
                        energy = torch.permute(energy,(1,0,2,3))
                        print("energy shape",energy.shape)
                        print("mask shape",mask.shape)
                        energy =  energy.masked_fill(mask != 0, float('-1e9'))
                        energy = torch.permute(energy,(1,0,2,3))

             return torch.matmul(torch.softmax(energy,dim=-1),v)
      def forward(self,x,mask=None):
              """
               first we create key,query and value using a linear projection using a 1 fully connected layer
               The size of these tensors is (input_sequence_length,model_d)
              """
              print("============== MultiHeadAttention ===============")
              query= None
              key = None
              value = None
              batch_size = None
              if isinstance(x,(list,tuple)) :
                   query,key,value = x
                   batch_size = query.size(0)
                   query = self.queryP(query)
                   key = self.keyP(key)
                   value = self.valueP(value)

              else :
                   batch_size = x.size(0)
                   query = self.queryP(x)
                   key = self.keyP(x)
                   value = self.valueP(x)

              """
               we add another dimension for heads now the tensors size is (heads_num,input_sequence,model_d)
               calculte attention for each head independently and in parallel
              """
              query = query.view(batch_size,self.heads_num,self.max_sequence_len,self.qkv_d)
              key = key.view(batch_size, self.heads_num,self.max_sequence_len, self.qkv_d)
              value = value.view(batch_size, self.heads_num,self.max_sequence_len, self.qkv_d)
              attention = self.attention(query, key, value, mask)
              attention = attention.view(batch_size, self.max_sequence_len, self.model_d)
              out = self.out(attention)
              return out



In [17]:
import torch.nn.functional as F
class FeedForward(nn.Module) :
        def __init__(self,input_size,output_size,hidden_size,dropout_p=0.1) :
                  super(FeedForward,self).__init__()
                  self.fc1 =  nn.Linear(input_size,hidden_size)
                  self.fc2 = nn.Linear(hidden_size,output_size)
                  self.dropout = nn.Dropout(p=dropout_p)
        def forward(self,x) :
                  print("============== FeedForward NN  ===============", x.size())
                  x = F.relu(self.fc1(x))
                  x = self.dropout(x)
                  x = self.fc2(x)
                  return x

In [18]:
class EncoderLayer(nn.Module) :
     def __init__(self,config) :
        super(EncoderLayer,self).__init__()
        self.model_d = config.model_d
        self.max_length = config.max_sequence_len
        self.hidden_size = config.hidden_size
        self.attention_heads = config.n_heads
        self.dropout1 = nn.Dropout(p=config.dropout_p)
        self.dropout2 = nn.Dropout(p=config.dropout_p)
        self.multi_head_attention =  MultiHeadAttention(config)
        self.layernorm = LayerNorm()
        self.fc = FeedForward(self.model_d,self.model_d,self.hidden_size)

     def forward(self,x,mask=None) :
              res_x = x.clone()
              x = self.multi_head_attention(x,mask)
              x = self.dropout1(x)
              x = self.layernorm(x + res_x )
              res_x = x.clone()
              x = self.fc(x)
              x = self.dropout2(x)
              x =  self.layernorm(x + res_x )
              return x


In [19]:
class EncoderLayers(nn.Sequential) :

          def forward(self, x,mask=None):
                     for module in self._modules.values():
                          x = module(x,mask)
                     return x

In [20]:
class Encoder(nn.Module):
           def __init__(self,config) :
                super(Encoder,self).__init__()
                self.layers  = EncoderLayers(*[ EncoderLayer(config) for _ in range(config.n_layers)])
                self.tokenizer = Tokenizer(config,"encoder")
           def forward(self,x,mask,is_starttoken,is_endtoken) :
                   x = self.tokenizer(x,is_starttoken,is_endtoken)
                   x = self.layers(x,mask)
                   return x



In [21]:
class DecoderLayer(nn.Module) :
        def __init__(self,config) :
            super(DecoderLayer,self).__init__()
            self.model_d = config.model_d
            self.max_sequence_length = config.max_sequence_len
            self.hidden_size = config.hidden_size
            self.attention_heads = config.n_heads
            self.dropout1 = nn.Dropout(p=config.dropout_p)
            self.dropout2 = nn.Dropout(p=config.dropout_p)
            self.dropout3 = nn.Dropout(p=config.dropout_p)
            self.multi_head_attention =  MultiHeadAttention(config)
            self.layernorm = LayerNorm()
            self.fc = FeedForward(self.model_d,self.model_d,self.hidden_size)
        def forward(self,x,encoder_out,att_mask,pad_mask) :
                    res_x = x.clone()
                    x = self.multi_head_attention(x,att_mask)
                    x = self.dropout1(x)
                    x = self.layernorm(x + res_x )
                    res_x = x.clone()
                    x  = self.multi_head_attention((x,encoder_out,encoder_out),pad_mask)
                    x = self.dropout2(x)
                    x = self.layernorm(x + res_x )
                    res_x = x.clone()
                    x =  self.fc(x)
                    x = self.dropout3(x)
                    x = self.layernorm(x + res_x )
                    return x


In [22]:
class DecoderLayers(nn.Sequential) :

          def forward(self, x, encoder_out,att_mask,pad_mask):
                     for module in self._modules.values():
                          x = module(x,encoder_out,att_mask,pad_mask)
                     return x

In [23]:
class Decoder(nn.Module):
           def __init__(self,config) :
                super(Decoder,self).__init__()
                self.layers  = DecoderLayers(*[ DecoderLayer(config) for _ in range(config.n_layers)])
                self.tokenizer = Tokenizer(config,"decoder")

           def forward(self,x,encoder_out,att_mask,pad_mask,is_starttoken,is_endtoken) :
                   x = self.tokenizer(x,is_starttoken,is_endtoken)
                   x = self.layers(x,encoder_out,att_mask,pad_mask)
                   return  x



In [24]:
class Transformer(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.linear = nn.Linear(config.model_d, config.ger_vocab_size)
        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    def forward(self,
                x,
                y,
                encoder_pad_mask=None,
                decoder_att_mask=None,
                decoder_pad_mask=None,
                enc_start_token=False,
                enc_end_token=False,
                dec_start_token=False, # We should make this true
                dec_end_token=False): # x, y are batch of sentences
        x = self.encoder(x,encoder_pad_mask ,is_starttoken=enc_start_token, is_endtoken=enc_end_token)
        out = self.decoder(y,x,decoder_att_mask,decoder_pad_mask, is_starttoken=dec_start_token, is_endtoken=dec_end_token)
        out = self.linear(out)
        return out

In [25]:
import numpy as np


file_path = "./deu.txt"
start_token = '<sof>'
end_token = '<eof>'
pad_token = '<pad>'
english_vocabulary = [start_token,'<unk>', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/','`','’',
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', '<', '=', '>', '?', '@','[', '\\', ']', '^', '_', '`',
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
                        'y', 'z',
                        '{', '|', '}', '~', pad_token, end_token]

german_vocabulary = [start_token,'<unk>',' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/','`','’',
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                      ':', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`',
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                      'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x',
                      'y', 'z', 'ä', 'ö', 'ü', 'ß',
                      '{', '|', '}', '~', pad_token, end_token]


In [26]:
index_to_german = {k : v  for k,v in enumerate(german_vocabulary)}
german_to_index = {v : k  for k,v in enumerate(german_vocabulary)}
index_to_english = {k : v  for k,v in enumerate(english_vocabulary)}
english_to_index = {v : k  for k,v in enumerate(english_vocabulary)}

In [27]:
with open(file_path, 'r') as file:
    raw_data = file.readlines()

In [28]:
sentences =  [ (sentence.rstrip("\n").split("\t")[0].lower(),sentence.rstrip("\n").split("\t")[1].lower()) for sentence in raw_data]
sentences

[('go.', 'geh.'),
 ('hi.', 'hallo!'),
 ('hi.', 'grüß gott!'),
 ('run!', 'lauf!'),
 ('run.', 'lauf!'),
 ('wow!', 'potzdonner!'),
 ('wow!', 'donnerwetter!'),
 ('fire!', 'feuer!'),
 ('help!', 'hilfe!'),
 ('help!', 'zu hülf!'),
 ('stop!', 'stopp!'),
 ('wait!', 'warte!'),
 ('wait.', 'warte.'),
 ('begin.', 'fang an.'),
 ('go on.', 'mach weiter.'),
 ('hello!', 'hallo!'),
 ('hurry!', 'beeil dich!'),
 ('hurry!', 'schnell!'),
 ('i hid.', 'ich versteckte mich.'),
 ('i hid.', 'ich habe mich versteckt.'),
 ('i ran.', 'ich rannte.'),
 ('i see.', 'ich verstehe.'),
 ('i see.', 'aha.'),
 ('i try.', 'ich probiere es.'),
 ('i won!', 'ich hab gewonnen!'),
 ('i won!', 'ich habe gewonnen!'),
 ('relax.', 'entspann dich.'),
 ('shoot!', 'feuer!'),
 ('shoot!', 'schieß!'),
 ('smile.', 'lächeln!'),
 ('ask me.', 'frag mich!'),
 ('ask me.', 'fragt mich!'),
 ('ask me.', 'fragen sie mich!'),
 ('attack!', 'angriff!'),
 ('attack!', 'attacke!'),
 ('cheers!', 'zum wohl!'),
 ('eat it.', 'iss es.'),
 ('eat up.', 'iss auf.'

In [29]:
import numpy as np
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x[0]) for x in sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length German: {np.percentile([len(x[1]) for x in sentences], PERCENTILE)}" )



97th percentile length English: 40.0
97th percentile length German: 52.0


In [30]:
def is_token_exist(sentence,vocab):
     for token in sentence :
           if token not in vocab :
                 return False
     return True

def  is_valid_length(sentence,max_sequence_length) :
           return len(sentence) < max_sequence_length - 1

is_token_exist('sie geht zu fuß.',german_vocabulary)


True

In [31]:
max_sequence_length = 80
valid_sentence_indicies = []
for index in range(len(sentences)):
    german_sentence, english_sentence = sentences[index][1], sentences[index][0]
    if is_valid_length(german_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_token_exist(german_sentence, german_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 180050
Number of valid sentences: 179470


In [32]:
model_d = 512
batch_size = 30
n_heads = 8
n_layers = 1
max_sequence_len = 80
ger_vocab_size = len(german_vocabulary)
eng_vocab_size = len(english_vocabulary)
config = Config(model_d=model_d, max_sequence_len=max_sequence_len, ger_vocab_size=ger_vocab_size, eng_vocab_size=eng_vocab_size, n_layers=n_layers, n_heads=n_heads, eng_to_index=english_to_index, ger_to_index=german_to_index, start_token=start_token, end_token=end_token, pad_token=pad_token)
transformer = Transformer(config)

In [33]:
from torch.utils.data import DataLoader,Dataset
class TextDataset(Dataset) :
    def __init__(self,sentences) :
         self.sentences = sentences


    def __len__(self) :
         return len(self.sentences)
    def __getitem__(self,idx):
           german = self.sentences[idx][1]
           english = self.sentences[idx][0]
           return english,german


In [34]:
dataset = TextDataset(sentences)

In [35]:
print(dataset[3],len(dataset))

('run!', 'lauf!') 180050


In [36]:
criterian = nn.CrossEntropyLoss(ignore_index=german_to_index[pad_token])
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)
optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)


In [37]:
import numpy as np

def create_masks(eng_batch, ger_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.zeros((max_sequence_len, max_sequence_len))
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.zeros((num_sentences, max_sequence_len, max_sequence_len))
    decoder_padding_mask_self_attention = torch.zeros((num_sentences, max_sequence_len, max_sequence_len))
    decoder_padding_mask_cross_attention = torch.zeros((num_sentences, max_sequence_len, max_sequence_len))

    for idx in range(num_sentences):
      eng_sentence_length, ger_sentence_length = len(eng_batch[idx]), len(ger_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length  , max_sequence_len)
      ger_chars_to_padding_mask = np.arange(ger_sentence_length , max_sequence_len)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = 1
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = 1
      decoder_padding_mask_self_attention[idx, :, ger_chars_to_padding_mask] = 1
      decoder_padding_mask_self_attention[idx, ger_chars_to_padding_mask, :] = 1
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = 1
      decoder_padding_mask_cross_attention[idx, ger_chars_to_padding_mask, :] = 1
    decoder_self_attention_mask =   decoder_padding_mask_self_attention + look_ahead_mask
    return encoder_padding_mask, decoder_self_attention_mask, decoder_padding_mask_cross_attention,

In [38]:
train_loader = DataLoader(dataset,batch_size)

In [39]:
!pip install lightning


Collecting lightning
  Downloading lightning-2.2.2-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities<2.0,>=0.8.0 (from lightning)
  Downloading lightning_utilities-0.11.2-py3-none-any.whl (26 kB)
Collecting torchmetrics<3.0,>=0.7.0 (from lightning)
  Downloading torchmetrics-1.3.2-py3-none-any.whl (841 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m841.5/841.5 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.2.2-py3-none-any.whl (801 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m801.9/801.9 kB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch<4.0,>=1.13.0->lightning)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1

In [40]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.45.0-py2.py3-none-any.whl (267 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m267.1/267.1 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->w

In [41]:
import wandb

In [42]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [43]:
import lightning as L


class LitTransformer(L.LightningModule):
     def __init__(self, Transformer,config) -> None:
            super().__init__()
            self.save_hyperparameters()
            self.transformer = Transformer(config)

     def training_step(self,batch,batch_idx) :
              eng_batch, ger_batch = batch
              encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, ger_batch)
              ger_predictions = self.transformer(eng_batch,
                                     ger_batch,
                                     encoder_self_attention_mask.to(device),
                                     decoder_self_attention_mask.to(device),
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
              labels = transformer.decoder.tokenizer.batch_tokenization(ger_batch, is_starttoken=False, is_endtoken=True)
              loss = criterian(
                  ger_predictions.view(-1, ger_vocab_size).to(device),
                    labels.view(-1).to(device)
                 ).to(device)
              valid_indicies = torch.where(labels.view(-1) == german_to_index[pad_token], False, True)
              loss = loss.sum() / valid_indicies.sum()
              self.log('train/loss',loss)
              return loss
     def configure_optimizers(self) :
                     optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
                     return optim

In [44]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer
model = LitTransformer(Transformer,config)
wandblogger = WandbLogger(project="transformer",log_model="all")
trainer = Trainer(max_epochs=1,logger=wandblogger)
trainer.fit(model,train_loader)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mkaramkaku2000[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name        | Type        | Params
--------------------------------------------
0 | transformer | Transformer | 6.4 M 
--------------------------------------------
6.4 M     Trainable params
0         Non-trainable params
6.4 M     Total params
25.674    Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name        | Type        | Params
--------------------------------------------
0 | transformer | Transformer | 6.4 M 
--------------------------------------------
6.4 M     Trainable params
0         Non-trainable params
6.4 M     Total params
25.674    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask shape torch.Size([30, 80, 80])
________Masking__________
energy shape torch.Size([8, 30, 80, 80])
mask sha

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
