<a href="https://colab.research.google.com/github/monoramasn/Speech_fairness/blob/main/adapter_wishper_new.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install -U accelerate
! pip install -U transformers
!pip install datasets
!pip install evaluate
!pip install jiwer



In [None]:
#pip install openai-whisper

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import evaluate
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from typing import Any, Dict, List, Union
from datasets import DatasetDict, Audio, load_from_disk, concatenate_datasets
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
from datasets import load_dataset, load_metric
voxpopuli_dataset = load_dataset("facebook/voxpopuli", "lt")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [None]:
#dataset_lt = voxpopuli_dataset.remove_columns(['audio_id', 'language', 'raw_text', 'gender', 'speaker_id', 'is_gold_transcript', 'accent'])
dataset_lt = voxpopuli_dataset.remove_columns(['audio_id', 'language', 'gender', 'raw_text', 'speaker_id', 'is_gold_transcript', 'accent'])

In [None]:
dataset_lt

DatasetDict({
    train: Dataset({
        features: ['audio', 'normalized_text'],
        num_rows: 456
    })
    validation: Dataset({
        features: ['audio', 'normalized_text'],
        num_rows: 3
    })
    test: Dataset({
        features: ['audio', 'normalized_text'],
        num_rows: 42
    })
})

In [None]:
gradient_checkpointing = True
freeze_feature_encoder = False
freeze_encoder = False

do_normalize_eval = True
do_lower_case = False
do_remove_punctuation = False
normalizer = BasicTextNormalizer()

In [None]:
model_checkpoint= "openai/whisper-base"
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_checkpoint)
tokenizer = WhisperTokenizer.from_pretrained(model_checkpoint, language="Lithuanian", task="transcribe")
processor = WhisperProcessor.from_pretrained(model_checkpoint, language="Lithuanian", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(model_checkpoint)

if model.config.decoder_start_token_id is None:
    raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

if freeze_feature_encoder:
    model.freeze_feature_encoder()

if freeze_encoder:
    model.freeze_encoder()
    model.model.encoder.gradient_checkpointing = False


model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

if gradient_checkpointing:
    model.config.use_cache = False

In [None]:
model_checkpoint_name = model_checkpoint.split("/")[-1]
repo_name = f"{model_checkpoint_name}-demo-colab"

In [None]:
def prepare_dataset(batch):
    # Load and resample audio data to the expected sampling rate
    audio = batch["audio"]
    input_features = processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    #input_features = input_features.reshape(-1, 80, 3000)

    # Ensure the last dimension of input_features is 3000
    if input_features.shape[-1] < 3000:
        padding = torch.zeros(3000 - input_features.shape[-1])
        input_features = torch.cat([input_features, padding], dim=0)

    batch["input_features"] = input_features

    # Compute input length of audio sample in seconds
    batch["input_length"] = len(audio["array"]) / audio["sampling_rate"]

    #gender_mapping = {'male': 0, 'female': 1}
    #batch["gender_labels"] = gender_mapping[batch["gender"].lower()]

    # Optional pre-processing steps
    transcription = batch["normalized_text"]
    if do_lower_case:
        transcription = transcription.lower()
    if do_remove_punctuation:
        transcription = normalizer(transcription).strip()

    # Encode target text to label ids
    batch["labels"] = processor.tokenizer(transcription, padding="max_length", max_length=max_label_length).input_ids

    return batch

max_label_length = model.config.max_length
min_input_length = 0.0
max_input_length = 30.0
def is_in_length_range(length, labels):
    return min_input_length < length < max_input_length and 0 < len(labels) < max_label_length

In [None]:
# Apply preprocessing and ensure 'labels' key is added
dataset_lt1 = dataset_lt.map(prepare_dataset, batch_size=8)

Map:   0%|          | 0/3 [00:00<?, ? examples/s]

In [None]:
dataset_lt1

DatasetDict({
    train: Dataset({
        features: ['audio', 'normalized_text', 'input_features', 'input_length', 'labels'],
        num_rows: 456
    })
    validation: Dataset({
        features: ['audio', 'normalized_text', 'input_features', 'input_length', 'labels'],
        num_rows: 3
    })
    test: Dataset({
        features: ['audio', 'normalized_text', 'input_features', 'input_length', 'labels'],
        num_rows: 42
    })
})

In [None]:
train_dataset=dataset_lt1["train"]
val_dataset=dataset_lt1["test"]

In [None]:
train_dataset

Dataset({
    features: ['audio', 'normalized_text', 'input_features', 'input_length', 'labels'],
    num_rows: 456
})

In [None]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        #gender_labels = [feature["gender_labels"] for feature in features]
        #batch["gender_labels"] = torch.tensor(gender_labels)

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [None]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)
print('DATASET PREPARATION COMPLETED')

DATASET PREPARATION COMPLETED


In [None]:
data_loader = DataLoader(train_dataset, batch_size=8, collate_fn=data_collator, shuffle=True)

In [None]:
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7fd427e04b80>

In [None]:
import numpy as np
wer_metric = load_metric("wer")
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


adapter-whisper

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn
import torch.nn.functional as F

# Attention mechanism
class WhisperAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=True)
        self.q_proj = nn.Linear(d_model, d_model, bias=True)
        self.out_proj = nn.Linear(d_model, d_model, bias=True)

    def forward(self, k, v, q):
        k = self.k_proj(k)
        v = self.v_proj(v)
        q = self.q_proj(q)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))
        attn_probs = F.softmax(attn_scores, dim=-1)

        context = torch.matmul(attn_probs, v)
        return self.out_proj(context)

# Encoder layer
class WhisperEncoderLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.self_attn = WhisperAttention(d_model)
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.activation_fn = nn.GELU()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.final_layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        attn_output = self.self_attn(x, x, x)
        x = self.self_attn_layer_norm(x + attn_output)

        fc_output = self.fc2(self.activation_fn(self.fc1(x)))
        x = self.final_layer_norm(x + fc_output)
        return x

# Encoder
class WhisperEncoder(nn.Module):
    def __init__(self, input_dim, d_model, n_layers, d_ff, max_len):
        super().__init__()
        self.conv1 = nn.Conv1d(input_dim, d_model, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
        self.embed_positions = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([WhisperEncoderLayer(d_model, d_ff) for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device)
        position_embeddings = self.embed_positions(position_ids)
        x += position_embeddings

        for layer in self.layers:
            x = layer(x)

        x = self.layer_norm(x)
        return x

# Decoder layer
class WhisperDecoderLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.self_attn = WhisperAttention(d_model)
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.encoder_attn = WhisperAttention(d_model)
        self.encoder_attn_layer_norm = nn.LayerNorm(d_model)
        self.activation_fn = nn.GELU()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.final_layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output):
        self_attn_output = self.self_attn(x, x, x)
        x = self.self_attn_layer_norm(x + self_attn_output)

        enc_attn_output = self.encoder_attn(encoder_output, encoder_output, x)
        x = self.encoder_attn_layer_norm(x + enc_attn_output)

        fc_output = self.fc2(self.activation_fn(self.fc1(x)))
        x = self.final_layer_norm(x + fc_output)
        return x

# Decoder
class WhisperDecoder(nn.Module):
    def __init__(self, d_model, n_layers, d_ff, max_len, vocab_size):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, d_model, padding_idx=50257)
        self.embed_positions = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([WhisperDecoderLayer(d_model, d_ff) for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output):
        position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device)
        x = self.embed_tokens(x) + self.embed_positions(position_ids)

        for layer in self.layers:
            x = layer(x, encoder_output)

        x = self.layer_norm(x)
        return x

# Adapter Layer
class AdapterLayer(nn.Module):
    def __init__(self, d_model, d_adapter):
        super().__init__()
        self.down_proj = nn.Linear(d_model, d_adapter)
        self.activation = nn.GELU()
        self.up_proj = nn.Linear(d_adapter, d_model)

    def forward(self, x):
        x = self.down_proj(x)
        x = self.activation(x)
        x = self.up_proj(x)
        return x

# Complete model with Adapter Layer
class WhisperForfinetuneWithAdapter(nn.Module):
    def __init__(self, input_dim=80, encoder_d_model=512, encoder_n_layers=6, encoder_d_ff=2048, max_len=1500, vocab_size=51865, d_adapter=256):
        super().__init__()
        self.encoder = WhisperEncoder(input_dim, encoder_d_model, encoder_n_layers, encoder_d_ff, max_len)
        self.adapter = AdapterLayer(encoder_d_model, d_adapter)
        self.decoder = WhisperDecoder(encoder_d_model, encoder_n_layers, encoder_d_ff, max_len, vocab_size)
        self.proj_out = nn.Linear(encoder_d_model, vocab_size, bias=False)

    def forward(self, input_features, labels):
        encoder_output = self.encoder(input_features)
        adapter_output = self.adapter(encoder_output)
        decoder_output = self.decoder(labels, adapter_output)
        logits = self.proj_out(decoder_output)

        outputs = {'logits': logits}
        if labels is not None:
          loss_fn = nn.CrossEntropyLoss()
          loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
          outputs['loss'] = loss

        return outputs

# Instantiate the model with adapter
model_with_adapter = WhisperForfinetuneWithAdapter()

In [None]:
#output

NameError: name 'output' is not defined

In [None]:
from transformers import TrainingArguments

training_args = Seq2SeqTrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=5,
  fp16=True,
  gradient_checkpointing=False,
  save_steps=50,
  eval_steps=50,
  logging_steps=50,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
  push_to_hub=False,
)

In [None]:
from transformers import Trainer
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model_with_adapter,
    train_dataset=dataset_lt1["train"],
    eval_dataset=dataset_lt1["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

Step,Training Loss,Validation Loss


Step,Training Loss,Validation Loss


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.46 GiB. GPU 0 has a total capacty of 15.77 GiB of which 1.41 GiB is free. Process 29479 has 14.36 GiB memory in use. Of the allocated memory 7.40 GiB is allocated by PyTorch, and 6.56 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
trainer.evaluate()

{'eval_loss': 0.17271128296852112,
 'eval_wer': 0.4129511677282378,
 'eval_runtime': 43.3749,
 'eval_samples_per_second': 1.176,
 'eval_steps_per_second': 0.161,
 'epoch': 5.0}

Adapter fusion with Group-dro

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Basic Adapter Layer
class AdapterLayer(nn.Module):
    def __init__(self, d_model, d_adapter):
        super().__init__()
        self.down_proj = nn.Linear(d_model, d_adapter)
        self.activation = nn.GELU()
        self.up_proj = nn.Linear(d_adapter, d_model)

    def forward(self, x):
        x = self.down_proj(x)
        x = self.activation(x)
        x = self.up_proj(x)
        return x

# Gender-Specific Adapter Layer
class GenderSpecificAdapterLayer(nn.Module):
    def __init__(self, d_model, d_adapter):
        super().__init__()
        self.male_adapter = AdapterLayer(d_model, d_adapter)
        self.female_adapter = AdapterLayer(d_model, d_adapter)

    def forward(self, x, gender_labels):
        # Apply adapters for each gender
        male_mask = (gender_labels == 0)
        female_mask = (gender_labels == 1)

        # Apply male adapter where gender label is 0 (male)
        x_male = self.male_adapter(x[male_mask])

        # Apply female adapter where gender label is 1 (female)
        x_female = self.female_adapter(x[female_mask])

        # Ensure that both parts are of the same data type
        if x.dtype != x_male.dtype:
            x_male = x_male.to(x.dtype)
        if x.dtype != x_female.dtype:
            x_female = x_female.to(x.dtype)

        # Reassemble the output tensor, preserving original order
        output = torch.zeros_like(x)
        output[male_mask] = x_male
        output[female_mask] = x_female

        return output

class AdapterFusionLayer(nn.Module):
    def __init__(self, d_model, d_adapter):
        super().__init__()
        self.gender_adapters = GenderSpecificAdapterLayer(d_model, d_adapter)

    def forward(self, x, gender_labels):
        return self.gender_adapters(x, gender_labels)

# Attention Mechanism
class WhisperAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=True)
        self.q_proj = nn.Linear(d_model, d_model, bias=True)
        self.out_proj = nn.Linear(d_model, d_model, bias=True)

    def forward(self, k, v, q):
        k = self.k_proj(k)
        v = self.v_proj(v)
        q = self.q_proj(q)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(k.size(-1), dtype=torch.float32))
        attn_probs = F.softmax(attn_scores, dim=-1)

        context = torch.matmul(attn_probs, v)
        return self.out_proj(context)

# Encoder Layer with Gender Adapter
class WhisperEncoderLayerWithGenderAdapter(nn.Module):
    def __init__(self, d_model, d_ff, d_adapter):
        super().__init__()
        self.self_attn = WhisperAttention(d_model)
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.activation_fn = nn.GELU()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.final_layer_norm = nn.LayerNorm(d_model)
        self.gender_adapter = GenderSpecificAdapterLayer(d_model, d_adapter)

    def forward(self, x, gender_labels):
        attn_output = self.self_attn(x, x, x)
        x = self.self_attn_layer_norm(x + attn_output)

        fc_output = self.fc2(self.activation_fn(self.fc1(x)))
        x = self.final_layer_norm(x + fc_output)

        x = self.gender_adapter(x, gender_labels)
        return x

# Encoder with Gender Adapter
class WhisperEncoderWithGenderAdapter(nn.Module):
    def __init__(self, input_dim, d_model, n_layers, d_ff, max_len, d_adapter):
        super().__init__()
        self.conv1 = nn.Conv1d(input_dim, d_model, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=3, stride=2, padding=1)
        self.embed_positions = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([WhisperEncoderLayerWithGenderAdapter(d_model, d_ff, d_adapter) for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, gender_label):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.transpose(1, 2)
        position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device)
        position_embeddings = self.embed_positions(position_ids)
        x += position_embeddings

        for layer in self.layers:
            x = layer(x, gender_label)

        x = self.layer_norm(x)
        return x

# Decoder Layer with Gender Adapter
class WhisperDecoderLayerWithGenderAdapter(nn.Module):
    def __init__(self, d_model, d_ff, d_adapter):
        super().__init__()
        self.self_attn = WhisperAttention(d_model)
        self.self_attn_layer_norm = nn.LayerNorm(d_model)
        self.encoder_attn = WhisperAttention(d_model)
        self.encoder_attn_layer_norm = nn.LayerNorm(d_model)
        self.activation_fn = nn.GELU()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.final_layer_norm = nn.LayerNorm(d_model)
        self.gender_adapter = GenderSpecificAdapterLayer(d_model, d_adapter)

    def forward(self, x, encoder_output, gender_labels):
        self_attn_output = self.self_attn(x, x, x)
        x = self.self_attn_layer_norm(x + self_attn_output)

        enc_attn_output = self.encoder_attn(encoder_output, encoder_output, x)
        x = self.encoder_attn_layer_norm(x + enc_attn_output)

        fc_output = self.fc2(self.activation_fn(self.fc1(x)))
        x = self.final_layer_norm(x + fc_output)

        x = self.gender_adapter(x, gender_labels)
        return x

# Decoder with Gender Adapter
class WhisperDecoderWithGenderAdapter(nn.Module):
    def __init__(self, d_model, n_layers, d_ff, max_len, vocab_size, d_adapter):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, d_model, padding_idx=50257)
        self.embed_positions = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([WhisperDecoderLayerWithGenderAdapter(d_model, d_ff, d_adapter) for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, gender_labels):
        position_ids = torch.arange(x.size(1), dtype=torch.long, device=x.device)
        x = self.embed_tokens(x) + self.embed_positions(position_ids)

        for layer in self.layers:
            x = layer(x, encoder_output, gender_labels)

        x = self.layer_norm(x)
        return x

# Complete Model with Gender-Specific Adapter Layers

class WhisperfinetuneWithAdapterFusion(nn.Module):
    def __init__(self, input_dim=80, encoder_d_model=512, encoder_n_layers=6, encoder_d_ff=2048, max_len=1500, vocab_size=51865, d_adapter=256):
        super().__init__()
        # Pass d_adapter to the encoder
        self.encoder = WhisperEncoderWithGenderAdapter(input_dim, encoder_d_model, encoder_n_layers, encoder_d_ff, max_len, d_adapter)
        self.adapter_fusion = AdapterFusionLayer(encoder_d_model, d_adapter)
        self.decoder = WhisperDecoderWithGenderAdapter(encoder_d_model, encoder_n_layers, encoder_d_ff, max_len, vocab_size, d_adapter)
        self.proj_out = nn.Linear(encoder_d_model, vocab_size, bias=False)

    def forward(self, input_features, labels, gender_labels):
        # Pass gender_labels to the encoder
        encoder_output = self.encoder(input_features, gender_labels)
        fused_adapter_output = self.adapter_fusion(encoder_output, gender_labels)
        decoder_output = self.decoder(labels, fused_adapter_output, gender_labels)
        logits = self.proj_out(decoder_output)

        outputs = {'logits': logits}
        if labels is not None:
          loss_fn = nn.CrossEntropyLoss()
          loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
          outputs['loss'] = loss

        return outputs

AdapterFusion = WhisperfinetuneWithAdapterFusion(input_dim=80, encoder_d_model=512, encoder_n_layers=6, encoder_d_ff=2048, max_len=1500, vocab_size=51865, d_adapter=256)

In [None]:
import torch

# Assuming these dimensions based on your model definition
batch_size = 1
audio_feature_size = 80  # Number of features in your audio data
seq_length = 3000        # Sequence length of your audio data

# Create dummy audio features
dummy_audio_features = torch.randn(batch_size, audio_feature_size, seq_length)

# Create dummy labels (e.g., for sequence-to-sequence, these could be token IDs)
vocab_size = 51865       # Vocabulary size (based on your model)
dummy_labels = torch.randint(low=0, high=vocab_size, size=(batch_size, 100)) # Example label sequence

# Create dummy gender labels (0 for male, 1 for female for example)
dummy_gender_labels = torch.tensor([0]) # Assuming batch size of 1

output = AdapterFusion(dummy_audio_features, dummy_labels, dummy_gender_labels)

In [None]:
dummy_audio_features.shape
dummy_labels.shape
#dummy_gender_labels.shape

torch.Size([1, 100])

In [None]:
# Optimizer
from torch.optim import Adam
learning_rate = 0.001
optimizer = Adam(AdapterFusion.parameters(), lr=learning_rate)

def train_one_epoch(AdapterFusion, data_loader, optimizer, device):
    AdapterFusion.train()
    for batch in data_loader:
        input_features = batch['input_features'].to(device)
        labels = batch['labels'].to(device)
        gender_labels = batch['gender_labels'].to(device)

        outputs = AdapterFusion(input_features, labels, gender_labels)
        logits = outputs['logits']

        # Compute loss for each gender group
        loss_male = F.cross_entropy(logits[gender_labels == 0], labels[gender_labels == 0]) if (gender_labels == 0).any() else torch.tensor(0.)
        loss_female = F.cross_entropy(logits[gender_labels == 1], labels[gender_labels == 1]) if (gender_labels == 1).any() else torch.tensor(0.)

        # Choose the highest loss for backpropagation
        worst_group_loss = max(loss_male, loss_female)

        optimizer.zero_grad()
        worst_group_loss.backward()
        optimizer.step()

In [None]:
from torch.cuda.amp import GradScaler, autocast
from transformers import Seq2SeqTrainer, TrainingArguments

class GenderAwareTrainer(Seq2SeqTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.scaler = GradScaler()

    def training_step(self, model, inputs):
        gender_labels = inputs.pop("gender_labels")
        use_amp = self.args.fp16

        with autocast(enabled=use_amp):
            outputs = model(**inputs, gender_labels=gender_labels)
            loss = outputs['loss']

        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer)
        self.optimizer.zero_grad(set_to_none=True)
        self.scaler.update()

        return loss

In [None]:
from transformers import TrainingArguments
training_args = Seq2SeqTrainingArguments(
  output_dir=repo_name,
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=5,
  fp16=True,
  gradient_checkpointing=False,
  save_steps=50,
  eval_steps=50,
  logging_steps=50,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
  push_to_hub=False,
)

In [None]:
# Instantiate the custom trainer
trainer = GenderAwareTrainer(
    args=training_args,
    model=AdapterFusion,  # Make sure this is your gender-aware model
    train_dataset=dataset_lt1["train"],
    eval_dataset=dataset_lt1["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()