In [149]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import re
import random
from typing import List, Tuple, Dict, Optional

In [150]:
class TokenVocabulary:
    def __init__(self):
        # Special tokens first (PAD should be 0)
        self.special_tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>']
        
        # Mathematical operators and symbols
        self.operators = ['+', '-', '*', '/', '^', '(', ')']
        self.digits = []
        self.variables = ['a', 'b', 'c']
        
        # Build vocabulary
        self.vocab = self.special_tokens + self.operators + self.digits + self.variables
        self.token_to_id = {token: i for i, token in enumerate(self.vocab)}
        self.id_to_token = {i: token for i, token in enumerate(self.vocab)}
        
        self.pad_id = self.token_to_id['<PAD>']
        self.sos_id = self.token_to_id['<SOS>']
        self.eos_id = self.token_to_id['<EOS>']
        self.unk_id = self.token_to_id['<UNK>']
        print(f'vocab: {self.vocab}')
    
    def tokenize(self, expression: str) -> List[str]:
        """Tokenize mathematical expression"""
        expression = expression.replace(' ', '')
        tokens = []
        i = 0
        while i < len(expression):
            char = expression[i]
            if char.isdigit():
                # Handle multi-digit numbers
                num = ''
                while i < len(expression) and expression[i].isdigit():
                    num += expression[i]
                    i += 1
                tokens.append(num)
            elif char in self.operators or char.isalpha():
                tokens.append(char)
                i += 1
            else:
                i += 1  # Skip unknown characters
        return tokens
    
    def encode(self, tokens: List[str]) -> List[int]:
        """Convert tokens to IDs"""
        return [self.token_to_id.get(token, self.unk_id) for token in tokens]
    
    def decode(self, ids: List[int]) -> List[str]:
        """Convert IDs to tokens"""
        return [self.id_to_token[id] for id in ids if id != self.pad_id]

class InfixToPolishConverter:
    """Convert infix notation to Polish notation using Shunting Yard algorithm"""
    
    def __init__(self):
        self.precedence = {'+': 1, '-': 1, '*': 2, '/': 2, '^': 3}
        self.right_associative = {'^'}
    
    def infix_to_polish(self, infix_tokens: List[str]) -> List[str]:
        """Convert infix tokens to Polish notation"""
        output = []
        operator_stack = []
        
        # Reverse input and flip parentheses for Polish notation
        reversed_tokens = []
        for token in reversed(infix_tokens):
            if token == '(':
                reversed_tokens.append(')')
            elif token == ')':
                reversed_tokens.append('(')
            else:
                reversed_tokens.append(token)
        
        # Modified Shunting Yard for prefix notation
        for token in reversed_tokens:
            if token.isdigit() or token.isalpha():
                output.append(token)
            elif token == '(':
                operator_stack.append(token)
            elif token == ')':
                while operator_stack and operator_stack[-1] != '(':
                    output.append(operator_stack.pop())
                if operator_stack:
                    operator_stack.pop()
            elif token in self.precedence:
                while (operator_stack and 
                       operator_stack[-1] != '(' and
                       operator_stack[-1] in self.precedence and
                       (self.precedence[operator_stack[-1]] > self.precedence[token] or
                        (self.precedence[operator_stack[-1]] == self.precedence[token] and 
                         token not in self.right_associative))):
                    output.append(operator_stack.pop())
                operator_stack.append(token)
        
        while operator_stack:
            output.append(operator_stack.pop())
        
        return list(reversed(output))

In [151]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
    
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class InfixPolishTransformer(nn.Module):
    def __init__(self, vocab_size: int, d_model: int = 128, nhead: int = 4, 
                 num_encoder_layers: int = 3, num_decoder_layers: int = 3,
                 dim_feedforward: int = 256, max_len: int = 100):
        super().__init__()
        
        self.d_model = d_model
        self.vocab_size = vocab_size
        
        # Embeddings
        self.encoder_embedding = nn.Embedding(vocab_size, d_model)
        self.decoder_embedding = nn.Embedding(vocab_size, d_model)
        
        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        
        # Transformer
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            batch_first=True
        )
        
        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        initrange = 0.1
        self.encoder_embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder_embedding.weight.data.uniform_(-initrange, initrange)
        self.output_projection.bias.data.zero_()
        self.output_projection.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, src, tgt, src_mask=None, tgt_mask=None, 
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        # Embeddings
        src_emb = self.pos_encoding(self.encoder_embedding(src) * math.sqrt(self.d_model))
        tgt_emb = self.pos_encoding(self.decoder_embedding(tgt) * math.sqrt(self.d_model))
        
        # Transformer forward pass
        output = self.transformer(
            src_emb, tgt_emb,
            src_mask=src_mask, tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        
        return self.output_projection(output)
    
    def generate_square_subsequent_mask(self, sz):
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask



In [152]:
class InfixPolishDataset:
    def __init__(self, vocab: TokenVocabulary, converter: InfixToPolishConverter):
        self.vocab = vocab
        self.converter = converter
    
    def generate_simple_expression(self) -> str:
        """Generate simple infix expressions"""
        operators = ['+', '-', '*', '/']
        # operands = ['a', 'b', 'c'] + ['1', '2', '3', '4', '5']
        operands = ['a', 'b', 'c']
        
        # Generate different types of expressions
        expr_type = random.choice(['binary', 'chain', 'paren'])
        
        if expr_type == 'binary':
            # Simple: a + b
            op1 = random.choice(operands)
            operator = random.choice(operators)
            op2 = random.choice(operands)
            return f"{op1}{operator}{op2}"
        
        elif expr_type == 'chain':
            # Chain: a + b * c
            op1 = random.choice(operands)
            op1_op = random.choice(operators)
            op2 = random.choice(operands)
            op2_op = random.choice(operators)
            op3 = random.choice(operands)
            return f"{op1}{op1_op}{op2}{op2_op}{op3}"
        
        else:  # paren
            # Parentheses: (a + b) * c
            op1 = random.choice(operands)
            op1_op = random.choice(operators)
            op2 = random.choice(operands)
            op2_op = random.choice(operators)
            op3 = random.choice(operands)
            return f"({op1}{op1_op}{op2}){op2_op}{op3}"
    
    def create_training_pair(self, infix_expr: str) -> Tuple[List[int], List[int]]:
        """Create training pair from infix expression"""
        infix_tokens = self.vocab.tokenize(infix_expr)
        polish_tokens = self.converter.infix_to_polish(infix_tokens)
        
        infix_ids = self.vocab.encode(infix_tokens)
        polish_ids = [self.vocab.sos_id] + self.vocab.encode(polish_tokens) + [self.vocab.eos_id]
        
        return infix_ids, polish_ids

def create_padding_mask(sequences, pad_id):
    """Create padding mask for sequences"""
    return (sequences == pad_id)

def collate_fn(batch, pad_id):
    """Collate function for DataLoader"""
    src_sequences, tgt_sequences = zip(*batch)
    
    max_src_len = max(len(seq) for seq in src_sequences)
    max_tgt_len = max(len(seq) for seq in tgt_sequences)
    
    padded_src = []
    padded_tgt = []
    
    for src, tgt in zip(src_sequences, tgt_sequences):
        padded_src.append(src + [pad_id] * (max_src_len - len(src)))
        padded_tgt.append(tgt + [pad_id] * (max_tgt_len - len(tgt)))
    
    return torch.tensor(padded_src), torch.tensor(padded_tgt)


In [153]:
def train_model():
    print("Initializing components...")
    vocab = TokenVocabulary()
    converter = InfixToPolishConverter()
    dataset = InfixPolishDataset(vocab, converter)
    
    print(f"Vocabulary size: {len(vocab.vocab)}")
    print(f"Sample vocab: {vocab.vocab[:15]}")
    
    # Test converter first
    test_expr = "b+c*a"
    test_tokens = vocab.tokenize(test_expr)
    test_polish = converter.infix_to_polish(test_tokens)
    print(f"Test conversion: {test_expr} -> {test_tokens} -> {test_polish}")
    
    model = InfixPolishTransformer(vocab_size=len(vocab.vocab))
    
    # Generate training data
    print("Generating training data...")
    training_pairs = []
    
    for i in range(3000):
        try:
            expr = dataset.generate_simple_expression()
            pair = dataset.create_training_pair(expr)
            
            # Validate the pair
            if len(pair[0]) > 0 and len(pair[1]) > 2:  # Must have content
                training_pairs.append(pair)
                
                if i < 5:  # Print first few examples
                    infix_tokens = vocab.tokenize(expr)
                    polish_tokens = converter.infix_to_polish(infix_tokens)
                    print(f"Example {i+1}: {expr}")
                    print(f"  Infix: {infix_tokens} -> {pair[0]}")
                    print(f"  Polish: {polish_tokens} -> {pair[1]}")
                    
        except Exception as e:
            if i < 10:
                print(f"Error with expression: {e}")
            continue
    
    print(f"Generated {len(training_pairs)} valid training pairs")
    
    if len(training_pairs) < 100:
        print("Warning: Very few training pairs generated!")
        return None, None, None
    
    # Training setup
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.pad_id)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    
    model.train()
    
    print("Starting training...")
    for epoch in range(25):
        total_loss = 0
        num_batches = 0
        
        random.shuffle(training_pairs)
        batch_size = 16
        
        for i in range(0, len(training_pairs), batch_size):
            batch = training_pairs[i:i+batch_size]
            if len(batch) < 2:
                continue
                
            try:
                src_batch, tgt_batch = collate_fn(batch, vocab.pad_id)
                
                tgt_input = tgt_batch[:, :-1]
                tgt_output = tgt_batch[:, 1:]
                
                if tgt_input.size(1) == 0:
                    continue
                
                # Create masks
                src_key_padding_mask = create_padding_mask(src_batch, vocab.pad_id)
                tgt_key_padding_mask = create_padding_mask(tgt_input, vocab.pad_id)
                tgt_mask = model.generate_square_subsequent_mask(tgt_input.size(1))
                
                optimizer.zero_grad()
                
                output = model(src_batch, tgt_input, 
                             tgt_mask=tgt_mask,
                             src_key_padding_mask=src_key_padding_mask,
                             tgt_key_padding_mask=tgt_key_padding_mask)
                
                loss = criterion(output.reshape(-1, len(vocab.vocab)), tgt_output.reshape(-1))
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
                
            except Exception as e:
                continue
        
        avg_loss = total_loss / num_batches if num_batches > 0 else float('inf')
        scheduler.step(avg_loss)
        
        print(f"Epoch {epoch+1:2d}, Avg Loss: {avg_loss:.4f}, Batches: {num_batches}")
        
        # Test every 5 epochs
        if epoch % 5 == 0 and epoch > 0:
            model.eval()
            test_expr = "a+b"
            predicted = convert_infix_to_polish(model, vocab, converter, test_expr)
            actual = converter.infix_to_polish(vocab.tokenize(test_expr))
            print(f"  Test: {test_expr} -> Pred: '{predicted}' | Actual: '{' '.join(actual)}'")
            model.train()
    
    return model, vocab, converter




In [154]:
def convert_infix_to_polish(model, vocab, converter, infix_expr: str, max_length: int = 15):
    """Convert infix expression to Polish notation using trained model"""
    model.eval()
    
    with torch.no_grad():
        infix_tokens = vocab.tokenize(infix_expr)
        if len(infix_tokens) == 0:
            return "ERROR: Empty input"
            
        src = torch.tensor([vocab.encode(infix_tokens)])
        tgt = torch.tensor([[vocab.sos_id]])
        
        generated_tokens = []
        
        for step in range(max_length):
            if tgt.size(1) > 20:  # Prevent infinite loops
                break
                
            tgt_mask = model.generate_square_subsequent_mask(tgt.size(1))
            
            try:
                output = model(src, tgt, tgt_mask=tgt_mask)
                logits = output[0, -1, :]
                
                # Use top-k sampling for better results
                top_k = 5
                top_logits, top_indices = torch.topk(logits, top_k)
                probs = F.softmax(top_logits, dim=-1)
                next_token_idx = torch.multinomial(probs, 1).item()
                next_token = top_indices[next_token_idx].item()
                
                if next_token == vocab.eos_id:
                    break
                if next_token == vocab.pad_id:
                    continue
                    
                generated_tokens.append(next_token)
                tgt = torch.cat([tgt, torch.tensor([[next_token]])], dim=1)
                
            except Exception as e:
                break
        
        # Decode result
        result_tokens = [vocab.id_to_token[token_id] for token_id in generated_tokens 
                        if token_id in vocab.id_to_token]
        
        return ' '.join(result_tokens) if result_tokens else "ERROR: No output"


In [155]:
# Example usage
if __name__ == "__main__":
    print("Training Infix to Polish Notation Transformer...")
    model, vocab, converter = train_model()
    
    if model is None:
        print("Training failed!")
    else:
        print("\nTesting conversions:")
        test_expressions = ["a+b", "a+b*c", "(a+b)*c", "a*b+c"]
        
        for expr in test_expressions:
            try:
                predicted = convert_infix_to_polish(model, vocab, converter, expr)
                actual_tokens = converter.infix_to_polish(vocab.tokenize(expr))
                actual = ' '.join(actual_tokens)
                
                print(f"Infix: {expr}")
                print(f"Predicted: {predicted}")
                print(f"Actual:    {actual}")
                print(f"Match: {predicted.strip() == actual.strip()}")
                print("-" * 40)
            except Exception as e:
                print(f"Error with {expr}: {e}")

Training Infix to Polish Notation Transformer...
Initializing components...
vocab: ['<PAD>', '<SOS>', '<EOS>', '<UNK>', '+', '-', '*', '/', '^', '(', ')', 'a', 'b', 'c']
Vocabulary size: 14
Sample vocab: ['<PAD>', '<SOS>', '<EOS>', '<UNK>', '+', '-', '*', '/', '^', '(', ')', 'a', 'b', 'c']
Test conversion: b+c*a -> ['b', '+', 'c', '*', 'a'] -> ['+', 'b', '*', 'c', 'a']
Generating training data...
Example 1: (a*b)+c
  Infix: ['(', 'a', '*', 'b', ')', '+', 'c'] -> [9, 11, 6, 12, 10, 4, 13]
  Polish: ['+', '*', 'a', 'b', 'c'] -> [1, 4, 6, 11, 12, 13, 2]
Example 2: c+a+a
  Infix: ['c', '+', 'a', '+', 'a'] -> [13, 4, 11, 4, 11]
  Polish: ['+', 'c', '+', 'a', 'a'] -> [1, 4, 13, 4, 11, 11, 2]
Example 3: (c-c)+b
  Infix: ['(', 'c', '-', 'c', ')', '+', 'b'] -> [9, 13, 5, 13, 10, 4, 12]
  Polish: ['+', '-', 'c', 'c', 'b'] -> [1, 4, 5, 13, 13, 12, 2]
Example 4: a+a
  Infix: ['a', '+', 'a'] -> [11, 4, 11]
  Polish: ['+', 'a', 'a'] -> [1, 4, 11, 11, 2]
Example 5: (b-b)+c
  Infix: ['(', 'b', '-', 'b