<a href="https://colab.research.google.com/github/hashd2035/model_exercise/blob/main/Wav2Vec2_Pytorch_KD1_DryRun.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1. Install necessary libraries:

In [4]:
!pip install torch transformers pytorch-lightning datasets



## 2. Define the Teacher and Student Models

In [5]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch.nn.functional as F
import torch
import pytorch_lightning as pl

# Teacher Model: wav2vec2-base (pre-trained)
teacher_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")

# Student Model: Same architecture for this dry run (or you can define a smaller custom model)
student_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base")


Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 3. Set Up the Dataset (Subset)

In [7]:
from datasets import load_dataset, Audio

# Load the PolyAI minds14 dataset (English subset or use multiple languages if desired)
dataset = load_dataset("PolyAI/minds14", name="ko-KR", split="train[:5%]")  # Use a small subset for testing

# Preview the dataset
print(dataset)

minds14.py:   0%|          | 0.00/5.83k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.28k [00:00<?, ?B/s]

The repository for PolyAI/minds14 contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/PolyAI/minds14.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


MInDS-14.zip:   0%|          | 0.00/471M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
    num_rows: 30
})


## 4. Preprocess the Data

### For MInDS-14

Wav2Vec2 is pretrained on 16kHz sampled speech audio. It is important your audio data’s sampling rate matches the sampling rate of the dataset used to pretrain the model. If your data’s sampling rate isn’t the same, then you need to resample your data.

In [8]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=16_000))

In [9]:
dataset

Dataset({
    features: ['path', 'audio', 'transcription', 'english_transcription', 'intent_class', 'lang_id'],
    num_rows: 30
})

In [10]:
# Preprocess function to process the audio and intents
def prepare_dataset(batch):
    audio = batch["audio"]["array"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]  # Audio features
    batch["labels"] = torch.tensor(batch["intent_class"], dtype=torch.long)  # Intent class as labels
    return batch

# Apply preprocessing
dataset = dataset.map(prepare_dataset, remove_columns=["transcription", "english_transcription", "audio", "intent_class", "lang_id"])
dataset.set_format(type="torch", columns=["input_values", "labels"])

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

### For LJSpeech for Dry Run

In [None]:
from datasets import load_dataset

# Load a small subset of the Common Voice dataset (English)
# dataset = load_dataset("mozilla-foundation/common_voice_11_0", "en", split="train[:1%]")  # Use only 1% for a dry run
# Load the LJ Speech dataset
dataset = load_dataset("lj_speech", split="train")
# Load the processor for Wav2Vec2
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

In [None]:
Wav2Vec2 is pretrained on 16kHz sampled speech audio. It is important your audio data’s sampling rate matches the sampling rate of the dataset used to pretrain the model. If your data’s sampling rate isn’t the same, then you need to resample your data.

In [None]:
# Preprocess function to process the audio
def prepare_dataset(batch):
    audio = batch["audio"]["array"]  # Access the audio data
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]  # Process the audio
    batch["labels"] = processor.tokenizer(batch["text"]).input_ids  # Tokenize the transcription
    return batch

# Apply preprocessing
dataset = dataset.map(prepare_dataset, remove_columns=["text", "audio"])
dataset.set_format(type="torch", columns=["input_values", "labels"])

### For mozilla-foundation/common_voice_11_0 for actual run

In [None]:
# Preprocessing function
def prepare_dataset(batch):
    audio = batch["audio"]["array"]
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch

# Apply preprocessing
dataset = dataset.map(prepare_dataset, remove_columns=["audio", "sentence"])
dataset.set_format(type='torch', columns=['input_values', 'labels'])


## 5. Define the Knowledge Distillation Model


In [39]:
class DistillationModel(pl.LightningModule):
    def __init__(self, teacher_model, student_model):
        super().__init__()
        self.teacher = teacher_model
        self.student = student_model
        self.teacher.eval()  # Freeze the teacher model weights

    def forward(self, input_values):
        return self.student(input_values)

    def training_step(self, batch, batch_idx):
        input_values = batch["input_values"]
        labels = batch["labels"]

        # Student model forward pass
        student_outputs = self.student(input_values)
        student_logits = student_outputs.logits  # Shape: [batch_size, sequence_length, num_classes]

        # Pool the logits over the sequence dimension (mean pooling or using last time step)
        pooled_student_logits = student_logits.mean(dim=1)  # Shape: [batch_size, num_classes]

        # Teacher model forward pass (no gradients)
        with torch.no_grad():
            teacher_logits = self.teacher(input_values).logits
            pooled_teacher_logits = teacher_logits.mean(dim=1)  # Also pool teacher logits

        # Knowledge distillation loss (KL divergence between student and teacher logits)
        distillation_loss = F.kl_div(
            F.log_softmax(pooled_student_logits, dim=-1),
            F.softmax(pooled_teacher_logits, dim=-1),
            reduction="batchmean"
        )

        # Classification loss (cross-entropy with hard labels)
        ce_loss = F.cross_entropy(pooled_student_logits, labels)

        # Combine the two losses
        loss = distillation_loss + ce_loss
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)


## 6. Set Up the DataLoader

In [40]:
from torch.utils.data import DataLoader

# Define a collate function to handle padding
def collate_fn(batch):
    # Extract the input values and labels
    input_values = [example['input_values'] for example in batch]
    labels = [example['labels'] for example in batch]

    # Create a list of dictionaries for input_values (to match expected input format)
    batch_inputs = [{"input_values": iv} for iv in input_values]

    # Apply padding to input_values using the processor
    input_values_padded = processor.pad(
        batch_inputs, padding=True, return_tensors="pt"
    ).input_values

    # Convert labels to tensor (since labels are fixed length, we don't need to pad them)
    labels_padded = torch.tensor(labels)

    return {"input_values": input_values_padded, "labels": labels_padded}

# Use the custom collate function in your DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)


## 7. Train the Student Model

In [41]:
# PyTorch Lightning Trainer (Updated for new API)
trainer = pl.Trainer(
    max_epochs=1,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1 if torch.cuda.is_available() else None
)
student_model = DistillationModel(teacher_model, student_model)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [42]:
# Train the student model
trainer.fit(student_model, dataloader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name    | Type              | Params | Mode 
------------------------------------------------------
0 | teacher | Wav2Vec2ForCTC    | 94.4 M | eval 
1 | student | DistillationModel | 188 M  | train
------------------------------------------------------
188 M     Trainable params
0         Non-trainable params
188 M     Total params
755.171   Total estimated model params size (MB)
12        Modules in train mode
446       Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
