In [45]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import FrozenDict
import optax
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the dataset
dataset = load_dataset("Helsinki-NLP/opus_books", "en-hu")

# Select smaller subsets of the dataset
train_dataset = dataset['train'].select(range(50000))
val_dataset = dataset['train'].select(range(50000, 60000))
test_dataset = dataset['train'].select(range(60000, 70000))

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hu")

In [46]:
# Define the model
class Transformer(nn.Module):
    vocab_size: int
    hidden_dim: int = 256
    num_heads: int = 4
    num_layers: int = 3
    max_length: int = 128

    def setup(self):
        self.embedding = nn.Embed(self.vocab_size, self.hidden_dim)
        self.encoder_layers = [nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.hidden_dim) for _ in range(self.num_layers)]
        self.decoder_layers = [nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.hidden_dim) for _ in range(self.num_layers)]
        self.output_layer = nn.Dense(self.vocab_size)

    def __call__(self, x, y):
        x_embed = self.embedding(x)
        y_embed = self.embedding(y)

        for layer in self.encoder_layers:
            x_embed = layer(x_embed)

        for layer in self.decoder_layers:
            y_embed = layer(y_embed)

        logits = self.output_layer(y_embed)
        return logits


In [47]:
@jax.jit
def train_step(params, opt_state, batch):
    def loss_fn(params):
        logits = model.apply({'params': params}, batch[0], batch[2])
        
        # Convert labels to one-hot encoding
        labels_one_hot = jax.nn.one_hot(batch[2], num_classes=tokenizer.vocab_size)
        
        # Compute softmax cross-entropy loss
        loss = optax.softmax_cross_entropy(logits=logits, labels=labels_one_hot)
        
        return jnp.mean(loss)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


In [48]:
# Define pmap'd training step
@jax.jit
def mapped_train_step(params, opt_state, batch):
    params, opt_state, loss = jax.pmap(
        train_step, axis_name='batch')(params, opt_state, batch)
    loss = jax.pmean(loss, axis_name='batch')
    return params, opt_state, loss


In [49]:
# Tokenize the dataset
def preprocess_function(examples):
    inputs = [ex['en'] for ex in examples['translation']]
    targets = [ex['hu'] for ex in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])
tokenized_val = val_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])
tokenized_test = test_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])

# Convert to numpy arrays for JAX
def convert_to_numpy(tokenized_dataset):
    input_ids = jnp.array(tokenized_dataset["input_ids"])
    attention_mask = jnp.array(tokenized_dataset["attention_mask"])
    labels = jnp.array(tokenized_dataset["labels"])
    return input_ids, attention_mask, labels

train_input_ids, train_attention_mask, train_labels = convert_to_numpy(tokenized_train)
val_input_ids, val_attention_mask, val_labels = convert_to_numpy(tokenized_val)
test_input_ids, test_attention_mask, test_labels = convert_to_numpy(tokenized_test)


In [50]:
# Instantiate the model
model = Transformer(vocab_size=tokenizer.vocab_size)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 128), dtype=jnp.int32), jnp.ones((1, 128), dtype=jnp.int32))["params"]
optimizer = optax.adam(learning_rate=0.0001)
opt_state = optimizer.init(params)


In [51]:
print("Train Input IDs shape:", train_input_ids.shape)
print("Train Attention Mask shape:", train_attention_mask.shape)
print("Train Labels shape:", train_labels.shape)

Train Input IDs shape: (50000, 128)
Train Attention Mask shape: (50000, 128)
Train Labels shape: (50000, 128)


In [52]:
def train_model(train_data, num_epochs, batch_size, initial_params, optimizer):
    opt_state = optimizer.init(initial_params)
    params = initial_params

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in train_data:
            params, opt_state, loss = train_step(params, opt_state, batch)
            epoch_loss += loss

        avg_epoch_loss = epoch_loss / len(train_data)
        print(f"Epoch {epoch+1}, Loss: {avg_epoch_loss}")


In [53]:
optimizer = optax.adam(learning_rate=1e-3)

In [54]:
import time

# Record the start time
start_time = time.time()

# Define the training parameters
batch_size = 16
train_data = (train_input_ids, train_attention_mask, train_labels)
initial_params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 128), dtype=jnp.int32), jnp.ones((1, 128), dtype=jnp.int32))["params"]

# Train the model
trained_params= train_model(train_data, 50, batch_size, initial_params, optimizer)

# Record the end time
end_time = time.time()

# Calculate the training time
training_time = end_time - start_time

# Display the training time
print("Training time:", training_time, "seconds")

Epoch 1, Loss: 10.527602195739746
Epoch 2, Loss: 6.242406368255615
Epoch 3, Loss: 2.7567012310028076
Epoch 4, Loss: 2.0361435413360596
Epoch 5, Loss: 1.0315613746643066
Epoch 6, Loss: 1.1941053867340088
Epoch 7, Loss: 1.0223288536071777
Epoch 8, Loss: 1.0766180753707886
Epoch 9, Loss: 0.9996814131736755
Epoch 10, Loss: 0.8261687159538269
Epoch 11, Loss: 0.612897515296936
Epoch 12, Loss: 0.5480751395225525
Epoch 13, Loss: 0.4410651922225952
Epoch 14, Loss: 0.5039677619934082
Epoch 15, Loss: 0.4538334608078003
Epoch 16, Loss: 0.43136781454086304
Epoch 17, Loss: 0.4076271951198578
Epoch 18, Loss: 0.3754669427871704
Epoch 19, Loss: 0.3384438455104828
Epoch 20, Loss: 0.3069915771484375
Epoch 21, Loss: 0.30940699577331543
Epoch 22, Loss: 0.2740011215209961
Epoch 23, Loss: 0.2542256712913513
Epoch 24, Loss: 0.23851990699768066
Epoch 25, Loss: 0.23250965774059296
Epoch 26, Loss: 0.23146545886993408
Epoch 27, Loss: 0.22077354788780212
Epoch 28, Loss: 0.21609269082546234
Epoch 29, Loss: 0.213431