In [None]:
!pip install git+https://github.com/obbwins/whisper-jax.git@fix-jax-compatibility
!pip install librosa soundfile
!pip install accelerate==0.31.0

In [None]:
import jax.numpy as jnp
from datasets import load_dataset
from jax import device_get, jit
from transformers import WhisperProcessor

from whisper_jax import FlaxWhisperForConditionalGeneration

# load the processor and model
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-base", _do_init=False,
)

def generate_fn(input_features, params):
    pred_ids = model.generate(
        input_features, 
        task="transcribe", 
        return_timestamps=False, 
        max_length=model.config.max_length, 
        params=params,
    )
    return pred_ids.sequences

# jit the generate function for speed
p_generate = jit(generate_fn)

# load a dummy sample from the LibriSpeech dataset
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]

# pre-process: convert the audio array to log-mel input features
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="np").input_features

# Always pass params when calling model.encode!
output = model.encode(input_features=input_features, params=params)
last_hidden_state = output.last_hidden_state

print(last_hidden_state)
print(last_hidden_state.shape)

# run the forward pass (JIT compiled the first time it is called)
pred_ids = p_generate(input_features, params)

# post-process: convert tokens ids to text string
transcription = processor.batch_decode(pred_ids, skip_special_tokens=True)
print(transcription)

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn

class AudioAdapter(nn.Module):
    """
    Audio adapter module that downsamples audio embeddings using a 1D
    convolutional layer for a 4x temporal downsampling.
    """
    embedding_dim: int

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Performs the forward pass of the adapter.
        
        Args:
            x (jnp.ndarray): The input tensor from the audio encoder.
                             Expected shape: (batch_size, sequence_length, embedding_dim)
        
        Returns:
            jnp.ndarray: The downsampled output tensor.
                         Expected shape: (batch_size, new_sequence_length, embedding_dim)
        """
        # Flax Conv layer expects channel-last format (NLC), which the input
        # tensor (1, 1500, 512) already is, so no permutation is needed.
        downsampler = nn.Conv(
            features=self.embedding_dim,
            kernel_size=(4,),
            strides=(4,)
        )
        x = downsampler(x)
        return x

# Example Usage:
# Define the input tensor based on your specifications
input_tensor = last_hidden_state

# Generate a random key for parameter initialization
key = jax.random.PRNGKey(0)

# Instantiate the adapter with the specified embedding dimension
embedding_dim = input_tensor.shape[2]
adapter = AudioAdapter(embedding_dim=embedding_dim)

# Initialize the model parameters
params = adapter.init(key, input_tensor)['params']

# Perform a forward pass
output_tensor = adapter.apply({'params': params}, input_tensor)

print(f"Original input tensor shape: {input_tensor.shape}")
print(f"Downsampled output tensor shape: {output_tensor.shape}")

In [None]:
import jax
import jax.numpy as jnp
from transformers import AutoProcessor, AutoModelForCausalLM, w, MistralForCausalLM


from huggingface_hub import HfApi


# Step 1: Initialize the processor and model
# model_name = "mistralai/Mistral-Small-24B-Base-2501"
model_name = "mistralai/Voxtral-Mini-3B-2507"
# model_name = "mistralai/Magistral-Small-2506"
# model_name = "ministral/Ministral-3b-instruct"
processor = AutoProcessor.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, token = token)
model = MistralForCausalLM.from_pretrained(model_name)

from flax import nnx
nnx.display(model)
# # Step 2: Define your input embeddings
# # The user-provided audio embeddings of shape (1, 375, 512).
# # The processor expects a list of embeddings.
# audio_embeddings = output_tensor

# # The user-provided text embeddings for the question.
# text_question = "What is the primary topic of the audio?"

# # Step 3: Use the processor to create a single input format
# # The processor combines the audio and text inputs into a format the model can use.
# inputs = processor(
#     audios=audio_embeddings,
#     text=text_question,
#     return_tensors="jax"
# )

# # Step 4: Generate a text response using the JAX model
# # The model will take the combined inputs and generate a text output.
# # You can customize generation parameters like max_new_tokens.
# outputs = model.generate(
#     input_ids=inputs.input_ids,
#     attention_mask=inputs.attention_mask,
#     audio_embeddings=inputs.audio_embeddings,
#     max_new_tokens=50
# )

# # Step 5: Decode the output to get the final text response
# response = processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]

# print("Generated Response:")
# print(response)

In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM
import torch

# Choose the model size you want to use
model_name = "mistralai/Voxtral-Mini-3B-2507" # Or use "mistralai/Voxtral-Small-24B-2507"

# Load the processor and model
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

nnx.display(model)

In [None]:
import jax
import jax.numpy as jnp
from transformers import AutoProcessor, AutoModelForCausalLM, FlaxAutoModelForCausalLM, MistralForCausalLM


# Step 1: Initialize the processor and model
# model_name = "mistralai/Mistral-Small-24B-Base-2501"
model_name = "mistralai/Voxtral-Mini-3B-2507"
# model_name = "mistralai/Magistral-Small-2506"
# model_name = "ministral/Ministral-3b-instruct"
processor = AutoProcessor.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, token = token)
model = MistralForCausalLM.from_pretrained(model_name)

from flax import nnx
nnx.display(model)

In [None]:
print(text)