In [3]:
from datasets import load_from_disk
from transformers import (
    Wav2Vec2ForPreTraining, 
    Wav2Vec2FeatureExtractor, 
    Trainer, 
    TrainingArguments,
    EarlyStoppingCallback,
)
import torch
from typing import Dict, List, Union

# Load data and model
dataset = load_from_disk("../Data/kham_asr_preprocessed")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base")

Loading dataset from disk:   0%|          | 0/63 [00:00<?, ?it/s]

In [4]:
# Load model with proper configuration for pretraining
model = Wav2Vec2ForPreTraining.from_pretrained(
    "facebook/wav2vec2-base",
)

# CRITICAL: Set masking config - without this, no loss is computed
model.config.mask_time_prob = 0.065  # Probability of masking time steps
model.config.mask_time_length = 10    # Length of mask spans
model.config.mask_feature_prob = 0.0  # No feature masking
model.config.mask_feature_length = 10

# Verify config
print(f"Mask time prob: {model.config.mask_time_prob}")
print(f"Mask time length: {model.config.mask_time_length}")



Mask time prob: 0.065
Mask time length: 10


In [5]:
from dataclasses import dataclass
from typing import Dict, List, Union, Optional
import numpy as np

# Custom DataCollator that generates mask_time_indices
@dataclass
class DataCollatorForWav2Vec2Pretraining:
    model: Wav2Vec2ForPreTraining
    feature_extractor: Wav2Vec2FeatureExtractor
    padding: Union[bool, str] = True
    pad_to_multiple_of: Optional[int] = None
    
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Extract input_values and convert to list if tensor
        input_values = []
        for feature in features:
            val = feature["input_values"]
            if isinstance(val, torch.Tensor):
                val = val.squeeze().tolist()
            input_values.append(val)
        
        # Pad using feature extractor
        batch = self.feature_extractor.pad(
            {"input_values": input_values},
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )
        
        batch_size = batch["input_values"].shape[0]
        
        # Create attention mask
        mask_indices_seq_length = self.model._get_feat_extract_output_lengths(
            batch["input_values"].shape[-1]
        )
        
        # Make sure masked sequence length is a Python scalar
        if isinstance(mask_indices_seq_length, torch.Tensor):
            mask_indices_seq_length = mask_indices_seq_length.item()
        
        # Create attention mask for the transformer (keep on CPU)
        attention_mask = torch.ones(
            (batch_size, mask_indices_seq_length), 
            dtype=torch.long,
        )
        
        # Generate mask_time_indices
        mask_time_indices = None
        if self.model.config.mask_time_prob > 0:
            mask_time_indices = self._compute_mask_indices(
                (batch_size, mask_indices_seq_length),
                mask_prob=self.model.config.mask_time_prob,
                mask_length=self.model.config.mask_time_length,
                min_masks=2,
            )
            # CRITICAL: Must be boolean tensor
            mask_time_indices = torch.from_numpy(mask_time_indices).bool()
        
        # CRITICAL: Sample negative indices for contrastive loss
        sampled_negative_indices = None
        if self.model.config.num_negatives > 0 and mask_time_indices is not None:
            sampled_negative_indices = self._sample_negative_indices(
                (batch_size, mask_indices_seq_length),
                num_negatives=self.model.config.num_negatives,
                mask_time_indices=mask_time_indices,
            )
            sampled_negative_indices = torch.from_numpy(sampled_negative_indices).long()
        
        return {
            "input_values": batch["input_values"],
            "attention_mask": attention_mask,
            "mask_time_indices": mask_time_indices,
            "sampled_negative_indices": sampled_negative_indices,
        }
    
    def _compute_mask_indices(
        self,
        shape: tuple,
        mask_prob: float,
        mask_length: int,
        min_masks: int = 0,
    ) -> np.ndarray:
        """
        Computes random mask spans for a given shape. Used for masking time steps.
        """
        batch_size, sequence_length = shape
        
        if mask_length < 1:
            raise ValueError("`mask_length` has to be bigger than 0.")
        
        if mask_length > sequence_length:
            raise ValueError(
                f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
                f" and `sequence_length`: {sequence_length}`"
            )
        
        # Compute number of masked spans
        num_masked_spans = int(mask_prob * sequence_length / mask_length + np.random.rand(1).item())
        num_masked_spans = max(num_masked_spans, min_masks)
        
        # Make sure num masked indices <= sequence_length
        if num_masked_spans * mask_length > sequence_length:
            num_masked_spans = sequence_length // mask_length
        
        # Create mask
        mask = np.zeros((batch_size, sequence_length), dtype=bool)
        
        for i in range(batch_size):
            if num_masked_spans == 0:
                continue
                
            # Get random indices to mask
            mask_indices = np.random.choice(
                sequence_length - mask_length + 1,
                num_masked_spans,
                replace=False
            )
            
            # Expand to mask_length
            mask_indices = np.concatenate([
                mask_indices + j for j in range(mask_length)
            ])
            
            mask[i, mask_indices] = True
        
        return mask
    
    def _sample_negative_indices(
        self,
        shape: tuple,
        num_negatives: int,
        mask_time_indices: torch.Tensor,
    ) -> np.ndarray:
        """
        Sample negative indices for contrastive loss.
        For each masked position, sample num_negatives random negative samples.
        """
        batch_size, sequence_length = shape
        
        # Convert mask to numpy if needed
        if isinstance(mask_time_indices, torch.Tensor):
            mask_time_indices_np = mask_time_indices.cpu().numpy()
        else:
            mask_time_indices_np = mask_time_indices
        
        # For each masked position, sample num_negatives negatives
        # Shape: (batch, sequence_length, num_negatives)
        sampled_negatives = np.zeros((batch_size, sequence_length, num_negatives), dtype=np.int32)
        
        for batch_idx in range(batch_size):
            # Get masked positions for this batch item
            masked_positions = np.where(mask_time_indices_np[batch_idx])[0]
            
            if len(masked_positions) == 0:
                continue
            
            # For each position (masked or not), sample negatives
            for seq_idx in range(sequence_length):
                # Sample num_negatives random indices
                # These should be different from seq_idx
                candidates = list(range(sequence_length))
                if seq_idx in candidates:
                    candidates.remove(seq_idx)
                
                if len(candidates) >= num_negatives:
                    neg_indices = np.random.choice(candidates, num_negatives, replace=False)
                else:
                    # If not enough candidates, sample with replacement
                    neg_indices = np.random.choice(candidates, num_negatives, replace=True)
                
                sampled_negatives[batch_idx, seq_idx] = neg_indices
        
        return sampled_negatives

# Initialize data collator
data_collator = DataCollatorForWav2Vec2Pretraining(
    model=model,
    feature_extractor=feature_extractor,
    padding=True,
    pad_to_multiple_of=None,
)

In [12]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./wav2vec_pretrain",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    learning_rate=5e-5,  # Slightly lower learning rate
    warmup_steps=500,    # Add warmup
    num_train_epochs=5,
    fp16=True,
    save_strategy='epoch',
    remove_unused_columns=False,
    dataloader_drop_last=True,
)

# Custom Trainer that manually computes the loss
class Wav2Vec2PretrainingTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        # CRITICAL: Ensure model is in training mode
        if not model.training:
            model.train()
        
        # Debug: print input shapes on first call
        if not hasattr(self, '_printed_shapes'):
            print(f"Input shapes:")
            print(f"  input_values: {inputs['input_values'].shape}")
            print(f"  attention_mask: {inputs['attention_mask'].shape}")
            if inputs.get('mask_time_indices') is not None:
                print(f"  mask_time_indices: {inputs['mask_time_indices'].shape}")
                print(f"  mask_time_indices dtype: {inputs['mask_time_indices'].dtype}")
                print(f"  mask_time_indices sum: {inputs['mask_time_indices'].sum()}")
            if inputs.get('sampled_negative_indices') is not None:
                print(f"  sampled_negative_indices: {inputs['sampled_negative_indices'].shape}")
            print(f"Model training mode: {model.training}")
            print(f"Model config mask_time_prob: {model.config.mask_time_prob}")
            self._printed_shapes = True
        
        # Forward pass
        outputs = model(**inputs)
        
        # If model doesn't compute loss, compute it manually
        if outputs.loss is None:
            # Manually compute contrastive loss
            import torch.nn.functional as F
            
            # Get the outputs we need
            hidden_states = outputs.projected_states  # (batch, seq_len, hidden_size)
            quantized_states = outputs.projected_quantized_states  # (batch, seq_len, hidden_size)
            
            mask_time_indices = inputs.get('mask_time_indices')
            sampled_negative_indices = inputs.get('sampled_negative_indices')
            
            if mask_time_indices is None or sampled_negative_indices is None:
                raise ValueError("Need mask_time_indices and sampled_negative_indices to compute loss")
            
            batch_size, sequence_length, hidden_size = hidden_states.shape
            
            # Mask the hidden states - only compute loss on masked positions
            # mask_time_indices shape: (batch, seq_len)
            masked_hidden = hidden_states[mask_time_indices]  # (num_masked, hidden_size)
            masked_quantized = quantized_states[mask_time_indices]  # (num_masked, hidden_size)
            
            # Sample negatives
            # sampled_negative_indices shape: (batch, seq_len, num_negatives)
            num_negatives = sampled_negative_indices.shape[-1]
            
            # Expand quantized_states for negative sampling
            # Shape: (batch, seq_len, num_negatives, hidden_size)
            negative_quantized_vectors = quantized_states.unsqueeze(2).expand(
                batch_size, sequence_length, num_negatives, hidden_size
            )
            
            # Gather the negative samples
            # Shape: (batch, seq_len, num_negatives, hidden_size)
            sampled_negative_indices_expanded = sampled_negative_indices.unsqueeze(-1).expand(
                batch_size, sequence_length, num_negatives, hidden_size
            )
            negative_quantized = torch.gather(
                negative_quantized_vectors, 
                dim=1,
                index=sampled_negative_indices_expanded
            )
            
            # Only keep negatives for masked positions
            negative_quantized = negative_quantized[mask_time_indices]  # (num_masked, num_negatives, hidden_size)
            
            # Compute cosine similarity
            # Positive similarity (correct targets)
            positive_similarity = F.cosine_similarity(
                masked_hidden.unsqueeze(1),  # (num_masked, 1, hidden_size)
                masked_quantized.unsqueeze(1),  # (num_masked, 1, hidden_size)
                dim=-1
            )  # (num_masked, 1)
            
            # Negative similarity
            negative_similarity = F.cosine_similarity(
                masked_hidden.unsqueeze(1),  # (num_masked, 1, hidden_size)
                negative_quantized,  # (num_masked, num_negatives, hidden_size)
                dim=-1
            )  # (num_masked, num_negatives)
            
            # Concatenate positive and negative logits
            logits = torch.cat([positive_similarity, negative_similarity], dim=1)  # (num_masked, 1 + num_negatives)
            
            # Apply temperature
            temperature = getattr(model.config, 'contrastive_logits_temperature', 0.1)
            logits = logits / temperature
            
            # Target is always index 0 (the positive example)
            targets = torch.zeros(logits.shape[0], dtype=torch.long, device=logits.device)
            
            # Compute cross entropy loss
            contrastive_loss = F.cross_entropy(logits, targets)
            
            # Add diversity loss (codevector_perplexity encourages diversity)
            diversity_loss = ((num_negatives - outputs.codevector_perplexity) / num_negatives) * model.config.num_codevectors
            diversity_loss_weight = getattr(model.config, 'diversity_loss_weight', 0.1)
            
            loss = contrastive_loss + diversity_loss_weight * diversity_loss
            
            print(f"Manually computed loss: {loss.item():.4f} (contrastive: {contrastive_loss.item():.4f}, diversity: {diversity_loss.item():.4f})")
        else:
            loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

# Use custom Trainer
trainer = Wav2Vec2PretrainingTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)



socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


In [7]:
%env WANDB_PROJECT=wav2vec

env: WANDB_PROJECT=wav2vec


In [13]:
trainer.train()

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


Input shapes:
  input_values: torch.Size([4, 126900])
  attention_mask: torch.Size([4, 396])
  mask_time_indices: torch.Size([4, 396])
  mask_time_indices dtype: torch.bool
  mask_time_indices sum: 80
  sampled_negative_indices: torch.Size([4, 396, 100])
Model training mode: True
Model config mask_time_prob: 0.065


Step,Training Loss
500,207.9489
1000,151.4967
1500,139.807
2000,134.3255
2500,128.5356
3000,123.9762
3500,120.218
4000,118.1438
4500,115.902
5000,114.0952


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

TrainOutput(global_step=10515, training_loss=121.58965200151198, metrics={'train_runtime': 36233.4201, 'train_samples_per_second': 9.283, 'train_steps_per_second': 0.29, 'total_flos': 2.4061203328424436e+19, 'train_loss': 121.58965200151198, 'epoch': 5.0})

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.


In [14]:
model.save_pretrained('five_epoch_kham_pretrained')

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
