In [1]:
!pip install -qq nltk rouge transformers

In [2]:
import transformers
import torch
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.generation.logits_process import WhisperTimeStampLogitsProcessor
from transformers.models.whisper.tokenization_whisper import TASK_IDS, TO_LANGUAGE_CODE


class WhisperForAudioCaptioning(transformers.WhisperForConditionalGeneration):

    def forward(
        self,
        input_features: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        forced_ac_decoder_ids: Optional[torch.LongTensor] = None, # added to be ignored when passed from trainer
    ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
        return super().forward(
            input_features=input_features,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

    # copy-pasted and adapted from transformers.WhisperForConditionalGeneration.generate
    def generate(
        self,
        inputs: Optional[torch.Tensor] = None,
        forced_ac_decoder_ids: Optional[torch.Tensor] = None,
        generation_config=None,
        logits_processor=None,
        stopping_criteria=None,
        prefix_allowed_tokens_fn=None,
        synced_gpus=False,
        return_timestamps=True,
        task="transcribe",
        language="english",
        **kwargs,
    ):
        if generation_config is None:
            generation_config = self.generation_config

        if return_timestamps is not None:
            if not hasattr(generation_config, "no_timestamps_token_id"):
                raise ValueError(
                    "You are trying to return timestamps, but the generation config is not properly set."
                    "Make sure to initialize the generation config with the correct attributes that are needed such as `no_timestamps_token_id`."
                    "For more details on how to generate the approtiate config, refer to https://github.com/huggingface/transformers/issues/21878#issuecomment-1451902363"
                )

            generation_config.return_timestamps = return_timestamps
        else:
            generation_config.return_timestamps = False

        if language is not None:
            generation_config.language = language
        if task is not None:
            generation_config.task = task

        forced_decoder_ids = []
        if task is not None or language is not None:
            if hasattr(generation_config, "language"):
                if generation_config.language in generation_config.lang_to_id.keys():
                    language_token = generation_config.language
                elif generation_config.language in TO_LANGUAGE_CODE.keys():
                    language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
                else:
                    raise ValueError(
                        f"Unsupported language: {language}. Language should be one of:"
                        f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}."
                    )
                forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
            else:
                forced_decoder_ids.append((1, None))  # automatically detect the language

            if hasattr(generation_config, "task"):
                if generation_config.task in TASK_IDS:
                    forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
                else:
                    raise ValueError(
                        f"The `{generation_config.task}`task is not supported. The task should be one of `{TASK_IDS}`"
                    )
            else:
                forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))  # defaults to transcribe
            if hasattr(generation_config, "no_timestamps_token_id") and not generation_config.return_timestamps:
                idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
                forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))

        # Legacy code for backward compatibility
        elif hasattr(self.config, "forced_decoder_ids") and self.config.forced_decoder_ids is not None:
            forced_decoder_ids = self.config.forced_decoder_ids
        elif (
            hasattr(self.generation_config, "forced_decoder_ids")
            and self.generation_config.forced_decoder_ids is not None
        ):
            forced_decoder_ids = self.generation_config.forced_decoder_ids

        if generation_config.return_timestamps:
            logits_processor = [WhisperTimeStampLogitsProcessor(generation_config)]

        decoder_input_ids = None

        if len(forced_decoder_ids) > 0:
            # get the token sequence coded in forced_decoder_ids
            forced_decoder_ids.sort()
            if min(forced_decoder_ids)[0] != 0:
                forced_decoder_ids = [(0, self.config.decoder_start_token_id)] + forced_decoder_ids

            position_indices, decoder_input_ids = zip(*forced_decoder_ids)
            assert tuple(position_indices) == tuple(range(len(position_indices))), "forced_decoder_ids is not a (continuous) prefix, we can't handle that"

            device = self.get_decoder().device

            if forced_ac_decoder_ids is None:
                forced_ac_decoder_ids = torch.tensor([[]], device=device, dtype=torch.long)

            # enrich every sample's forced_ac_decoder_ids with Whisper's forced_decoder_ids
            batch_size = forced_ac_decoder_ids.shape[0]
            fluff_len = len(decoder_input_ids)
            decoder_input_ids = torch.tensor(decoder_input_ids, device=device, dtype=torch.long)
            decoder_input_ids = decoder_input_ids.expand((batch_size, fluff_len))
            decoder_input_ids = torch.cat([decoder_input_ids, forced_ac_decoder_ids], dim=1)

            generation_config.forced_decoder_ids = forced_decoder_ids

        return super(transformers.WhisperPreTrainedModel, self).generate(
            inputs,
            generation_config,
            logits_processor,
            stopping_criteria,
            prefix_allowed_tokens_fn,
            synced_gpus,
            decoder_input_ids=decoder_input_ids,
            **kwargs,
        )


In [3]:
import nltk
nltk.download('punkt')
nltk.download('wordnet')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


True

In [4]:
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge
from nltk.translate.meteor_score import meteor_score
from nltk.translate.bleu_score import SmoothingFunction
from nltk import word_tokenize

def preprocess_caption(caption):
    """Remove punctuation and convert to lowercase."""
    return ''.join([c.lower() for c in caption if c.isalpha() or c.isspace()])

def compute_metrics(reference_captions, generated_caption):
    metrics = {}

    # Preprocess and tokenize captions
    generated_caption = preprocess_caption(generated_caption)
    reference_captions = [preprocess_caption(ref) for ref in reference_captions]
    gen_words = word_tokenize(generated_caption)
    ref_words = [word_tokenize(ref) for ref in reference_captions]

    # Apply smoothing for BLEU scores
    chencherry = SmoothingFunction()

    # Compute BLEU scores with smoothing
    for i in range(1, 5):
        weights = [1/i] * i
        metrics[f'BLEU{i}'] = sentence_bleu(ref_words, gen_words, weights=weights, smoothing_function=chencherry.method1)

    # Compute ROUGE-L
    rouge = Rouge()
    scores = rouge.get_scores(' '.join(gen_words), ' '.join([' '.join(ref) for ref in ref_words]))
    metrics['ROUGEL'] = scores[0]['rouge-l']['f']

    # Compute METEOR (using tokenized hypothesis)
    metrics['METEOR'] = meteor_score(ref_words, gen_words)

    return metrics

In [5]:
# Load the pre-trained Whisper Medium Audio Captioning model
checkpoint = "MU-NLPC/whisper-small-audio-captioning"
hugging_face_token = ""

model = WhisperForAudioCaptioning.from_pretrained(checkpoint,token=hugging_face_token)
tokenizer = transformers.WhisperTokenizer.from_pretrained(checkpoint, language="en", task="caption", predict_timestamps=True)
feature_extractor = transformers.WhisperFeatureExtractor.from_pretrained(checkpoint)

config.json:   0%|          | 0.00/2.25k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/967M [00:00<?, ?B/s]



generation_config.json:   0%|          | 0.00/3.62k [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/800 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.08k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


preprocessor_config.json:   0%|          | 0.00/339 [00:00<?, ?B/s]

In [6]:
import librosa

def generate_caption(audio_file):
    audio, sampling_rate = librosa.load(audio_file, sr=feature_extractor.sampling_rate)
    features = feature_extractor(audio, sampling_rate=sampling_rate, return_tensors="pt").input_features
    # Prepare the caption style
    style_prefix = "clotho > caption: "
    #style_prefix = " "
    style_prefix_tokens = tokenizer("", text_target=style_prefix, return_tensors="pt", add_special_tokens=False).labels

    model.eval()
    outputs = model.generate(
        inputs=features.to(model.device),
        forced_ac_decoder_ids=style_prefix_tokens,
        max_length=300,
        num_beams=4,
        return_timestamps=True,
        early_stopping=True
    )
    caption = tokenizer.decode(outputs[0], decode_with_timestamps=True, skip_special_tokens=True)

    return caption

In [7]:
import pandas as pd
import soundfile as sf
from transformers import WhisperFeatureExtractor

# Load the metadata and captions
metadata_path = 'drive/MyDrive/metrics/clotho_metadata_evaluation.csv'
captions_path = 'drive/MyDrive/metrics/clotho_captions_evaluation.csv'

clotho_metadata = pd.read_csv(metadata_path, sep=',')
clotho_captions = pd.read_csv(captions_path)

# Merge the metadata with captions
# This step assumes that the file names match exactly in both datasets.
# If not, you might need additional preprocessing to match the file names.
combined_data = clotho_metadata.merge(clotho_captions, on='file_name')

# Create the metrics dictionary for the overall mean
cumulative_metrics = {
    'BLEU1': 0, 'BLEU2': 0, 'BLEU3': 0, 'BLEU4': 0,
    'ROUGEL': 0, 'METEOR': 0
}
num_files = 0

# Function to load and preprocess audio
def preprocess_audio(file_name):
    audio, sample_rate = sf.read(file_name)
    return feature_extractor(audio, sampling_rate=16000)

# Iterate over the dataset and generate captions
for index, row in combined_data.iterrows():
    audio_file = row['file_name']
    # Assuming audio files are in a directory named 'audio_files'
    audio_path = f'drive/MyDrive/metrics/evaluation/{audio_file}'

    # Preprocess the audio
    input_features = preprocess_audio(audio_path)

    # Generate caption
    generated_caption = generate_caption(audio_path)

    # Extract reference captions
    reference_captions = [row[f'caption_{i}'] for i in range(1, 6)]

    # Compute metrics
    metrics = compute_metrics(reference_captions, generated_caption)

    # Aggregate the metrics
    for key in cumulative_metrics:
        cumulative_metrics[key] += metrics[key]
    num_files += 1

    # Print or store the metrics for analysis
    print(f"Testing on file {num_files}")

# Calculate the mean metrics
mean_metrics = {key: val / num_files for key, val in cumulative_metrics.items()}

# Print the mean metrics
print("Mean Metrics for all audio files:", mean_metrics)

Testing on file 1
Testing on file 2
Testing on file 3
Testing on file 4
Testing on file 5
Testing on file 6
Testing on file 7
Testing on file 8
Testing on file 9
Testing on file 10
Testing on file 11
Testing on file 12
Testing on file 13
Testing on file 14
Testing on file 15
Testing on file 16
Testing on file 17
Testing on file 18
Testing on file 19
Testing on file 20
Testing on file 21
Testing on file 22
Testing on file 23
Testing on file 24
Testing on file 25
Testing on file 26
Testing on file 27
Testing on file 28
Testing on file 29
Testing on file 30
Testing on file 31
Testing on file 32
Testing on file 33
Testing on file 34
Testing on file 35
Testing on file 36
Testing on file 37
Testing on file 38
Testing on file 39
Testing on file 40
Testing on file 41
Testing on file 42
Testing on file 43
Testing on file 44
Testing on file 45
Testing on file 46
Testing on file 47
Testing on file 48
Testing on file 49
Testing on file 50
Testing on file 51
Testing on file 52
Testing on file 53
Te