In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from typing import Any
from tokenizers import Tokenizer

In [2]:
class BiDataset(Dataset):
  def __init__(self,ds : list,tokenizer_src : Tokenizer, tokenizer_tgt :Tokenizer , lang_src, lang_tgt, seq_len):
    super().__init__()
    self.seq_len = seq_len
    self.ds = ds
    self.tokenizer_src = tokenizer_src
    self.tokenizer_tgt = tokenizer_tgt
    self.src_lang = lang_src
    self.tgt_lang = lang_tgt

    self.sos_token = torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype = torch.int64)
    self.eos_token = torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype = torch.int64)
    self.pad_token = torch.tensor([tokenizer_src.token_to_id("[PAD]")], dtype = torch.int64)
    #tokenizer_src.token_to_id("[SOS]")
    #tokenizer_src.token_to_id("[EOS]")
    #tokenizer_src.token_to_id("[PAD]")

  def __len__(self):
    return len(self.ds)
  def __getitem__(self, index ) :
    src_target_pair = self.ds[index]
    src_text = src_target_pair['translation'][self.src_lang]
    tgt_text = src_target_pair['translation'][self.tgt_lang]
    encode_input_token = self.tokenizer_src.encode(src_text).ids
    decode_input_token = self.tokenizer_tgt.encode(tgt_text).ids

    encode_num_padding_tokens = self.seq_len - len(encode_input_token) - 2
    decode_num_padding_tokens = self.seq_len - len(decode_input_token) - 1

    if encode_num_padding_tokens < 0 or decode_num_padding_tokens < 0:
      raise ValueError('Do dai cau dang co van de')

    encoder_input = torch.cat(
        [
            self.sos_token,
            torch.tensor(encode_input_token, dtype = torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token]* encode_num_padding_tokens, dtype = torch.int64)
        ]
    )
    decoder_input = torch.cat(

        [
            self.sos_token,
            torch.tensor(decode_input_token, dtype = torch.int64),
            torch.tensor([self.pad_token] * decode_num_padding_tokens, dtype=torch.int64)

        ]

    )
    label = torch.cat(
        [
            torch.tensor(decode_input_token, dtype = torch.int64),
            self.eos_token,
            torch.tensor([self.pad_token] * decode_num_padding_tokens, dtype=torch.int64)
        ]
    )

    assert encoder_input.size(0) == self.seq_len
    assert decoder_input.size(0) == self.seq_len
    assert label.size(0) == self.seq_len
    return {
        "encoder_input": encoder_input,
        "decoder_input": decoder_input,
        "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(),
        "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(self.seq_len),
        "label" : label,
        "src_text": src_text,
        "tgt_text": tgt_text
    }




In [3]:
def causal_mask(seq_len):
  mask = torch.triu(torch.ones(1,seq_len,seq_len), diagonal=1).type(torch.int)
  return mask == 0