In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import string
from typing import Dict

# 1.) Configuration
CHARS = string.ascii_lowercase + " "
SOS_TOKEN = "["
EOS_TOKEN = "]"
VOCAB = SOS_TOKEN + EOS_TOKEN + CHARS
char_to_idx: Dict[str,int] = {char: i for i, char in enumerate(VOCAB)}
idx_to_char: Dict[int, str] = {i: char for i, char in enumerate(VOCAB)}
VOCAB_SIZE = len(VOCAB)

# 2.) Dataset Logic
class MirrorDataset(Dataset):
    # Priority 1: Custom Dataset for the Mirror (String Reveras) Task.
    def __init__(self, num_samples: int = 10000, max_length: int = 10):
        self.num_samples = num_samples
        self.max_length = max_length
    
    def __len__(self) -> int:
        return self.num_samples
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        length = random.randint(3, self.max_length)
        original = ''.join(random.choice(string.ascii_lowercase) for _ in range(length))

        # Format: [DSOS] + original + [EOS] + reversed
        full_seq = SOS_TOKEN + original + EOS_TOKEN + original[::-1]
        indices = [char_to_idx[c] for c in full_seq]

        # Padding to fixed length (22) for batch stability.
        padded = torch.zeros(self.max_length * 2 + 2, dtype=torch.long)
        padded[:len(indices)] = torch.tensor(indices)
        return padded
    
dataset = MirrorDataset(num_samples=10000)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
print(f"Block 1 complete: Dataset ready with {VOCAB_SIZE} characters in vocab.")

  cpu = _conversion_method_template(device=torch.device("cpu"))


Block 1 complete: Dataset ready with 29 characters in vocab.


In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecoderOnlyTransformer(nn.Module):
    # Priority 3: Decoder-Only Transformer Architecture.
    #   Includes Causal Masking to ensure the model learns next-token prediction.
    def __init__(self, vocab_size: int, d_model: int = 256, nhead: int = 8, num_layers: int = 4):
        super().__init__()
        self.d_model = d_model
        
        # Priority 1: Pytorch API - Standard Layers
        self.embedding = nn.Embedding(vocab_size, d_model)

        # Positional Encoding: Fixed Sinusoidal math.
        # This replaces learned zeros with a deterministic coordinate system.
        pe = torch.zeros(100, d_model)
        position = torch.arange(0, 100, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        # register_buffer ensures 'pe' is moved to GPU with the model but not trained.
        self.register_buffer('pe', pe.unsqueeze(0))

        # Core Transformer Blocks
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            batch_first=True,
            norm_first=True # Helps with training stability.
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

        # Final output layer to map back to vocabulary characters.
        self.fc_out = nn.Linear(d_model, vocab_size)

    def generate_causal_mask(self, sz: int) -> torch.Tensor:
        # Priority 3: The Blindfold (Causal Masking).
        # Creates a triangular matrix that prevents looking at future tokens.
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
        return mask
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # seq_len is needed for the mask and positional encoding.
        seq_len = x.size(1)

        # Move mask to the same device as input (CPU/GPU).
        mask = self.generate_causal_mask(seq_len).to(x.device)

        # 1.) Embed and add fixed sinusoidal position info.
        #  Slicing self.pe to match the current input sequence length.
        x = self.embedding(x) + self.pe[:, :seq_len, :]

        # 2.) Pass through Transformer Decoder.
        #  Since this is decoder only, x is passed as both tgt and memory.
        output = self.transformer_decoder(tgt=x, memory=x, tgt_mask=mask, memory_mask=mask)

        # 3.) Output logits for each character position.
        return self.fc_out(output)
    
# Instantiate the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DecoderOnlyTransformer(VOCAB_SIZE).to(device)

print(f"Block 2 complete: Model initialized on {device} with Sinusoidal Encodings.")



Block 2 complete: Model initialized on cpu with Sinusoidal Encodings.


In [3]:
from ignite.engine import Engine,Events

# 1.) Optimizer and Loss Function.
#  Learning rate is 0.001 as a standard starting point.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0) # Ignore padding (index 0) in loss calculation.

# 2.) The Core Training Step.
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()

    # Move batch to device (CPU/GPU).
    x = batch.to(device)
    y_pred = model(x)

    # 1.) SHIFT: We want to predict the next character.
    #  Remove the last prediction, remove the first input character.
    logits = y_pred[:, :-1, :].contiguous()
    targets = x[:, 1:].contiguous()

    # 2.) FLATTEN: CrossEntropyLoss needs (N, C) where N is total characters.
    # We use .view(-1, ...) to squash the batch and sequence dimensions together.
    loss = criterion(
        logits.view(-1, logits.size(-1)),
        targets.view(-1)
    )

    loss.backward()
    optimizer.step()

    return loss.item()
# 3.) Creating the Ignite Engine.
trainer = Engine(train_step)

print("Block 3 complete: Optimizer and Ignite Engine initialized.")

Block 3 complete: Optimizer and Ignite Engine initialized.


In [4]:
from ignite.metrics import Loss, Accuracy
from ignite.handlers import ProgressBar
from tqdm.autonotebook import tqdm
# 1.) Define the Validation Step.
#  This engine only checks the models work, it does not train the model.
def validation_step(engine, batch):
    model.eval()
    with torch.no_grad():
        x = batch.to(device)
        y_pred = model(x)

        # This allows the Accuracy metric to compare character by character.
        logits = y_pred[:, :-1, :].reshape(-1, VOCAB_SIZE)
        targets = x[:, 1:].reshape(-1)

        return logits, targets
    
evaluator = Engine(validation_step)

# 2.) Attaching Professional Metrics.
Accuracy().attach(evaluator, "accuracy")
Loss(criterion).attach(evaluator, "loss")

# 3.) Adding a Logger at the end of each Epoch.
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(dataloader)
    metrics = evaluator.state.metrics
    print(f"Epoch[{trainer.state.epoch}] Validation - "f"Accuracy: {metrics['accuracy']:.2f}, "f"Loss: {metrics['loss']:.2f}")

# 4.) Progress Bar for both Trainer and Evaluator.
pbar = ProgressBar()
pbar.attach(evaluator)

pbar.attach(trainer)

print("Block 4 complete: Evaluator and Metrics are attached.")  

Block 4 complete: Evaluator and Metrics are attached.


  from tqdm.autonotebook import tqdm


In [5]:
print("Starting the Mirror Transformer Training...")

# I will run for 100 epochs using the data.
trainer.run(dataloader, max_epochs=100)

print("Training is completed.")

Starting the Mirror Transformer Training...


                                                    

Epoch[1] Validation - Accuracy: 0.17, Loss: 2.41


                                                    

Epoch[2] Validation - Accuracy: 0.33, Loss: 1.83


                                                    

Epoch[3] Validation - Accuracy: 0.36, Loss: 1.73


                                                    

Epoch[4] Validation - Accuracy: 0.36, Loss: 1.70


                                                    

Epoch[5] Validation - Accuracy: 0.36, Loss: 1.70


                                                    

Epoch[6] Validation - Accuracy: 0.36, Loss: 1.70


                                                    

Epoch[7] Validation - Accuracy: 0.36, Loss: 1.70


                                                    

Epoch[8] Validation - Accuracy: 0.36, Loss: 1.70


                                                    

Epoch[9] Validation - Accuracy: 0.36, Loss: 1.69


                                                     

Epoch[10] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[11] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[12] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[13] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[14] Validation - Accuracy: 0.36, Loss: 1.70


                                                     

Epoch[15] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[16] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[17] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[18] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[19] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[20] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[21] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[22] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[23] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[24] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[25] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[26] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[27] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[28] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[29] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[30] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[31] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[32] Validation - Accuracy: 0.36, Loss: 1.69


                                                     

Epoch[33] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[34] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[35] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[36] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[37] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[38] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[39] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[40] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[41] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[42] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[43] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[44] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[45] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[46] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[47] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[48] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[49] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[50] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[51] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[52] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[53] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[54] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[55] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[56] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[57] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[58] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[59] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[60] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[61] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[62] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[63] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[64] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[65] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[66] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[67] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[68] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[69] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[70] Validation - Accuracy: 0.37, Loss: 1.67


                                                     

Epoch[71] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[72] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[73] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[74] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[75] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[76] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[77] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[78] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[79] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[80] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[81] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[82] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[83] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[84] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[85] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[86] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[87] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[88] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[89] Validation - Accuracy: 0.36, Loss: 1.68


                                                     

Epoch[90] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[91] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[92] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[93] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[94] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[95] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[96] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[97] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[98] Validation - Accuracy: 0.36, Loss: 1.67


                                                     

Epoch[99] Validation - Accuracy: 0.36, Loss: 1.67


                                                      

Epoch[100] Validation - Accuracy: 0.36, Loss: 1.67
Training is completed.




In [6]:
def test_model(word):
    model.eval()
    # Converting your word into the format the model expects.
    test_seq = SOS_TOKEN + word.lower() + EOS_TOKEN
    x = torch.tensor([char_to_idx[c] for c in test_seq]).unsqueeze(0).to(device)

    with torch.no_grad():
        output = model(x)
        # Take the most probable character at each position.
        pred_indices = torch.argmax(output, dim=-1).squeeze(0)
        # Convert indices back to characters.
        result = ''.join([idx_to_char[idx.item()] for idx in pred_indices])

    print(f"\n --- MODEL TEST ---")
    print(f"Input: {word}")
    print(f"Output: {result}")

# Testing the model with some example words.
test_model("HELLO")


 --- MODEL TEST ---
Input: HELLO
Output: nng]]]o


In [7]:
# Saving the Final Model Weights.
torch.save(model.state_dict(), "mirror_transformer_model.pth")
print("Success. Trained weights saved as 'mirror_transformer_model.pth'.")

Success. Trained weights saved as 'mirror_transformer_model.pth'.
