In [1]:
import torch
import torch.nn as nn
import librosa
import soundfile as sf
from transformers import AutoProcessor, AutoModel, BitsAndBytesConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ---- 4-bit quantization (smallest VRAM) ----
bnb_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,  # bf16 compute for stability/speed
)


# Some envs prefer explicit dtype even when quantized
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

In [3]:
# 1) Load Gemma 3n (pick an effective size that fits your hardware)
MODEL_ID = "google/gemma-3n-E2B"   # also try: "google/gemma-3n-E4B"
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained(MODEL_ID)
# model = AutoModel.from_pretrained(MODEL_ID).to(device)
# model.eval()


print("Loading quantized model… (this can take a moment)")
model = AutoModel.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_4bit,
    torch_dtype=dtype,
    device_map="auto",
    # attn_implementation=attn_impl,  # uncomment if flash‑attn installed
)
model.eval()
print("Loaded.")

Loading quantized model… (this can take a moment)


Loading checkpoint shards: 100%|██████████| 3/3 [00:01<00:00,  1.97it/s]
Some weights of Gemma3nModel were not initialized from the model checkpoint at google/gemma-3n-E2B and are newly initialized: ['audio_tower.conformer.0.attention.attn.k_proj.weight', 'audio_tower.conformer.0.attention.attn.per_dim_scale', 'audio_tower.conformer.0.attention.attn.q_proj.weight', 'audio_tower.conformer.0.attention.attn.relative_position_embedding.pos_proj.weight', 'audio_tower.conformer.0.attention.attn.v_proj.weight', 'audio_tower.conformer.0.attention.post.weight', 'audio_tower.conformer.0.attention.post_norm.weight', 'audio_tower.conformer.0.attention.pre_attn_norm.weight', 'audio_tower.conformer.0.ffw_layer_end.ffw_layer_1.weight', 'audio_tower.conformer.0.ffw_layer_end.ffw_layer_2.weight', 'audio_tower.conformer.0.ffw_layer_end.post_layer_norm.weight', 'audio_tower.conformer.0.ffw_layer_end.pre_layer_norm.weight', 'audio_tower.conformer.0.ffw_layer_start.ffw_layer_1.weight', 'audio_tower.conform

Loaded.


In [4]:

# 2) Load an audio file (mono float32). Gemma 3n expects ~16 kHz; we’ll resample if needed.
# Replace "example.wav" with your file path, or load arrays directly.

path = '/mnt/c/Users/user/Desktop/Roshidat/Workspace/PD_prediction/data/1_data/HC_AH/AH_064F_7AB034C9-72E4-438B-A9B3-AD7FDA1596C5.wav'
wav, sr = sf.read(path)
if wav.ndim > 1:
    wav = librosa.to_mono(wav.T)
target_sr = getattr(processor.feature_extractor, "sampling_rate", 16000)
if sr != target_sr:
    wav = librosa.resample(wav, orig_sr=sr, target_sr=target_sr)
    sr = target_sr

In [None]:
import numpy as np
print('wav type:', type(wav), 'shape:', getattr(wav, 'shape', None))
print('sr:', sr)
if wav is None:
    raise ValueError('wav is None')
if not isinstance(wav, np.ndarray):
    wav = np.asarray(wav, dtype=np.float32)
if wav.ndim > 1:
    wav = librosa.to_mono(wav.T)
if wav.dtype != np.float32:
    wav = wav.astype(np.float32)
print('wav after checks:', type(wav), wav.shape, wav.dtype)
inputs = processor(audio=wav, sampling_rate=sr, return_tensors="pt")
if inputs is None:
    raise ValueError('processor returned None. Check audio format and processor compatibility.')
inputs = {k: v.to(device) for k, v in inputs.items()}

wav type: <class 'numpy.ndarray'> shape: (59822,)
sr: 16000
wav after checks: <class 'numpy.ndarray'> (59822,) float32


TypeError: 'NoneType' object is not subscriptable

In [None]:
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
print(outputs)

In [6]:
# # 3) Prepare inputs for the audio encoder
# # AutoProcessor will produce `input_features` (Mel-spec frames) and a mask.
# inputs = processor(audio=wav, sampling_rate=sr, return_tensors="pt")
# inputs = {k: v.to(device) for k, v in inputs.items()}

# # 4) Forward pass to get audio hidden states
# # Set output_hidden_states=True so we can choose which representation to pool.
# with torch.no_grad():
#     outputs = model(**inputs, output_hidden_states=True)

TypeError: 'NoneType' object is not subscriptable

In [5]:


# 5) Grab the last audio-layer features and mean-pool over the time dimension
# `audio_hidden_states` has shape: (batch, seq_len, hidden_size)
audio_hs = outputs.audio_hidden_states  # tuple of layer outputs
last_audio = audio_hs[-1]               # (B, T, D)

# If a frame mask was returned, use it for masked pooling (robust to padding)
mask = inputs.get("input_features_mask")  # (B, T)
if mask is not None:
    mask = mask.to(last_audio.dtype).unsqueeze(-1)  # (B, T, 1)
    summed = (last_audio * mask).sum(dim=1)         # (B, D)
    lengths = mask.sum(dim=1).clamp(min=1e-6)       # (B, 1)
    pooled = summed / lengths                        # (B, D)
else:
    pooled = last_audio.mean(dim=1)                 # (B, D)

# `pooled` is your fixed-size embedding for downstream tasks.
print("Embedding shape:", pooled.shape)

TypeError: 'NoneType' object is not subscriptable

In [None]:


# 6) (Optional) Tiny classifier head example
num_classes = 8  # change for your task
clf = nn.Linear(pooled.shape[-1], num_classes).to(device)

# Dummy forward (replace with your dataloader + training loop)
logits = clf(pooled)            # (B, num_classes)
probs = logits.softmax(dim=-1)  # class probabilities
print("Probs:", probs)
