In [1]:
!pip install jax jaxlib flax transformers datasets

Collecting ml-dtypes>=0.2.0 (from jax)
  Downloading ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)
Downloading ml_dtypes-0.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: ml-dtypes
  Attempting uninstall: ml-dtypes
    Found existing installation: ml-dtypes 0.2.0
    Uninstalling ml-dtypes-0.2.0:
      Successfully uninstalled ml-dtypes-0.2.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow-decision-forests 1.8.1 requires wurlitzer, which is not installed.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.
tensorflow 2.15.0 requires ml-dtypes~=0.2.0, but you have m

In [2]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core import FrozenDict
import optax
from datasets import load_dataset
from transformers import AutoTokenizer

# Load the dataset
dataset = load_dataset("Helsinki-NLP/opus_books", "en-hu")

# Select smaller subsets of the dataset
train_dataset = dataset['train'].select(range(50000))
val_dataset = dataset['train'].select(range(50000, 60000))
test_dataset = dataset['train'].select(range(60000, 70000))

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-hu")

# Preprocess function
def preprocess_function(examples):
    inputs = [ex['en'] for ex in examples['translation']]
    targets = [ex['hu'] for ex in examples['translation']]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True, padding="max_length")
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize the dataset
tokenized_train = train_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])
tokenized_val = val_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])
tokenized_test = test_dataset.map(preprocess_function, batched=True, remove_columns=["translation"])

# Convert to numpy arrays for JAX
def convert_to_numpy(tokenized_dataset):
    input_ids = jnp.array(tokenized_dataset["input_ids"])
    attention_mask = jnp.array(tokenized_dataset["attention_mask"])
    labels = jnp.array(tokenized_dataset["labels"])
    return input_ids, attention_mask, labels

train_input_ids, train_attention_mask, train_labels = convert_to_numpy(tokenized_train)
val_input_ids, val_attention_mask, val_labels = convert_to_numpy(tokenized_val)
test_input_ids, test_attention_mask, test_labels = convert_to_numpy(tokenized_test)


Downloading readme:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/137151 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/792k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/850k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.57M [00:00<?, ?B/s]



Map:   0%|          | 0/50000 [00:00<?, ? examples/s]



Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [5]:


class Transformer(nn.Module):
    vocab_size: int
    hidden_dim: int = 256
    num_heads: int = 4
    num_layers: int = 3
    max_length: int = 128

    def setup(self):
        self.embedding = nn.Embed(self.vocab_size, self.hidden_dim)
        self.encoder_layers = [nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.hidden_dim) for _ in range(self.num_layers)]
        self.decoder_layers = [nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.hidden_dim) for _ in range(self.num_layers)]
        self.output_layer = nn.Dense(self.vocab_size)

    def __call__(self, x, y):
        x_embed = self.embedding(x)
        y_embed = self.embedding(y)

        for layer in self.encoder_layers:
            x_embed = layer(x_embed)

        for layer in self.decoder_layers:
            y_embed = layer(y_embed)

        logits = self.output_layer(y_embed)
        return logits


In [6]:
model = Transformer(vocab_size=tokenizer.vocab_size)
print("Model Information:")
print("Vocabulary Size:", model.vocab_size)
print("Hidden Dimension:", model.hidden_dim)
print("Number of Heads:", model.num_heads)
print("Number of Layers:", model.num_layers)
print("Max Length:", model.max_length)


Model Information:
Vocabulary Size: 62522
Hidden Dimension: 256
Number of Heads: 4
Number of Layers: 3
Max Length: 128


In [8]:
model = Transformer(vocab_size=tokenizer.vocab_size)
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 128), dtype=jnp.int32), jnp.ones((1, 128), dtype=jnp.int32))["params"]
optimizer = optax.adam(learning_rate=0.0001)
opt_state = optimizer.init(params)

In [9]:
# Cross entropy loss function
def cross_entropy_loss(logits, labels):
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    return jnp.mean(loss)

In [10]:
# Training step
@jax.jit
def train_step(params, opt_state, batch):
    def loss_fn(params):
        logits = model.apply({'params': params}, batch[0], batch[2])
        loss = cross_entropy_loss(logits, batch[2])
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss

In [11]:
import time


start_time = time.time()

for epoch in range(50):
    for _ in range(5000 // 16):
        batch = (train_input_ids[_:_+16], train_attention_mask[_:_+16], train_labels[_:_+16])
        params, opt_state, loss = train_step(params, opt_state, batch)
    print(f'Epoch {epoch+1} - Loss: {loss}')
    
training_time = time.time() - start_time
print("Training Time:", training_time, "seconds")

Epoch 1 - Loss: 2.083096742630005
Epoch 2 - Loss: 1.9809261560440063
Epoch 3 - Loss: 1.9450907707214355
Epoch 4 - Loss: 1.9234602451324463
Epoch 5 - Loss: 1.910798192024231
Epoch 6 - Loss: 1.9001022577285767
Epoch 7 - Loss: 1.884229302406311
Epoch 8 - Loss: 1.86797297000885
Epoch 9 - Loss: 1.1999534368515015
Epoch 10 - Loss: 1.0686688423156738
Epoch 11 - Loss: 0.985762357711792
Epoch 12 - Loss: 0.9380471110343933
Epoch 13 - Loss: 0.9045101404190063
Epoch 14 - Loss: 0.8698063492774963
Epoch 15 - Loss: 0.8451318740844727
Epoch 16 - Loss: 0.8336489200592041
Epoch 17 - Loss: 0.8203508257865906
Epoch 18 - Loss: 0.8065019249916077
Epoch 19 - Loss: 0.7935856580734253
Epoch 20 - Loss: 0.8446008563041687
Epoch 21 - Loss: 0.7896855473518372
Epoch 22 - Loss: 0.7707398533821106
Epoch 23 - Loss: 0.7530943751335144
Epoch 24 - Loss: 0.7365652918815613
Epoch 25 - Loss: 0.716712474822998
Epoch 26 - Loss: 0.6967530250549316
Epoch 27 - Loss: 0.6868158578872681
Epoch 28 - Loss: 0.6815652251243591
Epoch 29

In [None]:
def translate(params, input_text):
    # Tokenize the input text
    inputs = tokenizer(input_text, return_tensors="jax", padding="max_length", truncation=True, max_length=128)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    print(f"Input IDs: {input_ids}")
    print(f"Attention Mask: {attention_mask}")

    # Initialize the output sequence with BOS token
    output_ids = jnp.zeros((1, 128), dtype=jnp.int32)
    output_ids = output_ids.at[0, 0].set(tokenizer.bos_token_id)

    generated_tokens = [tokenizer.bos_token_id]

    # Auto-regressive generation
    for i in range(1, 128):
        logits = model.apply({'params': params}, input_ids, output_ids)
        next_token = jnp.argmax(logits[:, i-1], axis=-1).item()  # Convert to scalar integer
        output_ids = output_ids.at[:, i].set(next_token)
        generated_tokens.append(next_token)

        if next_token == tokenizer.eos_token_id:
            break
            
        print(f"Step {i}: Logits shape: {logits.shape}")
        print(f"Step {i}: Next Token: {next_token}")
        print(f"Step {i}: Updated Output IDs: {output_ids}")
        
    # Filter out None values from generated_tokens
    generated_tokens = [token for token in generated_tokens if token is not None]
    # Decode the output token IDs to text
    output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    print(f"Generated Tokens: {generated_tokens}")
    print(f"Output Text: {output_text}")

    return output_text


# Example usage
translated_text = translate(params, "Source: Project GutenbergAudiobook available here")
print(f"Translated Text: {translated_text}")