In [1]:
import os
from pathlib import Path

# Fix common HPC issue: system CA bundle missing; point libcurl/TF to conda's CA bundle.
conda_prefix = os.environ.get('CONDA_PREFIX')
if conda_prefix:
    ca_bundle = Path(conda_prefix) / 'ssl' / 'cacert.pem'
    if ca_bundle.exists():
        os.environ.setdefault('SSL_CERT_FILE', str(ca_bundle))
        os.environ.setdefault('CURL_CA_BUNDLE', str(ca_bundle))
        os.environ.setdefault('REQUESTS_CA_BUNDLE', str(ca_bundle))
        print('Using CA bundle:', ca_bundle)
    else:
        print('Conda CA bundle not found at:', ca_bundle)
else:
    print('CONDA_PREFIX not set; leaving SSL cert settings unchanged.')

# Use the local Orbax (OCDBT) checkpoint you already downloaded/extracted.
CKPT_DIR = Path('/home/chojnowski.h/weishao/chojnowski.h/JaxFM/t5gemma')
assert CKPT_DIR.exists(), f'Checkpoint folder not found: {CKPT_DIR}'
print('Using checkpoint:', CKPT_DIR)

Using CA bundle: /blue/weishao/chojnowski.h/.conda/envs/ml/ssl/cacert.pem
Using checkpoint: /home/chojnowski.h/weishao/chojnowski.h/JaxFM/t5gemma


In [2]:
from gemma import gm
from gemma.research import t5gemma

# Build the matching model skeleton from the preset, then load weights from the local Orbax checkpoint.
preset = t5gemma.T5GemmaPreset.GEMMA2_XL_XL
t5gemma_model = preset.config.make('transformer')

# IMPORTANT: avoid preset.get_checkpoint_from_kaggle(...) (403). Load from disk instead.
t5gemma_params = gm.ckpts.load_params(CKPT_DIR)

In [4]:
import jax.numpy as jnp

def input_token_ids(text: str, *, max_input_length: int = 256) -> jnp.ndarray:
    """Returns ONLY the unpadded token IDs corresponding to `text`."""
    token_ids = preset.tokenizer.encode(text)
    token_ids = token_ids[:max_input_length]
    return jnp.asarray(token_ids, dtype=jnp.int32)

def encode_with_t5gemma_encoder(text: str, *, max_input_length: int = 256):
    """Runs ONLY the encoder. Returns (token_ids, encoder_last_hidden_unpadded)."""
    token_ids = preset.tokenizer.encode(text)
    token_ids = token_ids[:max_input_length]

    pad_id = 0  # T5Gemma uses 0 as padding id internally
    padded = token_ids + [pad_id] * (max_input_length - len(token_ids))
    input_tokens = jnp.asarray([padded], dtype=jnp.int32)
    inputs_mask = input_tokens != pad_id

    encoder_acts = t5gemma_model.apply(
        {'params': t5gemma_params},
        tokens=input_tokens,
        inputs_mask=inputs_mask,
        method=t5gemma_model.compute_encoder_activations,
    )
    encoder_last_hidden = encoder_acts.activations[-1]  # [1, L, d_model]
    encoder_last_hidden_unpadded = encoder_last_hidden[0][inputs_mask[0]]  # [T, d_model]

    return jnp.asarray(token_ids, dtype=jnp.int32), encoder_last_hidden_unpadded

# Example: tokens corresponding to the input text (no padding)
text = "A picture of a sunlit sky"
tok = input_token_ids(text, max_input_length=128)
print('token_ids shape:', tok.shape)
print('token_ids:', tok.tolist())

# If you also want encoder features for just those tokens:
tok2, enc_unpadded = encode_with_t5gemma_encoder(text, max_input_length=128)
print('encoder features shape:', enc_unpadded.shape)

token_ids shape: (7,)
token_ids: [235280, 5642, 576, 476, 4389, 3087, 8203]
encoder features shape: (7, 2048)
