import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
import math

class ImprovedTranslationModel(nn.Module):
    """
    Encoder-Decoder Translation Model with proper architecture
    """
    def __init__(self, 
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model=512,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 max_seq_len=128,
                 dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
        
        # Encoder layers (unidirectional for translation)
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Decoder layers
        self.decoder_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        self.output_projection = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def encode(self, src_ids, src_mask=None):
        # Source embedding + positional encoding
        src_emb = self.src_embedding(src_ids) * math.sqrt(self.d_model)
        src_emb = self.pos_encoding(src_emb)
        
        encoder_output = src_emb
        for layer in self.encoder_layers:
            encoder_output = layer(encoder_output, src_mask)
            
        return encoder_output
    
    def decode(self, tgt_ids, encoder_output, src_mask=None, tgt_mask=None):
        # Target embedding + positional encoding
        tgt_emb = self.tgt_embedding(tgt_ids) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoding(tgt_emb)
        
        decoder_output = tgt_emb
        for layer in self.decoder_layers:
            decoder_output = layer(decoder_output, encoder_output, src_mask, tgt_mask)
            
        return self.output_projection(decoder_output)
    
    def forward(self, src_ids, tgt_ids, src_mask=None, tgt_mask=None):
        encoder_output = self.encode(src_ids, src_mask)
        decoder_output = self.decode(tgt_ids, encoder_output, src_mask, tgt_mask)
        return decoder_output

class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding"""
    def __init__(self, d_model, max_seq_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class TransformerEncoderLayer(nn.Module):
    """Standard Transformer Encoder Layer"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention with residual connection
        attn_output, _ = self.self_attention(x, x, x, key_padding_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Feed-forward with residual connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        
        return x

class TransformerDecoderLayer(nn.Module):
    """Standard Transformer Decoder Layer"""
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.cross_attention = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        # Self-attention with causal mask
        attn_output, _ = self.self_attention(x, x, x, attn_mask=tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        
        # Cross-attention
        cross_attn_output, _ = self.cross_attention(x, encoder_output, encoder_output, key_padding_mask=src_mask)
        x = self.norm2(x + self.dropout(cross_attn_output))
        
        # Feed-forward
        ffn_output = self.ffn(x)
        x = self.norm3(x + self.dropout(ffn_output))
        
        return x

class ImprovedTranslationDataset(Dataset):
    """Simplified dataset for translation"""
    def __init__(self, data_pairs, src_tokenizer, tgt_tokenizer, max_len=128):
        self.data_pairs = data_pairs
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.data_pairs)
    
    def __getitem__(self, idx):
        src_text, tgt_text = self.data_pairs[idx]
        
        # Tokenize source
        src_tokens = self.src_tokenizer(
            src_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize target with BOS/EOS
        tgt_tokens = self.tgt_tokenizer(
            tgt_text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        src_ids = src_tokens['input_ids'].squeeze()
        src_mask = src_tokens['attention_mask'].squeeze() == 0  # True for padding
        
        tgt_ids = tgt_tokens['input_ids'].squeeze()
        tgt_input = tgt_ids[:-1]  # Remove last token for input
        tgt_target = tgt_ids[1:]  # Remove first token for target
        
        # Create causal mask for target
        tgt_len = tgt_input.size(0)
        tgt_mask = torch.triu(torch.ones(tgt_len, tgt_len), diagonal=1).bool()
        
        return {
            'src_ids': src_ids,
            'src_mask': src_mask,
            'tgt_input': tgt_input,
            'tgt_target': tgt_target,
            'tgt_mask': tgt_mask
        }

def create_improved_model():
    """Create an improved translation model"""
    
    # Use consistent tokenizers
    src_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    tgt_tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
    
    model = ImprovedTranslationModel(
        src_vocab_size=src_tokenizer.vocab_size,
        tgt_vocab_size=tgt_tokenizer.vocab_size,
        d_model=512,
        num_layers=6,
        num_heads=8,
        d_ff=2048,
        max_seq_len=128,
        dropout=0.1
    )
    
    return model, src_tokenizer, tgt_tokenizer

def train_improved_model(model, train_dataloader, val_dataloader, epochs=10):
    """Simplified training loop"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Assuming 0 is padding
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_dataloader:
            src_ids = batch['src_ids'].to(device)
            src_mask = batch['src_mask'].to(device)
            tgt_input = batch['tgt_input'].to(device)
            tgt_target = batch['tgt_target'].to(device)
            tgt_mask = batch['tgt_mask'].to(device)
            
            optimizer.zero_grad()
            
            logits = model(src_ids, tgt_input, src_mask, tgt_mask)
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_target.view(-1))
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
        
        # Validation
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_dataloader:
                src_ids = batch['src_ids'].to(device)
                src_mask = batch['src_mask'].to(device)
                tgt_input = batch['tgt_input'].to(device)
                tgt_target = batch['tgt_target'].to(device)
                tgt_mask = batch['tgt_mask'].to(device)
                
                logits = model(src_ids, tgt_input, src_mask, tgt_mask)
                loss = criterion(logits.view(-1, logits.size(-1)), tgt_target.view(-1))
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_dataloader)
        print(f"Validation Loss: {avg_val_loss:.4f}")

# Usage example
# model, src_tokenizer, tgt_tokenizer = create_improved_model()
# train_dataset = ImprovedTranslationDataset(train_data, src_tokenizer, tgt_tokenizer)
# train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# train_improved_model(model, train_dataloader, val_dataloader)