In [2]:
class MultiHeadAttention(nn.Module):
    """
    Tối ưu hóa bằng cách:
    - Sử dụng F.scaled_dot_product_attention (Flash Attention trong PyTorch 2.0+)
    - Giảm số phép tính lại không cần thiết
    - Giảm số biến trung gian tạm thời
    """

    def __init__(self, emb_dim, num_heads=4, dropout=0.1, at_mask=False):
        """
        Args:
            emb_dim: Kích thước embedding
            num_heads: Số lượng đầu attention
            dropout: Tỷ lệ dropout
            at_mask: Nếu True, sử dụng causal mask (dùng cho decoder)
        """
        super(MultiHeadAttention, self).__init__()
        assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"

        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads if emb_dim % num_heads == 0 else emb_dim // (num_heads - 1)
        self.at_mask = at_mask

        self.QKV_linear = nn.Linear(emb_dim, emb_dim * 3, bias=False)
        self.out_linear = nn.Linear(emb_dim, emb_dim)
        self.dropout = dropout

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, emb_dim) - input tensor
            mask: (attn_mask) (seq_len, seq_len) or (batch_size, seq_len, seq_len) khi evn có PyTorch 2.0+
        Return:
           (batch_size, seq_len, emb_dim)
        """
        batch_size, seq_len, emb_dim = x.size()
        qkv = self.QKV_linear(x)  # (batch_size, seq_len, emb_dim * 3)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch_size, num_heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (batch_size, num_heads, seq_len, head_dim)

        if hasattr(F, "scaled_dot_product_attention"):
            attn_output = F.scaled_dot_product_attention(
                q, k, v,
                attn_mask=mask,
                dropout_p=self.dropout if self.training else 0.0,
                is_causal=self.at_mask
            )  # (batch_size, num_heads, seq_len, head_dim)
        #      bx = 2 yêu cầu phải có mask dạng (batch_size, seq_len, seq_len)
        else:
            # Fallback to manual implementation
            scale = 1.0 / math.sqrt(self.head_dim)
            scores = torch.matmul(q, k.transpose(-2, -1)) * scale

            if self.atmask:
                mask = torch.triu(torch.ones((seq_len, seq_len), device=x.device), diagonal=1).bool()
                scores = scores.masked_fill(mask, float('-inf'))

            attn = F.softmax(scores, dim=-1)
            attn = F.dropout(attn, p=self.dropout_p, training=self.training)
            attn_output = torch.matmul(attn, v)

        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()  # (batch_size, seq_len, num_heads, head_dim)
        attn_output = attn_output.reshape(batch_size, seq_len, emb_dim)  # (batch_size, seq_len, emb_dim)
        attn_output = self.out_linear(attn_output)  # (batch_size, seq_len, emb_dim)

        return attn_output


class Embedding(nn.Module):
    """
    Tối ưu hóa bằng cách:
    - Cache positional encoding (không tính lại mỗi forward)
    - Sử dụng register_buffer để tự động chuyển device
    - Broadcasting thay vì repeat
    - Pre-compute trong __init__
    - Thêm dropout cho regularization
    """

    def __init__(self, vocab_size, emb_dim=512, max_seq_len=5000, dropout=0.1):
        """
        Args:
            vocab_size: Kích thước từ vựng
            emb_dim: Kích thước embedding
            max_seq_len: Độ dài chuỗi tối đa
            dropout: Tỷ lệ dropout
        """
        super(Embedding, self).__init__()
        self.emb_dim = emb_dim
        self.vocab_size = vocab_size
        self.dropout = nn.Dropout(dropout)

        # Token embedding
        self.token_emb = nn.Embedding(vocab_size, emb_dim)

        # Pre-compute positional encoding và cache
        pe = self._create_positional_encoding(max_seq_len, emb_dim)
        # Register as buffer (không train, tự động chuyển device)
        self.register_buffer('pe', pe)

    def _create_positional_encoding(self, max_seq_len, emb_dim):
        """Pre-compute positional encoding một lần duy nhất"""
        pe = torch.zeros(max_seq_len, emb_dim)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)

        # Tối ưu công thức div_term
        div_term = torch.exp(
            torch.arange(0, emb_dim, 2).float() * (-math.log(10000.0) / emb_dim)
        )

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        return pe.unsqueeze(0)  # (1, max_seq_len, emb_dim)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len) - token indices
        Returns:
            (batch_size, seq_len, emb_dim)
        """
        batch_size, seq_len = x.size()

        # Token embedding
        token_emb = self.token_emb(x)  # (batch_size, seq_len, emb_dim)

        # Positional encoding - sử dụng broadcasting (không cần repeat)
        pos_emb = self.pe[:, :seq_len, :]  # (1, seq_len, emb_dim)

        # Combine (broadcasting tự động)
        x = token_emb + pos_emb

        return self.dropout(x)


class Feedforward(nn.Module):
    """
    Tối ưu hóa bằng cách:
        - Sử dụng SwiGLU thay vì ReLU
        - Thêm dropout cho regularization
    Ở đây sử dụng SwiGLU: SwiGLU(x) = (Swish(W1(x)) ⊙ W3(x)) W2
       where Swish(x) = x * sigmoid(x) = SiLU(x)
    """

    def __init__(self, d_model=512, d_ff=None, dropout=0.1, bias=False):
        """
        Args:
            d_model: Kích thước embedding đầu vào
            d_ff: Kích thước ẩn của feedforward network
            dropout: Tỷ lệ dropout
            bias: Sử dụng bias trong các lớp Linear hay không
        """
        super(Feedforward, self).__init__()

        if d_ff is None:
            # LLaMA uses 8/3 * d_model to compensate for gating
            d_ff = int(8 * d_model / 3)
            # Round to nearest multiple of 256 for efficiency
            d_ff = 256 * ((d_ff + 255) // 256)

        # Three projections: W1 (gate), W2 (down), W3 (up)
        self.w1 = nn.Linear(d_model, d_ff, bias=bias)  # Gate
        self.w2 = nn.Linear(d_ff, d_model, bias=bias)  # Down projection
        self.w3 = nn.Linear(d_model, d_ff, bias=bias)  # Up projection
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, d_model)
        Returns:
            (batch_size, seq_len, d_model)
        """
        # SwiGLU(x) = (Swish(W1(x)) ⊙ W3(x)) W2
        # where Swish(x) = x * sigmoid(x) = SiLU(x)
        gate = F.silu(self.w1(x))
        x = self.w3(x)
        x = gate * x  # Element-wise multiplication
        x = self.dropout(x)
        x = self.w2(x)
        x = self.dropout(x)
        return x


class CrossAttention(nn.Module):
    """
    tối ưu hóa bằng cách:
        - Sử dụng F.scaled_dot_product_attention (Flash Attention trong PyTorch 2.0+)
    """

    def __init__(self, emb_dim, num_heads=8, dropout=0.1, bias=True):
        super(CrossAttention, self).__init__()
        assert emb_dim % num_heads == 0, "emb_dim must be divisible by num_heads"

        self.emb_dim = emb_dim
        self.num_heads = num_heads
        self.head_dim = emb_dim // num_heads
        self.dropout_p = dropout

        # Projections
        self.query_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.key_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.value_proj = nn.Linear(emb_dim, emb_dim, bias=bias)
        self.out_proj = nn.Linear(emb_dim, emb_dim, bias=bias)

    def forward(self, query, key_value, attn_mask=None, key_padding_mask=None):
        """
        Args:
            query: (batch_size, seq_len_q, emb_dim) - from decoder
            key_value: (batch_size, seq_len_kv, emb_dim) - from encoder
            attn_mask:  (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv) khi evn có PyTorch 2.0+
            key_padding_mask: (batch_size, seq_len_kv) - True for positions to ignore khi evn có PyTorch < 2.0
        Returns:
            (batch_size, seq_len_q, emb_dim)
        """
        batch_size, seq_len_q, _ = query.size()
        seq_len_kv = key_value.size(1)

        # Project and reshape
        Q = self.query_proj(query).view(batch_size, seq_len_q, self.num_heads, self.head_dim)
        K = self.key_proj(key_value).view(batch_size, seq_len_kv, self.num_heads, self.head_dim)
        V = self.value_proj(key_value).view(batch_size, seq_len_kv, self.num_heads, self.head_dim)

        # Transpose to (batch_size, num_heads, seq_len, head_dim)
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Use PyTorch's optimized scaled_dot_product_attention
        if hasattr(F, 'scaled_dot_product_attention'):
            out = F.scaled_dot_product_attention(
                Q, K, V,
                attn_mask=attn_mask,
                dropout_p=self.dropout_p if self.training else 0.0,
                is_causal=False  # Cross-attention is not causal
            )
        else:
            # Fallback
            scale = 1.0 / math.sqrt(self.head_dim)
            scores = torch.matmul(Q, K.transpose(-2, -1)) * scale

            if key_padding_mask is not None:
                key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
                scores = scores.masked_fill(key_padding_mask, float('-inf'))

            if attn_mask is not None:
                scores = scores + attn_mask

            attn = F.softmax(scores, dim=-1)
            attn = F.dropout(attn, p=self.dropout_p, training=self.training)
            out = torch.matmul(attn, V)

        # Reshape and project
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, seq_len_q, self.emb_dim)
        out = self.out_proj(out)

        return out


class Encoder(nn.Module):
    """
    Encoder block gồm:
    - Multi-Head Attention với residual connection và layer normalization
    - Feedforward network với residual connection và layer normalization

    Diagram of Encoder block:

    Encoder
    ├────────────────────────────────┐
    ├───Multi-Head Attention         │
    │   └───Head Attention x 4       │
    ├ + <────────────────────────────┘
    │
    ├───nn.LayerNorm
    ├────────────────────────────────┐
    ├───Feedforward                  │
    ├ + <────────────────────────────┘
    └───nn.LayerNorm
    """

    def __init__(self, dmodel=512, num_heads=4, d_ff=None, dropout=0.1):
        """
        Args:
            dmodel: Kích thước embedding
            num_heads: Số lượng head attention
            dropout: Tỷ lệ dropout
        """
        super(Encoder, self).__init__()
        self.mha = MultiHeadAttention(emb_dim=dmodel, num_heads=num_heads, dropout=dropout)
        self.ffn = Feedforward(d_model=dmodel, dropout=dropout, d_ff=d_ff)
        self.norm1 = nn.LayerNorm(dmodel)
        self.norm2 = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: (batch_size, seq_len, dmodel)
            mask:  (batch_size, seq_len, seq_len) khi evn có PyTorch 2.0+
        Returns:
            (batch_size, seq_len, dmodel)
        """
        # Multi-Head Attention with Residual Connection
        attn_output = self.mha(x, mask=mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)

        # Feedforward Network with Residual Connection
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.norm2(x)

        return x


class Decoder(nn.Module):
    """
    Decoder block gồm:
    - Masked Multi-Head Attention với residual connection và layer normalization
    - Cross-Attention với residual connection và layer normalization
    - Feedforward network với residual connection và layer normalization
    Diagram of Decoder block:

    Decoder
    ├────────────────────────────────┐
    ├───Masked Multi-Head Attention  │
    │   └───Head Attention x 4       │
    ├ + <────────────────────────────┘
    │
    ├───nn.LayerNorm
    ├────────────────────────────────┐
    ├───Cross-Attention              │
    │   └───Head Attention x 4       │
    ├ + <────────────────────────────┘
    │
    ├───nn.LayerNorm
    ├────────────────────────────────┐
    ├───Feedforward                  │
    ├ + <────────────────────────────┘
    └───nn.LayerNorm
    """

    def __init__(self, dmodel=512, num_heads=4, d_ff=None, dropout=0.1):
        """
        Args:
            dmodel : Kích thước embedding
            num_heads : Số lượng head attention
            dropout : Tỷ lệ dropout
        """
        super(Decoder, self).__init__()
        self.mha = MultiHeadAttention(emb_dim=dmodel, num_heads=num_heads, dropout=dropout, at_mask=True)
        self.cross_attn = CrossAttention(emb_dim=dmodel, num_heads=num_heads, dropout=dropout)
        self.ffn = Feedforward(d_model=dmodel, dropout=dropout, d_ff=d_ff)
        self.norm1 = nn.LayerNorm(dmodel)
        self.norm2 = nn.LayerNorm(dmodel)
        self.norm3 = nn.LayerNorm(dmodel)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, tgt_mask=None, src_mask=None):
        """
        Args:
            X: (batch_size, tgt_seq_len, dmodel) - input tensor to decoder
            enc_output: (batch_size, src_seq_len, dmodel) - output from encoder
            src_mask: (batch_size, tgt_seq_len, src_seq_len) khi evn có PyTorch 2.0+
            tgt_mask: (batch_size, src_seq_len) khi evn có PyTorch < 2.0
        Returns:
            (batch_size, tgt_seq_len, dmodel)
        """
        # Masked Multi-Head Attention with Residual Connection
        attn_output = self.mha(x, mask=tgt_mask)  # multi-head self-attention
        x = x + self.dropout(attn_output)  # residual connection
        x = self.norm1(x)

        # Cross-Attention with Residual Connection
        cross_attn_output = self.cross_attn(x, enc_output, attn_mask=src_mask,
                                            key_padding_mask=tgt_mask)  # cross-attention
        x = x + self.dropout(cross_attn_output)  # residual connection
        x = self.norm2(x)

        # Feedforward Network with Residual Connection
        ffn_output = self.ffn(x)
        x = x + self.dropout(ffn_output)
        x = self.norm3(x)

        return x



In [3]:
class Transformer(nn.Module):
    """
       Optimized Transformer với:
       - Pre-LayerNorm
       - Optimized components
       - Proper masking
       - Gradient checkpointing support
       - Mixed precision support
       """

    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model=512,
                 num_heads=8,
                 num_encoder_layers=6,
                 num_decoder_layers=6,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_len=5000,
                 pad_idx=(0, 1),
                 use_gradient_checkpointing=False):
        """
        Khởi tạo mô hình Transformer
            Args:
                src_vocab_size (int): Kích thước từ vựng nguồn
                tgt_vocab_size (int): Kích thước từ vựng đích
                d_model (int): Kích thước embedding và mô hình
                num_heads (int): Số lượng đầu attention
                num_encoder_layers (int): Số lớp encoder
                num_decoder_layers (int): Số lớp decoder
                d_ff (int): Kích thước của feed-forward layer
                dropout (float): Tỷ lệ dropout
                max_seq_len (int): Độ dài tối đa của chuỗi
                pad_idx (tuple): Chỉ số padding trong từ vựng
                use_gradient_checkpointing (bool): Sử dụng gradient checkpointing để tiết kiệm bộ nhớ
        """
        super(Transformer, self).__init__()
        # prame init
        self.d_model = d_model
        self.pad_idx = pad_idx
        self.use_gradient_checkpointing = use_gradient_checkpointing

        # embedding
        self.src_embedding = Embedding(src_vocab_size, d_model, max_seq_len, dropout)
        self.tgt_embedding = Embedding(tgt_vocab_size, d_model, max_seq_len, dropout)
        # encoder
        self.encoder = nn.ModuleList([
            Encoder(d_model, num_heads, d_ff, dropout)
            for _ in range(num_encoder_layers)
        ])
        self.encoder_norm = nn.LayerNorm(d_model)
        # decoder
        self.decoder = nn.ModuleList([
            Decoder(d_model, num_heads, d_ff, dropout)
            for _ in range(num_decoder_layers)
        ])
        self.decoder_norm = nn.LayerNorm(d_model)
        # prediction head
        self.output_layer = nn.Linear(d_model, tgt_vocab_size, bias=False)
        self.output_layer.weight = self.tgt_embedding.token_emb.weight
        # reset parameters
        self._reset_parameters()

    def _reset_parameters(self):
        """Initialize parameters"""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, src, tgt, mask_src=None, mask_tgt=None):
        if mask_src is None:
            mask_src = self.make_src_mask(src)
        if mask_tgt is None:
            mask_tgt = self.make_tgt_mask(tgt)

        encoder_out = self.encode(src, mask_src=mask_src)
        decoder_out = self.decode(tgt=tgt, enc_src=encoder_out, mask_tgt=mask_tgt, mask_src=mask_src)
        out_puts = self.output_layer(decoder_out)
        return out_puts

    def encode(self, src, mask_src=None):
        """Encode source sequence
        Arg:
            src: (batch_size, src_len)
            mask_src: (batch_size, 1, 1, src_len)
        """
        # src: (batch_size, src_len)
        # mask_src: (batch_size, 1, 1, src_len)
        src = self.src_embedding(src)  # (batch_size, src_len, d_model)
        for layer in self.encoder:
            if self.use_gradient_checkpointing:
                src = torch.utils.checkpoint.checkpoint(layer, src, mask_src)  # No tgt for encoder
            else:
                src = layer(src, mask_src)
        src = self.encoder_norm(src)
        return src

    def decode(self, tgt, enc_src, mask_tgt=None, mask_src=None):
        """Decode target sequence
        Arg:
            tgt: (batch_size, tgt_len)
            enc_src: (batch_size, src_len, d_model)
            mask_tgt: (batch_size, 1, tgt_len, tgt_len)
            mask_src: (batch_size, 1, 1, src_len)
        """

        tgt = self.tgt_embedding(tgt)  # (batch_size, tgt_len, d_model)
        for layer in self.decoder:
            if self.use_gradient_checkpointing:
                tgt = torch.utils.checkpoint.checkpoint(
                    layer, tgt, enc_src, mask_tgt, mask_src
                )
            else:
                tgt = layer(tgt, enc_src, mask_tgt, mask_src)
        tgt = self.decoder_norm(tgt)
        return tgt

    def make_src_mask(self, src):
        """Create source padding mask"""
        # src: (batch_size, src_len)
        src_mask = (src == self.pad_idx[0]).unsqueeze(1).unsqueeze(2)
        return src_mask

    def make_tgt_mask(self, tgt):
        """Create target padding mask"""
        # tgt: (batch_size, tgt_len)
        tgt_mask = (tgt == self.pad_idx[1]).unsqueeze(1).unsqueeze(2)  # (batch_size,tgt_len)

        return tgt_mask

    @torch.no_grad()
    def generate(self, src, max_len=50, start_token=1, end_token=2,
                 temperature=1.0, top_k=None, top_p=None):
        """Generate sequence using greedy decoding or sampling
        Args:
            src: (batch_size, src_len)
            max_len: chiều dài tối đa của chuỗi được tạo
            start_token: index của token bắt đầu
            end_token: index của token kết thúc
            temperature: nhiệt độ cho sampling
            top_k: top-k sampling
            top_p: nucleus sampling
        Returns:
            generated sequences: (batch_size, generated_len)
        """
        self.eval()
        device = src.device

        # Encode source
        if src.dim() == 1:
            src = src.unsqueeze(0)

        src_mask = self.make_src_mask(src)
        enc_src = self.encode(src, src_mask)

        tgt = torch.tensor([[start_token]], device=device)
        for _ in range(max_len):
            tgt_mask = self.make_tgt_mask(tgt)
            dec_out = self.decode(tgt, enc_src, mask_src=src_mask, mask_tgt=tgt_mask)  # (batch_size, tgt_len, d_model)
            logits = self.output_layer(dec_out[:, -1, :])  # (batch_size, vocab_size)
            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)

            next_token = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            tgt = torch.cat([tgt, next_token], dim=1)

            # Check for end token
            if next_token.item() == end_token:
                break

        return tgt.squeeze(0)  # (generated_len,)


In [4]:
class METTDataset(Dataset):
    """
    Optimized version với:
    - Pre-filtering data quá dài
    - Cached tokenization
    - Proper error handling
    - Memory efficient
    """

    def __init__(
            self,
            data: List[Dict[str, str]],
            tokenizer_eng: str = "bert-base-uncased",
            tokenizer_vie: str = "vinai/phobert-base",
            max_length: int = 75,
            cache_dir: Optional[str] = None,
            use_cache: bool = True
    ):
        self.max_length = max_length
        self.cache_dir = cache_dir
        self.use_cache = use_cache

        # Load tokenizers
        logger.info("Loading tokenizers...")
        self.tokenizer_eng = AutoTokenizer.from_pretrained(tokenizer_eng)
        self.tokenizer_vie = AutoTokenizer.from_pretrained(tokenizer_vie)

        # Cache file path
        cache_file = None
        if cache_dir and use_cache:
            os.makedirs(cache_dir, exist_ok=True)
            cache_file = os.path.join(
                cache_dir,
                f"cached_data_{tokenizer_eng.replace('/', '_')}_{tokenizer_vie.replace('/', '_')}_{max_length}.pkl"
            )

        # Try to load from cache
        if cache_file and os.path.exists(cache_file):
            logger.info(f"Loading cached data from {cache_file}")
            with open(cache_file, 'rb') as f:
                cache_data = pickle.load(f)
                self.data = cache_data['data']
                self.en_tokens = cache_data['en_tokens']
                self.vi_tokens = cache_data['vi_tokens']
            logger.info(f"Loaded {len(self.data)} samples from cache")
        else:
            # Process and filter data
            logger.info("Processing and filtering data...")
            self.data, self.en_tokens, self.vi_tokens = self._process_data(data)

            # Save to cache
            if cache_file:
                logger.info(f"Saving to cache: {cache_file}")
                with open(cache_file, 'wb') as f:
                    pickle.dump({
                        'data': self.data,
                        'en_tokens': self.en_tokens,
                        'vi_tokens': self.vi_tokens
                    }, f)

        logger.info(f"Dataset ready with {len(self.data)} samples")

    def _process_data(self, raw_data: List[Dict[str, str]]) -> Tuple[List[Dict], List[List[int]], List[List[int]]]:
        """Process and filter data, return valid samples only"""
        valid_data = []
        en_tokens_list = []
        vi_tokens_list = []

        filtered_count = 0
        error_count = 0

        for idx, item in enumerate(tqdm(raw_data, desc="Tokenizing")):
            try:
                en_text = item.get("en", "")
                vi_text = item.get("vi", "")

                # Skip empty
                if not en_text or not vi_text:
                    filtered_count += 1
                    continue

                # Tokenize
                en_encoded = self.tokenizer_eng(
                    en_text,
                    add_special_tokens=True,
                    truncation=False,
                    return_attention_mask=False
                )["input_ids"]

                vi_encoded = self.tokenizer_vie(
                    vi_text,
                    add_special_tokens=True,
                    truncation=False,
                    return_attention_mask=False
                )["input_ids"]

                # Filter by length
                if len(en_encoded) >= self.max_length or len(vi_encoded) >= self.max_length:
                    filtered_count += 1
                    continue

                # Keep valid samples
                valid_data.append(item)
                en_tokens_list.append(en_encoded)
                vi_tokens_list.append(vi_encoded)

            except Exception as e:
                error_count += 1
                if error_count <= 5:  # Log first 5 errors
                    logger.warning(f"Error processing item {idx}: {e}")
                continue

        logger.info(f"Filtered {filtered_count} samples (too long or empty)")
        logger.info(f"Errors: {error_count} samples")
        logger.info(f"Valid samples: {len(valid_data)}")

        return valid_data, en_tokens_list, vi_tokens_list

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return pre-tokenized data"""
        en_tokens = self.en_tokens[idx]
        vi_tokens = self.vi_tokens[idx]

        return torch.tensor(en_tokens, dtype=torch.long), torch.tensor(vi_tokens, dtype=torch.long)

    def decode(self, input_ids, language: str = 'eng') -> str:
        """Decode token ids back to text"""
        if language == 'eng':
            return self.tokenizer_eng.decode(input_ids, skip_special_tokens=False)
        elif language == 'vi':
            return self.tokenizer_vie.decode(input_ids, skip_special_tokens=False)
        else:
            raise ValueError("language must be 'eng' or 'vi'")

    def get_vocab_size(self, language: str = 'vi') -> int:
        """Get vocabulary size"""
        if language == 'eng':
            return len(self.tokenizer_eng)
        elif language == 'vi':
            return len(self.tokenizer_vie)
        else:
            raise ValueError("language must be 'eng' or 'vi'")


def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]], pad_idx_eng: int, pad_idx_vie: int) -> Dict[
    str, torch.Tensor]:
    """Collate function to pad sequences in a batch"""
    en_batch, vi_batch = zip(*batch)

    en_padded = pad_sequence(en_batch, batch_first=True, padding_value=pad_idx_eng)
    vi_padded = pad_sequence(vi_batch, batch_first=True, padding_value=pad_idx_vie)

    return {
        'en_input_ids': en_padded,
        'vi_input_ids': vi_padded
    }


In [5]:
class CrossEntropyLoss(nn.Module):
    """
    Optimized Cross Entropy Loss với:
    - Ignore padding tokens
    - Label smoothing
    - Efficient computation
    - Multiple metrics
    """

    def __init__(
            self,
            vocab_size: int,
            pad_idx: int = 0,
            label_smoothing: float = 0.1,
            reduction: str = 'mean',
            ignore_index: Optional[int] = None
    ):
        super(CrossEntropyLoss, self).__init__()

        self.vocab_size = vocab_size
        self.pad_idx = pad_idx
        self.label_smoothing = label_smoothing
        self.reduction = reduction
        self.ignore_index = ignore_index if ignore_index is not None else pad_idx

        # Use built-in CrossEntropyLoss with optimizations
        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=self.ignore_index,
            label_smoothing=label_smoothing,
            reduction=reduction
        )

    def forward(
            self,
            predict: torch.Tensor,
            target: torch.Tensor,
            return_metrics: bool = False
    ) -> torch.Tensor:
        """
        Args:
            predict: (batch_size, seq_len, vocab_size)
            target: (batch_size, seq_len)
            return_metrics: khi mà True thì trả về thêm metrics
        Returns:
            loss or (loss, metrics_dict)
        """
        batch_size, seq_len, vocab_size = predict.size()

        # Reshape efficiently
        predict_flat = predict.reshape(-1, vocab_size)
        target_flat = target.reshape(-1)

        # Compute loss
        loss = self.loss_fn(predict_flat, target_flat)

        if return_metrics:
            with torch.no_grad():
                metrics = self._compute_metrics(predict, target, predict_flat, target_flat)
            return loss, metrics

        return loss

    def _compute_metrics(
            self,
            predict: torch.Tensor,
            target: torch.Tensor,
            predict_flat: torch.Tensor,
            target_flat: torch.Tensor
    ) -> Dict[str, float]:
        """Compute additional metrics"""
        # Mask for non-padding tokens
        mask = (target_flat != self.ignore_index)

        # Accuracy
        pred_labels = predict_flat.argmax(dim=-1)
        correct = (pred_labels == target_flat) & mask
        accuracy = correct.sum().item() / mask.sum().item() if mask.sum() > 0 else 0.0

        # Perplexity
        with torch.cuda.amp.autocast(enabled=False):
            log_probs = F.log_softmax(predict_flat.float(), dim=-1)
            nll_loss = F.nll_loss(
                log_probs,
                target_flat,
                ignore_index=self.ignore_index,
                reduction='mean'
            )
            perplexity = torch.exp(nll_loss).item()

        # Token count
        num_tokens = mask.sum().item()

        return {
            'accuracy': accuracy,
            'perplexity': perplexity,
            'num_tokens': num_tokens
        }

In [6]:
DEVICES = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
EPOCHS = 10
PIN_MEMORY = True
SMOOTHING = 0.1
USE_AMP = False
ACCUMULATION_STEPS = 2
MAX_GRAD_NORM = 1.0
MAX_LEN = 100
NUM_WORKERS = 4


In [7]:
# configure logging
warnings.filterwarnings("ignore", category=UserWarning, message=".*torch.utils.checkpoint:*")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)


def train(ct_model, loss, train_loader, optimizer, device, epoch=None):
    ct_model.train()
    total_loss = 0.0
    loop = tqdm(train_loader, desc=f"Epoch {epoch}" if epoch is not None else "Training")
    for idx, batch in enumerate(loop):
        en_input_ids = batch['en_input_ids'].to(device)
        vi_input_ids = batch['vi_input_ids'].to(device)

        optimizer.zero_grad()

        outputs = ct_model(
            src=en_input_ids,
            tgt=vi_input_ids
        )

        loss_value = loss(
            predict=outputs,
            target=vi_input_ids
        )

        loss_value.backward()
        optimizer.step()

        total_loss += loss_value.item()
        if (idx + 1) % 10 == 0:
            loop.set_postfix(epoch=epoch, loss=total_loss / (idx + 1), idx=idx)

        if idx % 25 == 0:
            torch.cuda.empty_cache()
            gc.collect()

    avg_loss = total_loss / len(train_loader)
    return avg_loss


def main():
    dataset = load_dataset('hiimbach/mtet', cache_dir="/datasets")["train"]
    mtet_dataset = METTDataset(dataset, cache_dir="./cache", max_length=MAX_LEN, use_cache=True)
    pad_idx = (mtet_dataset.tokenizer_eng.pad_token_id, mtet_dataset.tokenizer_vie.pad_token_id)
    train_loader = DataLoader(
        mtet_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        # num_workers=config.NUM_WORKERS,
        collate_fn=lambda x: collate_fn(x, pad_idx_eng=pad_idx[0], pad_idx_vie=pad_idx[1])
    )
    print("device:", DEVICES)
    model = Transformer(
        src_vocab_size=mtet_dataset.get_vocab_size(language='eng'),
        tgt_vocab_size=mtet_dataset.get_vocab_size(language='vi'),
        d_model=512,
        num_heads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        d_ff=2048,
        max_seq_len=MAX_LEN,
        dropout=0.1,
        use_gradient_checkpointing=True,
        pad_idx=pad_idx
    )
    model.to(DEVICES)
    criterion = CrossEntropyLoss(
        vocab_size=mtet_dataset.get_vocab_size(language='vi'),
        label_smoothing=SMOOTHING,
        pad_idx=pad_idx[1],
    )
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    for epoch in range(1, EPOCHS + 1):
        avg_loss = train(model, criterion, train_loader, optimizer, DEVICES, epoch)
        logger.info(f"Epoch [{epoch}/{EPOCHS}], Loss: {avg_loss:.4f}")
        if epoch % 5 == 0:
            torch.save(model.state_dict(), f"transformer_epoch_{epoch}.pth")

        torch.cuda.empty_cache()
        gc.collect()


if __name__ == "__main__":
    main()


INFO:__main__:Loading tokenizers...
INFO:__main__:Loading cached data from ./cache\cached_data_bert-base-uncased_vinai_phobert-base_100.pkl
INFO:__main__:Loaded 3557930 samples from cache
INFO:__main__:Dataset ready with 3557930 samples


device: cuda


Epoch 1:   0%|          | 111/889483 [01:06<147:38:39,  1.67it/s, epoch=1, idx=109, loss=9.27]


KeyboardInterrupt: 