<a href="https://colab.research.google.com/github/emredeveloper/Transformers--General-AI/blob/main/DyT_vs_RMSNorm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm  # İlerleme çubuğu için tqdm ekleniyor

# 1. RMSNorm Sınıfı
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(RMSNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return self.gamma * x_norm + self.beta

# 2. DyT Sınıfı
class DyT(nn.Module):
    def __init__(self, dim, init_alpha=0.5):
        super(DyT, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        x = torch.tanh(self.alpha * x)
        return self.gamma * x + self.beta

# 3. TransformerBlock Sınıfı
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, norm_layer, init_alpha=0.5):
        super(TransformerBlock, self).__init__()
        if norm_layer == 'RMSNorm':
            self.norm1 = RMSNorm(dim)
            self.norm2 = RMSNorm(dim)
        elif norm_layer == 'DyT':
            self.norm1 = DyT(dim, init_alpha)
            self.norm2 = DyT(dim, init_alpha)
        else:
            raise ValueError("Geçersiz norm_layer. 'RMSNorm' veya 'DyT' seçin.")
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        attn_output, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_output
        ffn_output = self.ffn(self.norm2(x))
        x = x + ffn_output
        return x

# 4. SimpleViT Sınıfı
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=10, dim=256, depth=3, heads=4, norm_layer='RMSNorm', init_alpha=0.5):
        super(SimpleViT, self).__init__()
        assert img_size % patch_size == 0, "Görüntü boyutu yama boyutuna bölünebilir olmalı"
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.blocks = nn.ModuleList([
            TransformerBlock(dim=dim, num_heads=heads, norm_layer=norm_layer, init_alpha=init_alpha) for _ in range(depth)
        ])

        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = x.transpose(0, 1)
        for block in self.blocks:
            x = block(x)
        x = x.transpose(0, 1)

        x = x[:, 0]
        x = self.head(x)
        return x

# 5. Eğitim ve Değerlendirme Fonksiyonu
def train_model(model, dataloader, criterion, optimizer, num_epochs, device):
    model.to(device)
    start_time = time.time()

    # Epoch'lar için tqdm
    for epoch in tqdm(range(num_epochs), desc="Epochs", unit="epoch"):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Batch'ler için tqdm
        for inputs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} tamamlandı. Kayıp: {avg_loss:.4f}, Doğruluk: {accuracy:.2f}%")

    end_time = time.time()
    training_time = end_time - start_time
    return training_time, accuracy

# Veri Seti ve DataLoader (CIFAR-10)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Cihaz ve Eğitim Parametreleri
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 1

# RMSNorm Modeli
model_rms = SimpleViT(norm_layer='RMSNorm')
optimizer_rms = optim.Adam(model_rms.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# DyT Modeli
model_dyt = SimpleViT(norm_layer='DyT', init_alpha=0.5)
optimizer_dyt = optim.Adam(model_dyt.parameters(), lr=0.001)

# Eğitim ve Karşılaştırma
print("RMSNorm Modeli Eğitiliyor...")
time_rms, acc_rms = train_model(model_rms, train_loader, criterion, optimizer_rms, num_epochs, device)
print(f"RMSNorm Eğitim Süresi: {time_rms:.2f} saniye, Son Doğruluk: {acc_rms:.2f}%")

print("\nDyT Modeli Eğitiliyor...")
time_dyt, acc_dyt = train_model(model_dyt, train_loader, criterion, optimizer_dyt, num_epochs, device)
print(f"DyT Eğitim Süresi: {time_dyt:.2f} saniye, Son Doğruluk: {acc_dyt:.2f}%")

# Karşılaştırma Sonuçları
print("\nKarşılaştırma:")
print(f"RMSNorm - Süre: {time_rms:.2f}s, Doğruluk: {acc_rms:.2f}%")
print(f"DyT - Süre: {time_dyt:.2f}s, Doğruluk: {acc_dyt:.2f}%")

RMSNorm Modeli Eğitiliyor...


Epochs:   0%|          | 0/1 [00:00<?, ?epoch/s]
Epoch 1/1:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 1/1:   0%|          | 1/1563 [00:00<03:23,  7.69batch/s][A
Epoch 1/1:   0%|          | 2/1563 [00:00<03:22,  7.71batch/s][A
Epoch 1/1:   0%|          | 3/1563 [00:00<03:11,  8.13batch/s][A
Epoch 1/1:   0%|          | 4/1563 [00:00<03:02,  8.53batch/s][A
Epoch 1/1:   0%|          | 5/1563 [00:00<02:53,  8.99batch/s][A
Epoch 1/1:   0%|          | 7/1563 [00:00<02:46,  9.32batch/s][A
Epoch 1/1:   1%|          | 8/1563 [00:00<02:51,  9.09batch/s][A
Epoch 1/1:   1%|          | 9/1563 [00:01<02:53,  8.97batch/s][A
Epoch 1/1:   1%|          | 10/1563 [00:01<02:54,  8.92batch/s][A
Epoch 1/1:   1%|          | 11/1563 [00:01<02:56,  8.78batch/s][A
Epoch 1/1:   1%|          | 12/1563 [00:01<02:58,  8.70batch/s][A
Epoch 1/1:   1%|          | 13/1563 [00:01<02:57,  8.72batch/s][A
Epoch 1/1:   1%|          | 14/1563 [00:01<03:03,  8.44batch/s][A
Epoch 1/1:   1%|          | 15

Epoch 1/1 tamamlandı. Kayıp: 1.8101, Doğruluk: 33.27%
RMSNorm Eğitim Süresi: 154.52 saniye, Son Doğruluk: 33.27%

DyT Modeli Eğitiliyor...


Epochs:   0%|          | 0/1 [00:00<?, ?epoch/s]
Epoch 1/1:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 1/1:   0%|          | 1/1563 [00:00<03:43,  6.99batch/s][A
Epoch 1/1:   0%|          | 2/1563 [00:00<03:03,  8.49batch/s][A
Epoch 1/1:   0%|          | 3/1563 [00:00<02:53,  9.00batch/s][A
Epoch 1/1:   0%|          | 5/1563 [00:00<02:35, 10.02batch/s][A
Epoch 1/1:   0%|          | 7/1563 [00:00<02:29, 10.44batch/s][A
Epoch 1/1:   1%|          | 9/1563 [00:00<02:32, 10.19batch/s][A
Epoch 1/1:   1%|          | 11/1563 [00:01<02:29, 10.40batch/s][A
Epoch 1/1:   1%|          | 13/1563 [00:01<02:25, 10.67batch/s][A
Epoch 1/1:   1%|          | 15/1563 [00:01<02:24, 10.70batch/s][A
Epoch 1/1:   1%|          | 17/1563 [00:01<02:22, 10.84batch/s][A
Epoch 1/1:   1%|          | 19/1563 [00:01<02:21, 10.88batch/s][A
Epoch 1/1:   1%|▏         | 21/1563 [00:02<02:23, 10.76batch/s][A
Epoch 1/1:   1%|▏         | 23/1563 [00:02<02:21, 10.87batch/s][A
Epoch 1/1:   2%|▏         | 

Epoch 1/1 tamamlandı. Kayıp: 1.6889, Doğruluk: 38.04%
DyT Eğitim Süresi: 148.68 saniye, Son Doğruluk: 38.04%

Karşılaştırma:
RMSNorm - Süre: 154.52s, Doğruluk: 33.27%
DyT - Süre: 148.68s, Doğruluk: 38.04%





In [None]:
!pip install tqdm tabulate reportlab

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm  # İlerleme çubuğu için tqdm ekleniyor

# 1. RMSNorm Sınıfı
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super(RMSNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return self.gamma * x_norm + self.beta

# 2. DyT Sınıfı
class DyT(nn.Module):
    def __init__(self, dim, init_alpha=0.5):
        super(DyT, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1) * init_alpha)
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))

    def forward(self, x):
        x = torch.tanh(self.alpha * x)
        return self.gamma * x + self.beta

# 3. TransformerBlock Sınıfı
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, norm_layer, init_alpha=0.5):
        super(TransformerBlock, self).__init__()
        if norm_layer == 'RMSNorm':
            self.norm1 = RMSNorm(dim)
            self.norm2 = RMSNorm(dim)
        elif norm_layer == 'DyT':
            self.norm1 = DyT(dim, init_alpha)
            self.norm2 = DyT(dim, init_alpha)
        else:
            raise ValueError("Geçersiz norm_layer. 'RMSNorm' veya 'DyT' seçin.")
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        attn_output, _ = self.attn(self.norm1(x), self.norm1(x), self.norm1(x))
        x = x + attn_output
        ffn_output = self.ffn(self.norm2(x))
        x = x + ffn_output
        return x

# 4. SimpleViT Sınıfı
class SimpleViT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=10, dim=256, depth=3, heads=4, norm_layer='RMSNorm', init_alpha=0.5):
        super(SimpleViT, self).__init__()
        assert img_size % patch_size == 0, "Görüntü boyutu yama boyutuna bölünebilir olmalı"
        num_patches = (img_size // patch_size) ** 2

        self.patch_embed = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.blocks = nn.ModuleList([
            TransformerBlock(dim=dim, num_heads=heads, norm_layer=norm_layer, init_alpha=init_alpha) for _ in range(depth)
        ])

        self.head = nn.Linear(dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = x.transpose(0, 1)
        for block in self.blocks:
            x = block(x)
        x = x.transpose(0, 1)

        x = x[:, 0]
        x = self.head(x)
        return x

# 5. Eğitim ve Değerlendirme Fonksiyonu
def train_model(model, dataloader, criterion, optimizer, num_epochs, device):
    model.to(device)
    start_time = time.time()

    # Epoch'lar için tqdm
    for epoch in tqdm(range(num_epochs), desc="Epochs", unit="epoch"):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Batch'ler için tqdm
        for inputs, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch", leave=False):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = 100 * correct / total
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{num_epochs} tamamlandı. Kayıp: {avg_loss:.4f}, Doğruluk: {accuracy:.2f}%")

    end_time = time.time()
    training_time = end_time - start_time
    return training_time, accuracy

# Veri Seti ve DataLoader (CIFAR-10)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Cihaz ve Eğitim Parametreleri
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 2

# RMSNorm Modeli
model_rms = SimpleViT(norm_layer='RMSNorm')
optimizer_rms = optim.Adam(model_rms.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# DyT Modeli
model_dyt = SimpleViT(norm_layer='DyT', init_alpha=0.5)
optimizer_dyt = optim.Adam(model_dyt.parameters(), lr=0.001)

# Eğitim ve Karşılaştırma
print("RMSNorm Modeli Eğitiliyor...")
time_rms, acc_rms = train_model(model_rms, train_loader, criterion, optimizer_rms, num_epochs, device)
print(f"RMSNorm Eğitim Süresi: {time_rms:.2f} saniye, Son Doğruluk: {acc_rms:.2f}%")

print("\nDyT Modeli Eğitiliyor...")
time_dyt, acc_dyt = train_model(model_dyt, train_loader, criterion, optimizer_dyt, num_epochs, device)
print(f"DyT Eğitim Süresi: {time_dyt:.2f} saniye, Son Doğruluk: {acc_dyt:.2f}%")

# Karşılaştırma Sonuçları
print("\nKarşılaştırma:")
print(f"RMSNorm - Süre: {time_rms:.2f}s, Doğruluk: {acc_rms:.2f}%")
print(f"DyT - Süre: {time_dyt:.2f}s, Doğruluk: {acc_dyt:.2f}%")

RMSNorm Modeli Eğitiliyor...


Epochs:   0%|          | 0/2 [00:00<?, ?epoch/s]
Epoch 1/2:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 1/2:   0%|          | 1/1563 [00:00<03:24,  7.64batch/s][A
Epoch 1/2:   0%|          | 2/1563 [00:00<03:22,  7.72batch/s][A
Epoch 1/2:   0%|          | 3/1563 [00:00<03:08,  8.29batch/s][A
Epoch 1/2:   0%|          | 4/1563 [00:00<03:00,  8.65batch/s][A
Epoch 1/2:   0%|          | 5/1563 [00:00<02:59,  8.70batch/s][A
Epoch 1/2:   0%|          | 7/1563 [00:00<02:42,  9.57batch/s][A
Epoch 1/2:   1%|          | 9/1563 [00:00<02:33, 10.10batch/s][A
Epoch 1/2:   1%|          | 11/1563 [00:01<02:29, 10.38batch/s][A
Epoch 1/2:   1%|          | 13/1563 [00:01<02:28, 10.46batch/s][A
Epoch 1/2:   1%|          | 15/1563 [00:01<02:28, 10.41batch/s][A
Epoch 1/2:   1%|          | 17/1563 [00:01<02:26, 10.55batch/s][A
Epoch 1/2:   1%|          | 19/1563 [00:01<02:24, 10.70batch/s][A
Epoch 1/2:   1%|▏         | 21/1563 [00:02<02:23, 10.76batch/s][A
Epoch 1/2:   1%|▏         | 2

Epoch 1/2 tamamlandı. Kayıp: 1.7931, Doğruluk: 33.94%



Epoch 2/2:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 2/2:   0%|          | 1/1563 [00:00<03:19,  7.83batch/s][A
Epoch 2/2:   0%|          | 2/1563 [00:00<03:19,  7.83batch/s][A
Epoch 2/2:   0%|          | 3/1563 [00:00<03:06,  8.38batch/s][A
Epoch 2/2:   0%|          | 4/1563 [00:00<03:00,  8.62batch/s][A
Epoch 2/2:   0%|          | 5/1563 [00:00<03:02,  8.52batch/s][A
Epoch 2/2:   0%|          | 6/1563 [00:00<03:04,  8.45batch/s][A
Epoch 2/2:   0%|          | 7/1563 [00:00<03:04,  8.41batch/s][A
Epoch 2/2:   1%|          | 8/1563 [00:00<03:05,  8.37batch/s][A
Epoch 2/2:   1%|          | 9/1563 [00:01<03:06,  8.35batch/s][A
Epoch 2/2:   1%|          | 10/1563 [00:01<03:12,  8.06batch/s][A
Epoch 2/2:   1%|          | 11/1563 [00:01<03:09,  8.18batch/s][A
Epoch 2/2:   1%|          | 12/1563 [00:01<03:05,  8.38batch/s][A
Epoch 2/2:   1%|          | 13/1563 [00:01<03:02,  8.48batch/s][A
Epoch 2/2:   1%|          | 14/1563 [00:01<03:01,  8.53batch/s][A
Epoch 2/2:  

Epoch 2/2 tamamlandı. Kayıp: 1.5690, Doğruluk: 42.15%
RMSNorm Eğitim Süresi: 310.00 saniye, Son Doğruluk: 42.15%

DyT Modeli Eğitiliyor...


Epochs:   0%|          | 0/2 [00:00<?, ?epoch/s]
Epoch 1/2:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 1/2:   0%|          | 1/1563 [00:00<02:59,  8.69batch/s][A
Epoch 1/2:   0%|          | 3/1563 [00:00<02:39,  9.79batch/s][A
Epoch 1/2:   0%|          | 4/1563 [00:00<02:39,  9.78batch/s][A
Epoch 1/2:   0%|          | 6/1563 [00:00<02:27, 10.55batch/s][A
Epoch 1/2:   1%|          | 8/1563 [00:00<02:23, 10.81batch/s][A
Epoch 1/2:   1%|          | 10/1563 [00:00<02:22, 10.87batch/s][A
Epoch 1/2:   1%|          | 12/1563 [00:01<02:21, 10.99batch/s][A
Epoch 1/2:   1%|          | 14/1563 [00:01<02:21, 10.98batch/s][A
Epoch 1/2:   1%|          | 16/1563 [00:01<02:20, 10.98batch/s][A
Epoch 1/2:   1%|          | 18/1563 [00:01<02:20, 11.02batch/s][A
Epoch 1/2:   1%|▏         | 20/1563 [00:01<02:18, 11.10batch/s][A
Epoch 1/2:   1%|▏         | 22/1563 [00:02<02:18, 11.11batch/s][A
Epoch 1/2:   2%|▏         | 24/1563 [00:02<02:17, 11.17batch/s][A
Epoch 1/2:   2%|▏         |

Epoch 1/2 tamamlandı. Kayıp: 1.7113, Doğruluk: 36.97%



Epoch 2/2:   0%|          | 0/1563 [00:00<?, ?batch/s][A
Epoch 2/2:   0%|          | 1/1563 [00:00<03:08,  8.30batch/s][A
Epoch 2/2:   0%|          | 2/1563 [00:00<02:51,  9.11batch/s][A
Epoch 2/2:   0%|          | 3/1563 [00:00<02:49,  9.19batch/s][A
Epoch 2/2:   0%|          | 5/1563 [00:00<02:32, 10.19batch/s][A
Epoch 2/2:   0%|          | 7/1563 [00:00<02:30, 10.33batch/s][A
Epoch 2/2:   1%|          | 9/1563 [00:00<02:26, 10.62batch/s][A
Epoch 2/2:   1%|          | 11/1563 [00:01<02:25, 10.70batch/s][A
Epoch 2/2:   1%|          | 13/1563 [00:01<02:22, 10.86batch/s][A
Epoch 2/2:   1%|          | 15/1563 [00:01<02:24, 10.75batch/s][A
Epoch 2/2:   1%|          | 17/1563 [00:01<02:22, 10.88batch/s][A
Epoch 2/2:   1%|          | 19/1563 [00:01<02:22, 10.86batch/s][A
Epoch 2/2:   1%|▏         | 21/1563 [00:02<02:27, 10.48batch/s][A
Epoch 2/2:   1%|▏         | 23/1563 [00:02<02:34,  9.99batch/s][A
Epoch 2/2:   2%|▏         | 25/1563 [00:02<02:40,  9.61batch/s][A
Epoch 2/2

Epoch 2/2 tamamlandı. Kayıp: 1.3678, Doğruluk: 50.29%
DyT Eğitim Süresi: 299.82 saniye, Son Doğruluk: 50.29%

Karşılaştırma:
RMSNorm - Süre: 310.00s, Doğruluk: 42.15%
DyT - Süre: 299.82s, Doğruluk: 50.29%



