In [1]:
from typing import Sequence

import jax.numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import tensorflow

In [2]:
from datasets import load_dataset
import keras
import tensorflow as tf
# from torch import Tensor
# import torch
# import torch.nn as nn
# from torch.nn import Transformer
import math
import os

from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

# DEVICE = keras.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from flax.training import checkpoints
from flax.training import common_utils
from flax.training import dynamic_scale as dynamic_scale_lib
from flax.training import train_state
import optax
class TrainState(train_state.TrainState):
  dynamic_scale: dynamic_scale_lib.DynamicScale

## 1. Load data

In [3]:
train_ds, val_ds, test_ds=load_dataset('Helsinki-NLP/opus_books','en-hu',split=[
    'train[:80%]','train[80%:-10%]','train[-10%:]'])

In [4]:
def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]
        
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
def get_tokenizer(ds,lang,path_tokenizer=''):
    if os.path.isfile(path_tokenizer) and path_tokenizer:
        tokenizer = Tokenizer.from_file(path_tokenizer)
    else:
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[BOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        # tokenizer.save(config['tokenizer_file'])
    return tokenizer

In [6]:
tokenizer={}
tokenizer['en']=get_tokenizer(train_ds,'en')
tokenizer['hu']=get_tokenizer(train_ds,'hu')

## 2. Build model

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self,emb_size: int,dropout: float,maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = np.exp(- np.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = np.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = np.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = np.sin(pos * den)
        pos_embedding[:, 1::2] = np.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embed(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens):
        return self.embedding(tokens.long()) * np.sqrt(self.emb_size)

class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = keras.layers.Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: np.ndarray,
                trg: np.ndarray,
                src_mask: np.ndarray,
                tgt_mask: np.ndarray,
                src_padding_mask: np.ndarray,
                tgt_padding_mask: np.ndarray,
                memory_key_padding_mask: np.ndarray):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: np.ndarray, src_mask: np.ndarray):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: np.ndarray, memory: np.ndarray, tgt_mask: np.ndarray):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [77]:
def generate_square_subsequent_mask(sz):
    mask = (np.triu(np.ones((sz, sz))) == 1).transpose(0, 1)
    
    mask=mask.astype(float)
    mask=np.where(mask == 0,-np.inf,mask)
    mask= np.where(mask==1,float(0.0),mask)

    # mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = np.zeros((src_seq_len, src_seq_len)).astype(bool)

    src_padding_mask = (src == PAD_IDX).T
    tgt_padding_mask = (tgt == PAD_IDX).T
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

## 3. Initialize parameters and model

In [7]:

SRC_VOCAB_SIZE = tokenizer['en'].get_vocab_size()
TGT_VOCAB_SIZE = tokenizer['hu'].get_vocab_size()
EMB_SIZE = 128 #512
NHEAD = 8
FFN_HID_DIM = 128 #512
BATCH_SIZE = 128
# BATCH_SIZE = 64
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

# initial_variables = jax.jit(transformer.init)(
#       init_rng,
#       jnp.ones(input_shape, jnp.float32),
#       jnp.ones(target_shape, jnp.float32),
#   )
start_step = 0
rng = jax.random.key(2)
rng, init_rng = jax.random.split(rng)
input_shape = (BATCH_SIZE, EMB_SIZE)
target_shape = (BATCH_SIZE, EMB_SIZE)


initial_variables = jax.jit(transformer.init)(
    init_rng,
    jnp.ones(input_shape, jnp.float32),
    jnp.ones(target_shape, jnp.float32),
)

# for p in transformer.parameters():
#     if p.dim() > 1:
#         nn.init.xavier_uniform_(p)

# transformer = transformer.to(DEVICE)

# loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

## 4. Train model

In [40]:
def tensor_transform(token_ids):
    return tf.concat([[BOS_IDX],token_ids.ids,[EOS_IDX]],0)

# def collate_fn(batch):
#     src_batch, tgt_batch = [], []
#     for sample in batch:
#         src_batch.append(tensor_transform(tokenizer['en'].encode(sample['en'].rstrip("\n"))))
#         tgt_batch.append(tensor_transform(tokenizer['hu'].encode(sample['hu'].rstrip("\n"))))

#     src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
#     tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
#     return src_batch, tgt_batch

In [49]:
def get_ds(data_raw):
  src, tgt = [], []

  for sample in data_raw:
      src.append(tensor_transform(tokenizer['en'].encode(sample['en'].rstrip("\n"))))
      tgt.append(tensor_transform(tokenizer['hu'].encode(sample['hu'].rstrip("\n"))))

  
  src = keras.utils.pad_sequences(src, padding='post',value=PAD_IDX,maxlen=EMB_SIZE)
  tgt = keras.utils.pad_sequences(tgt, padding='post',value=PAD_IDX,maxlen=EMB_SIZE)

  return src,tgt
  return [(src[i],tgt[i])for i in range(len(src))]

In [79]:
dataset = tf.data.Dataset.from_tensor_slices(get_ds(train_ds['translation']))

[[    2  8077    75 ...     1     1     1]
 [    2 13466     7 ...     1     1     1]
 [    2   440     0 ...     1     1     1]
 ...
 [    2    14  8350 ...     1     1     1]
 [    2    14   417 ...     1     1     1]
 [    2     0     4 ...     1     1     1]] (109721, 128)


In [80]:
# Shuffle and batch the dataset
batch_size = 32
buffer_size = 1000  # Buffer size for shuffling
dataset = dataset.shuffle(buffer_size).batch(batch_size)

# Optionally, prefetch data for improved performance
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
@jax.jit
def train_step(
    state: train_state.TrainState, batch: jnp.ndarray,learning_rate_fn,label_smoothing=0.0,dropout_rng=None,
):
    inputs,targets = batch


    def loss_fn(params):
        logits = state.apply_fn({'params': params}, image)
        loss = cross_entropy_loss(logits=logits, labels=label)
        return loss, logits


    gradient_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = gradient_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits=logits, labels=label)
    return state, metrics


In [None]:
def train_and_evaluate(train_dataset, eval_dataset, test_dataset, state, epochs):
    num_train_batches = tf.data.experimental.cardinality(train_dataset)
    num_eval_batches = tf.data.experimental.cardinality(eval_dataset)
    num_test_batches = tf.data.experimental.cardinality(test_dataset)
   
    for epoch in tqdm(range(1, epochs + 1)):
        best_eval_loss = 1e6

        # ============== Training ============== #
        train_batch_metrics = []
        train_datagen = iter(tfds.as_numpy(train_dataset))
        for batch_idx in range(num_train_batches):
            batch = next(train_datagen)
            state, metrics = train_step(state, batch)
            train_batch_metrics.append(metrics)
        train_batch_metrics = accumulate_metrics(train_batch_metrics)




        # ============== Validation ============= #
        eval_batch_metrics = []
        eval_datagen = iter(tfds.as_numpy(eval_dataset))
        for batch_idx in range(num_eval_batches):
            batch = next(eval_datagen)
            metrics = eval_step(state, batch)
            eval_batch_metrics.append(metrics)
        eval_batch_metrics = accumulate_metrics(eval_batch_metrics)


        # Log Metrics to Weights & Biases
        wandb.log({
            "Train Loss": train_batch_metrics['loss'],
            "Train Accuracy": train_batch_metrics['accuracy'],
            "Validation Loss": eval_batch_metrics['loss'],
            "Validation Accuracy": eval_batch_metrics['accuracy']
        }, step=epoch)


    return state