In [84]:
import torch.nn as nn


class Embedding(nn.Module):

    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size,embedding_dim=d_model)


    def forward(self,x):
        return self.embedding(x)

In [85]:
import torch.nn as nn
import torch


class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len):

        super().__init__()

        pe = torch.zeros(max_len,d_model)
        positions = torch.arange(0,max_len).unsqueeze(1)

        steps = torch.arange(0,d_model,2)

        pe[:,0::2] = torch.sin(positions / (10000) ** (steps / d_model))
        pe[:,1::2] = torch.cos(positions / (10000) ** (steps / d_model))

        pe = pe.unsqueeze(0)

        self.register_buffer("pe",pe)

    def forward(self,x):

        length = x.size(1)

        x = x + self.pe[:,:length,:]

        return x

In [None]:
import torch
import torch.nn as nn
from flash_attn import flash_attn_func

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, mask_future=False):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.future_mask = mask_future
        self.d_k = d_model // n_heads
        assert self.d_k % 8 == 0

        # Linear projections
        self.query_transform = nn.Linear(d_model, d_model, bias=False)
        self.key_transform   = nn.Linear(d_model, d_model, bias=False)
        self.value_transform = nn.Linear(d_model, d_model, bias=False)
        self.output_transform = nn.Linear(d_model, d_model, bias=False)

    def split_heads(self, x):

        B, S, _ = x.size()
        x = x.view(B, S, self.n_heads, self.d_k).permute(0, 2, 1, 3).contiguous()
        return x

    def combine_heads(self, x):
        B, H, S, D = x.size()
        x = x.permute(0, 2, 1, 3).contiguous()
        return x.view(B, S, self.d_model)

    def forward(self, q, k, v, mask=None):
        Q = self.query_transform(q)  
        K = self.key_transform(k)
        V = self.value_transform(v)

        B, S_q, _ = Q.size()
        _, S_k, _ = K.size()

        Q = Q.view(B, S_q, self.n_heads, self.d_k)
        K = K.view(B, S_k, self.n_heads, self.d_k)
        V = V.view(B, S_k, self.n_heads, self.d_k)
        out = flash_attn_func(
            Q, K, V,
            dropout_p=0.0,
            causal=self.future_mask
        )

        out = out.reshape(B, S_q, self.d_model)
        out = self.output_transform(out)
        return out



In [87]:
import torch.nn as nn


class PostionalFeedForward(nn.Module):

    def __init__(self,d_model,d_hl,dropout=0.1):

        super().__init__()

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_hl),
            nn.ReLU(),
            nn.Linear(d_hl, d_model),
            nn.Dropout(dropout)
        )

    def forward(self,x):
        x = self.ff(x)
        return x

In [88]:
import torch
import torch.nn as nn




class BaseTransformerLayer(nn.Module):

    def __init__(self,input_dim, num_heads, feature_dim, dropout=0.1):
        super().__init__()

        self.self_attention = MultiHeadAttention(input_dim,num_heads)
        self.feature_transformation = PostionalFeedForward(input_dim,feature_dim,dropout)

        self.layer_norm_1 = nn.LayerNorm(input_dim)
        self.layer_norm_2 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask = None):
        attn_out = self.self_attention(x,x,x,mask)
        x = x + self.dropout1(attn_out)
        x = self.layer_norm_1(x)

        ff_out = self.feature_transformation(x)
        x = x +self.dropout2(ff_out)
        x = self.layer_norm_2(x)

        return x

class TransformerDecoderLayer(nn.Module):

    def __init__(self,input_dim,num_heads, feature_dim,dropout=0.1):
        super().__init__()

        self.self_attention = MultiHeadAttention(
            input_dim,
            num_heads,
            mask_future=True
        )

        self.encoder_attention = MultiHeadAttention(input_dim,num_heads,mask_future=False)
        self.feature_transformation = PostionalFeedForward(
            input_dim,feature_dim,dropout
        )
        self.layer_norm_1 = nn.LayerNorm(input_dim)
        self.layer_norm_2 = nn.LayerNorm(input_dim)
        self.layer_norm_3 = nn.LayerNorm(input_dim)

        self.dropout1 = nn.Dropout(dropout)

    def forward(self, target, encoder_output, src_mask=None, target_mask=None):

        attn_out = self.self_attention(target,target,target,mask=target_mask)
        target = target + self.dropout1(attn_out)
        target = self.layer_norm_1(target)

        attn_out = self.encoder_attention(q=target,k=encoder_output,v=encoder_output,mask=src_mask)
        target = target + self.dropout1(attn_out)
        target = self.layer_norm_2(target)

        ff_out = self.feature_transformation(target)
        target = target + self.dropout1(ff_out)
        target = self.layer_norm_3(target)
        return target



In [89]:
import torch
import torch.nn as nn


class Decoder(nn.Module):
    def __init__(
        self, d_model,  n_heads, dim_feedforward, dropout, num_layers
    ):
        super().__init__()
        self.decoder_layers = nn.ModuleList(
            [
                TransformerDecoderLayer(d_model, n_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, tgt_emb, src_emb, src_mask, tgt_mask):
        for layer in self.decoder_layers:
            tgt_emb = layer(tgt_emb, src_emb, src_mask, tgt_mask)
        return tgt_emb


In [90]:
import torch
import torch.nn as nn




class Encoder(nn.Module):
    def __init__(
        self, d_model,  n_heads, dim_feedforward, dropout, num_layers
    ):
        super().__init__()
        self.encoder_layers = nn.ModuleList(
            [
                BaseTransformerLayer(d_model, n_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, src_emb, src_mask):
        for layer in self.encoder_layers:
            src_emb = layer(src_emb, src_mask)
        return src_emb



In [91]:
import torch.nn as nn


class TransformerEmbedding(nn.Module):

    def __init__(self, vocab_size, d_model, max_len):
        super().__init__()

        self.token_emb = Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len)

    def forward(self, token):
        device = token.device
        token_emb = self.token_emb(token).to(device,non_blocking=True)
        pos_emb = self.pos_emb(token_emb).to(device,non_blocking=True)

        return pos_emb


In [107]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import math


class Transformer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int,
        n_heads: int,
        num_encoder_layers: int,
        num_decoder_layers: int,
        dim_feedforward: int,
        dropout: float,
        max_len: int,
    ):
        super().__init__()

        self.d_model = d_model

        self.embedding = TransformerEmbedding(
            vocab_size=vocab_size,
            d_model=d_model,
            max_len=max_len,
        )

        self.encoder = Encoder(
            num_layers=num_encoder_layers,
            d_model=d_model,
            n_heads=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        self.decoder = Decoder(
            num_layers=num_decoder_layers,
            d_model=d_model,
            n_heads=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )

        self.output_projection = nn.Linear(d_model, vocab_size, bias=False)

        self.output_projection.weight = self.embedding.token_emb.embedding.weight

    def generate(self, src, max_length=50):
        device = next(self.parameters()).device
        src = src.to(device)
        PAD = 0
        batch_size = src.size(0)

        all_outputs = []

        for i in range(batch_size):
            seq = src[i]
            length = (seq != PAD).sum().item()
            trimmed_seq = seq[:length].unsqueeze(0)

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                # Encoder
                src_emb = self.embedding(trimmed_seq)
                memory = self.encoder(src_emb, None)

                # Start token for decoder
                tgt = torch.tensor([[1]], device=device, dtype=torch.long)

                # Autoregressive decoding loop
                for _ in range(max_length):
                    tgt_emb = self.embedding(tgt)
                    output = self.decoder(tgt_emb, memory, None, None)
                    logits = self.output_projection(output)
                    next_token = torch.argmax(logits[:, -1, :], dim=-1)
                    tgt = torch.cat([tgt, next_token.unsqueeze(1)], dim=1)

                    if next_token.item() == 2:  # EOS token
                        break

            all_outputs.append(tgt.squeeze(0))

        return all_outputs



    def forward(
        self,
        src: torch.Tensor,
        tgt: torch.Tensor,
        src_mask: torch.Tensor = None,
        tgt_mask: torch.Tensor = None
    ):

        src_emb = self.embedding(src) * math.sqrt(self.d_model)
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)

        encoder_output = self.encoder(
            src_emb,
            src_mask=src_mask,
        )

        masked_encoder_output = encoder_output * src_mask.unsqueeze(-1)

        output = self.decoder(
            tgt_emb,
            masked_encoder_output,
            tgt_mask=tgt_mask,
            src_mask=src_mask

        )

        logits = self.output_projection(output)

        return logits



In [93]:
import re
from datasets import load_dataset
from torch.utils.data import Dataset
import torch

class TranslationDataSet:

    def __init__(self,min_len=5,max_len=64,max_ratio=1.5,limit=None):
        self.WHITELIST = "abcdefghijklmnopqrstuvwxyzäöüß0123456789.,!?()[]{}:;-&$@#%£€/\\|_+*¥ "
        self.WHITELIST_SET = set(self.WHITELIST.lower())
        self.min_ratio = 1 / max_ratio
        self.max_ratio = max_ratio
        self.max_len = max_len
        self.min_len = min_len
        self.data = load_dataset("wmt17", "de-en")
        self.limit = limit if limit != None else len(self.data["train"])

    def get_wmt17_datset(self,split="train"):
        return self.data[split].select(range(self.limit))

    def _preprocess_text(self,text):
        text = text.lower()

        text = re.sub(r'http\S+|www\S+|<.*?>', '', text)

        text = "".join(c for c in text if c in self.WHITELIST_SET)

        text = re.sub(r'\s+', ' ', text).strip()

        return text

    def clean_sentence_pair(self,example: dict):
        source_lang = 'de'
        target_lang = 'en'

        source_text = example['translation'][source_lang]
        target_text = example['translation'][target_lang]

        keep_example = True
        cleaned_source = self._preprocess_text(source_text)
        cleaned_target = self._preprocess_text(target_text)


        source_len = len(cleaned_source.split())
        target_len = len(cleaned_target.split())


        if not (self.min_len <= source_len <= self.max_len and self.min_len <= target_len <= self.max_len):
            keep_example = False

        if source_len > 0 and target_len > 0:
            ratio = source_len / target_len
            if not (self.min_ratio <= ratio <= self.max_ratio):
                keep_example = False
        else:
            keep_example = False

        example['translation'][source_lang] = cleaned_source
        example['translation'][target_lang] = cleaned_target
        example['keep'] = keep_example

        return example


class TranslationTorchDataset(Dataset):
    def __init__(self,dataset, tokenizer,max_len=64, src_lang="de",tgt_lang="en"):
        super().__init__()
        self.data = dataset
        self.tokenizer = tokenizer
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        self.max_len = max_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        element = self.data[index]["translation"]

        src_ids = self.tokenizer.encode(element[self.src_lang])[:self.max_len]
        tgt_ids = self.tokenizer.encode(element[self.tgt_lang])[:self.max_len]

        return {"src": torch.tensor(src_ids,dtype=torch.long),"tgt": torch.tensor(tgt_ids,dtype=torch.long)}


In [94]:
import math
from torch.optim.lr_scheduler import LRScheduler

class TransformerLR(LRScheduler):

    def __init__(self, optimizer,d_model,warmup_steps, last_epoch = -1):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.last_epoch = last_epoch
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        step = max(1,self.last_epoch + 1)
        lr = self.d_model ** -0.5 * min(step ** -0.5, step * self.warmup_steps ** -1.5)
        lrs = [lr for _ in self.optimizer.param_groups]

        return lrs

In [95]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.trainers import BpeTrainer
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from transformers import GPT2TokenizerFast
import os

class MyTokenizer:
    def __init__(self, vocab_size=50000, save_dir="my_gpt2_bpe"):
        self.vocab_size = vocab_size
        self.save_dir = save_dir
        self.tokenizer = None
        self.pad_token_id = None
        self.bos_token_id = None
        self.eos_token_id = None

    def train(self, corpus):
        tokenizer = Tokenizer(BPE(unk_token="<|unk|>"))
        tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
        tokenizer.decoder = ByteLevelDecoder()

        special_tokens = [
            "[PAD]",
            "[BOS]",
            "[EOS]",
            "[UNK]",
        ]

        trainer = BpeTrainer(
            vocab_size=self.vocab_size,
            special_tokens=special_tokens,
        )

        tokenizer.train_from_iterator(corpus, trainer=trainer)

        os.makedirs(self.save_dir, exist_ok=True)
        tokenizer.save(f"{self.save_dir}/tokenizer.json")

        self.tokenizer = GPT2TokenizerFast(
            tokenizer_file=f"{self.save_dir}/tokenizer.json",
            pad_token="[PAD]",
            bos_token="[BOS]",
            eos_token="[EOS]",
            unk_token="[UNK]",
        )
        self.pad_token_id = self.tokenizer.pad_token_id
        self.bos_token_id = self.tokenizer.bos_token_id
        self.eos_token_id = self.tokenizer.eos_token_id

    def load(self, path):
        self.tokenizer = GPT2TokenizerFast(
            tokenizer_file=f"{path}/tokenizer.json",
            pad_token="[PAD]",
            bos_token="[BOS]",
            eos_token="[EOS]",
            unk_token="[UNK]",
        )
        return self

    def encode(self, text,add_special_tokens=True):
        ids = self.tokenizer.encode(text)
        if add_special_tokens:
            ids = [self.tokenizer.bos_token_id] + ids + [self.tokenizer.eos_token_id]
        return ids

    def decode(self, ids):
        return self.tokenizer.decode(ids)




In [104]:
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import os
from torch.utils.data import random_split
from torch.utils.data import Subset
from torch.utils.tensorboard import SummaryWriter
import sacrebleu
from tqdm import tqdm





def corpus_iterator(dictionary):
    for ex in dictionary:
        yield ex["translation"]["de"]
        yield ex["translation"]["en"]


def collate_fn(batch, pad_id):
    srcs = [b["src"] for b in batch]
    tgts = [b["tgt"] for b in batch]

    src_batch = nn.utils.rnn.pad_sequence(srcs, batch_first=True, padding_value=pad_id)
    tgt_batch = nn.utils.rnn.pad_sequence(tgts, batch_first=True, padding_value=pad_id)

    src_mask = (src_batch != pad_id)
    tgt_padding_mask = (tgt_batch != pad_id)

    return src_batch, tgt_batch, src_mask.long(), tgt_padding_mask.long()

class CollateWithPad:
    def __init__(self, pad_id):
        self.pad_id = pad_id

    def __call__(self, batch):
        return collate_fn(batch, self.pad_id)

def compute_lengths(dataset):
    return [len(item["src"]) + len(item["tgt"]) for item in dataset]

def run_validation(model, val_loader, loss_fn, device,limit=-1):
    model.eval()
    total_val_loss = 0.0
    num_batches = 0

    with torch.no_grad():
        for batch_idx, (src, tgt, src_mask, tgt_padding_mask) in enumerate(val_loader):
            src = src.to(device, non_blocking=True)
            tgt = tgt.to(device, non_blocking=True)
            src_mask = src_mask.to(device, non_blocking=True)
            tgt_padding_mask = tgt_padding_mask.to(device, non_blocking=True)

            decoder_input = tgt[:, :-1]
            decoder_target = tgt[:, 1:]
            tgt_mask_input = tgt_padding_mask[:, :-1]

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                logits = model(
                    src,
                    decoder_input,
                    src_mask=src_mask,
                    tgt_mask=tgt_mask_input
                )
                loss = loss_fn(
                    logits.reshape(-1, logits.size(-1)),
                    decoder_target.reshape(-1)
                )

            total_val_loss += loss.item()
            num_batches += 1
            if batch_idx == limit:
                break

    model.train()
    return total_val_loss / num_batches

def compute_bleu(
    model,
    val_loader,
    tokenizer,
    device,
    max_length=50,
    max_batches=None
):
    model.eval()

    hypotheses = []
    references = []

    with torch.no_grad():
        for i, (src, tgt, src_mask, _) in enumerate(
            tqdm(val_loader, desc="Computing BLEU", leave=False)
        ):
            if max_batches is not None and i >= max_batches:
                break

            src = src.to(device)
            tgt = tgt.to(device)

            pred = model.generate(
                src,
                max_length=max_length
            )

            for p, r in zip(pred, tgt):
                p = [
                    t.item() for t in p
                    if t.item() not in {
                        0,
                        1,
                        2
                    }
                ]
                r = [
                    t.item() for t in r
                    if t.item() not in {
                        0,
                        1,
                        2
                    }
                ]

                hypotheses.append(tokenizer.decode(p))
                references.append(tokenizer.decode(r))

    bleu = sacrebleu.corpus_bleu(hypotheses, [references])
    return bleu.score



In [None]:
import time

def main():
    ds = TranslationDataSet(limit=150000)
    raw = ds.get_wmt17_datset()
    cleaned = raw.map(ds.clean_sentence_pair)
    cleaned = cleaned.filter(lambda x: x["keep"])

    tokenizer_path = r"/content/drive/MyDrive/mytransformer/models/WMTBPETokenizer"

    import os
    if os.path.exists(tokenizer_path):
        print("Loading existing tokenizer...")
        tokenizer = MyTokenizer().load(tokenizer_path)
    else:
        print("Tokenizer not found. Training a new tokenizer...")
        tokenizer = MyTokenizer()
        iterator = corpus_iterator(cleaned)  # or whatever iterator you have
        tokenizer.train(iterator)
        tokenizer.save(tokenizer_path)
        print(f"Tokenizer saved at {tokenizer_path}")

    pad_id = tokenizer.tokenizer.pad_token_id
    torchdataset = TranslationTorchDataset(cleaned, tokenizer)

    subset_size = 100000
    subset_dataset = torch.utils.data.Subset(torchdataset, list(range(subset_size)))
    train_size, val_size = 80000, 20000
    train_dataset, val_dataset = random_split(subset_dataset, [train_size, val_size])
    collator = CollateWithPad(pad_id)

    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=False, collate_fn=collator)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, collate_fn=collator)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Transformer(50000,512,8,6,6,2048,0.1,64).to(device)

    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_id)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1, betas=(0.9, 0.98), eps=1e-9)
    scaler = torch.cuda.amp.GradScaler()
    scheduler = TransformerLR(optimizer, d_model=model.d_model, warmup_steps=4000)

    num_epochs = 5
    writer = SummaryWriter(log_dir="/content/drive/MyDrive/mytransformer/models/Transformermodel/flashattention/log")
    best_val_loss = float("inf")
    best_globsl_val_loss = 999999
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0
        batch_times = []
        batch_memories = []

        for batch_idx, (src, tgt, src_mask, tgt_padding_mask) in enumerate(
              tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
          ):
            start_time = time.perf_counter()  
            optimizer.zero_grad()
            src, tgt = src.to(device), tgt.to(device)
            src_mask, tgt_padding_mask = src_mask.to(device), tgt_padding_mask.to(device)

            decoder_input = tgt[:, :-1]
            decoder_target = tgt[:, 1:]
            tgt_mask_input = tgt_padding_mask[:, :-1]

            torch.cuda.reset_peak_memory_stats(device) 

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                logits = model(
                    src,
                    decoder_input,
                    src_mask=src_mask,
                    tgt_mask=tgt_mask_input
                )
                loss = loss_fn(
                    logits.reshape(-1, logits.size(-1)),
                    decoder_target.reshape(-1)
                )

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            total_loss += loss.item()

            
            batch_time = time.perf_counter() - start_time
            batch_times.append(batch_time)
            writer.add_scalar("Time/batch", batch_time, epoch * len(train_loader) + batch_idx)

            
            batch_memory = torch.cuda.max_memory_allocated(device) / 1024 ** 2  # MB
            batch_memories.append(batch_memory)

            writer.add_scalar("Memory/peak_batch", batch_memory, epoch * len(train_loader) + batch_idx)
        globsl_val_loss = run_validation(model, val_loader, loss_fn, device,limit=10)
        # bleu = compute_bleu(model, val_loader, tokenizer, device="cuda", max_batches=1)
        avg_train_loss = total_loss / len(train_loader)
        avg_batch_time = sum(batch_times) / len(batch_times)
        avg_batch_memory = sum(batch_memories) / len(batch_memories)
        print(f"Epoch {epoch+1} finished. Loss: {avg_train_loss:.4f} | "
              f"Avg batch time: {avg_batch_time:.3f}s | "
              f"Avg peak memory: {avg_batch_memory:.1f} MB"
              f"Validation loss: {globsl_val_loss}")

        writer.add_scalar("Loss/train_epoch", avg_train_loss, epoch)
        # writer.add_scalar("BLEU/epoch", bleu, epoch)
        writer.add_scalar("Loss/validation_epoch", globsl_val_loss, epoch)
        if globsl_val_loss < best_globsl_val_loss:
            print("validation loss got improved!")
            checkpoint = {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "loss": avg_train_loss,
                "global_step" : epoch
            }
            best_globsl_val_loss = globsl_val_loss
            torch.save(checkpoint, "/content/drive/MyDrive/mytransformer/models/Transformermodel/flashattention/checkpoint_best_val_loss.pt")
    writer.close()

if __name__ == "__main__":
    torch.multiprocessing.freeze_support()
    main()

Loading existing tokenizer...


  scaler = torch.cuda.amp.GradScaler()


Epoch 1 finished. Loss: 92.7942 | Avg batch time: 0.167s | Avg peak memory: 48759.4 MBValidation loss: 32.86406222256747
validation loss got improved!




Epoch 2 finished. Loss: 31.9296 | Avg batch time: 0.168s | Avg peak memory: 39342.8 MBValidation loss: 26.42453384399414
validation loss got improved!




Epoch 3 finished. Loss: 26.9332 | Avg batch time: 0.168s | Avg peak memory: 31951.4 MBValidation loss: 25.149976210160688
validation loss got improved!




Epoch 4 finished. Loss: 22.7849 | Avg batch time: 0.168s | Avg peak memory: 31951.4 MBValidation loss: 17.41531077298251
validation loss got improved!




Epoch 5 finished. Loss: 18.3375 | Avg batch time: 0.168s | Avg peak memory: 31951.4 MBValidation loss: 14.375019853765314
validation loss got improved!
