In [4]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchaudio
import torchaudio.transforms as T

from transformers import BertTokenizer

from torchmultimodal.modules.fusions.concat_fusion import ConcatFusionModule
from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
from torchmultimodal.models.masked_auto_encoder.model import audio_mae
from torchmultimodal.models.masked_auto_encoder.model import MAEOutput

import os
import json
import yaml
from tqdm import tqdm


In [5]:
# opening the config file and extracting the parameters
with open("cfg.yaml", "r") as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

# audio
n_mels = config["audio"]["n_mels"]
target_sample_rate = config["audio"]["target_sample_rate"]
max_time_steps = config["audio"]["max_time_steps"]
n_fft = config["audio"]["n_fft"]
hop_length = config["audio"]["hop_length"]

# text
max_len = config["text"]["max_len"]

#training
batch_size = config["training"]["batch_size"]
lr = config["training"]["lr"]
epochs = config["training"]["epochs"]

## Model  
The model is composed of a text pipeline, from Bert architecture, and an audio pipeline, audio masked autoencoder, that get fused together in aconcatenation fusion model. This is then connected with a final linear layer used as classifier.

In [6]:
class AudioTextMultimodalModel(nn.Module):
    def __init__(self, audio_encoder, text_encoder, fusion_module, output_dim):
        super().__init__()
        self.audio_encoder = audio_encoder
        self.text_encoder = text_encoder
        self.fusion_module = fusion_module

        # Calculate the input size for the classifier
        audio_output_dim = self.audio_encoder.output_dim
        text_output_dim = self.text_encoder.output_dim
        fusion_output_dim = audio_output_dim + text_output_dim

        self.classifier = nn.Linear(fusion_output_dim, output_dim)

    def forward(self, audio, input_ids, attention_mask):
        audio_output = self.audio_encoder(audio)
        # Extract the embeddings tensor from the MAEOutput object and apply pooling.
        # audio_output is an MAEOutput NamedTuple. We need the encoder_output field.
        # The encoder_output field can be a TransformerOutput or a Tensor.
        # If it's a TransformerOutput, we need its last_hidden_state.
        if isinstance(audio_output, MAEOutput):
            encoder_out = audio_output.encoder_output
            if hasattr(encoder_out, 'last_hidden_state') and isinstance(encoder_out.last_hidden_state, torch.Tensor):
                # Assuming last_hidden_state is (batch_size, num_patches, embedding_dim)
                # Apply mean pooling over the patch dimension (dim=1)
                audio_features = encoder_out.last_hidden_state.mean(dim=1)
            elif isinstance(encoder_out, torch.Tensor) and encoder_out.ndim == 3:
                 # If encoder_output itself is the tensor (batch_size, num_patches, embedding_dim)
                 audio_features = encoder_out.mean(dim=1)
            else:
                 # Fallback or error handling if the encoder_output structure is unexpected
                 print("Error: Unexpected audio encoder output structure within MAEOutput. Cannot extract features for fusion.")
                 raise TypeError("Audio encoder output within MAEOutput is not in an expected format for feature extraction and pooling.")
        elif isinstance(audio_output, torch.Tensor) and audio_output.ndim >= 2:
             # If the audio_encoder directly returned a tensor (less likely based on MAEOutput definition, but for robustness)
             if audio_output.ndim == 3: # Assuming (batch_size, sequence_length, embedding_dim)
                 audio_features = audio_output.mean(dim=1)
             else: # Assuming (batch_size, embedding_dim)
                 audio_features = audio_output
        else:
             # Fallback or error handling if the top-level output structure is unexpected
             print("Error: Unexpected top-level audio encoder output structure. Cannot extract features for fusion.")
             raise TypeError("Audio encoder output is not in an expected format for feature extraction and pooling.")
        
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state.mean(dim=1)  # Average pooling over the sequence length
        fused_features = self.fusion_module({"audio": audio_features, "text": text_features})
        return self.classifier(fused_features)

# Instantiate components
audio_encoder = audio_mae(input_size=(128, 512)) # input size should match your audio data (n_mels, max_time_steps)
# text_encoder = CLIPTextEncoder()
text_encoder = bert_text_encoder()
fusion_module = ConcatFusionModule()
model = AudioTextMultimodalModel(audio_encoder, text_encoder, fusion_module, output_dim=10)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


## Audio Preprocessing
The audio dataset is loaded, resampled to a common sampling rate, converted to mono if there is more than one channel, and then transformed in the Mel Spectrogram format, needed by the audio MAE

In [7]:
def load_audio(file_path, target_sample_rate=target_sample_rate):
    # Load the audio file
    waveform, sample_rate = torchaudio.load(file_path)
    
    # Resample if the sample rate is different from the target
    if sample_rate != target_sample_rate:
        resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    
    # Convert to mono (if stereo)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)
    
    return waveform

def preprocess_audio(file_path, target_sample_rate=16000, n_mels=128, max_time_steps=None, n_fft=1024, hop_length=512):
    # Load the audio
    waveform = load_audio(file_path, target_sample_rate)

    # Convert to mel spectrogram
    # Ensure the audio tensor is (channels, time) for MelSpectrogram
    if waveform.ndim == 1:
         waveform = waveform.unsqueeze(0) # Add channel dimension if missing

    mel_spectrogram = T.MelSpectrogram(
        sample_rate=target_sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels
    )
    mel_features = mel_spectrogram(waveform)

    # Normalize the mel spectrogram
    # Add a small epsilon for numerical stability in case of zero std
    mean = mel_features.mean()
    std = mel_features.std()
    mel_features = (mel_features - mean) / (std + 1e-5)

    # Pad or truncate to max_time_steps
    if max_time_steps is not None:
        if mel_features.shape[2] < max_time_steps:
            # Pad with zeros
            padding = max_time_steps - mel_features.shape[2]
            mel_features = torch.nn.functional.pad(mel_features, (0, padding))
        else:
            # Truncate to max_time_steps
            mel_features = mel_features[:, :, :max_time_steps]

    return mel_features

## Text Preprocessing
The text dataset is tokenized, using the Bert tokenizer, in order to generate the tokens and the attention mask needed by the Bert encoder

In [8]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenize text
def tokenize_text(text, max_len=512):
    tokens = tokenizer(
        text,
        padding="max_length",  # Pad to max_len
        truncation=True,       # Truncate if longer than max_len
        max_length=max_len,    # Maximum sequence length
        return_tensors="pt",   # Return PyTorch tensors
    )
    return tokens["input_ids"].squeeze(0), tokens["attention_mask"].squeeze(0)

## Generation of Batched Dataset
The collate method generates the dataset uniting preprocessed audio and text relative to the same label

In [9]:
def collate_fn(batch):
    audio, input_ids, attention_mask, labels = zip(*batch)
    
    # Pad audio tensors to the same length
    audio = torch.stack(audio)
    
    # Stack tokenized text and attention masks
    input_ids = torch.stack(input_ids)  # Assuming text is already tokenized and of fixed length
    attention_mask = torch.stack(attention_mask)

    # Convert labels to tensor
    labels = torch.tensor(labels)
    
    return audio, input_ids, attention_mask, labels

## Dataset Class
Used for creating a preprocessed dataset from the data available, that can be then used in the dataloader

In [10]:
class AudioTextDataset(Dataset):
    def __init__(self, audio_files, text_data, labels, audio_transform=None,
                 text_transform=None, max_len=512, max_time_steps=512,
                 target_sample_rate=16000, n_mels=128, n_fft=1024, hop_length=512):
        self.audio_files = audio_files
        self.text_data = text_data
        self.labels = labels
        self.audio_transform = audio_transform
        self.text_transform = text_transform
        self.target_sample_rate = target_sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.max_len = max_len
        self.max_time_steps = max_time_steps

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Load and preprocess audio
        audio_path = self.audio_files[idx]
        audio = preprocess_audio(audio_path, target_sample_rate=self.target_sample_rate,
                                 n_mels=self.n_mels, max_time_steps=self.max_time_steps,
                                 n_fft=self.n_fft, hop_length=self.hop_length)  # Preprocess audio
        
        # Apply additional audio transforms if provided
        if self.audio_transform:
            audio = self.audio_transform(audio)

        # Tokenize and preprocess text
        text = self.text_data[idx]
        if self.text_transform:
            input_ids, attention_mask = self.text_transform(text, max_len=self.max_len)

        label = self.labels[idx]
        return audio, input_ids, attention_mask, label

# Generation of Dataloader
The data is extracted from its location, preprocessed and inserted in a dataloader that will then be used in the training and in the evaluation

In [11]:
audio_files = []  # List of audio file paths or preprocessed tensors
text_data = []    # List of text strings
labels = []       # List of labels

with open('data/data.json', 'r') as f:
    data = json.load(f)
    for entry in data:
        audio_files.append(os.path.join('data', entry['file']))
        text_data.append(entry['text'])
        labels.append(entry['label'])

# Instantiate dataset and dataloaders
train_dataset = AudioTextDataset(audio_files, text_data, labels, audio_transform=None, text_transform=tokenize_text,
                                 max_len=max_len, max_time_steps=max_time_steps,
                                 target_sample_rate=target_sample_rate, n_mels=n_mels,
                                 n_fft=n_fft, hop_length=hop_length)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

val_dataset = AudioTextDataset(audio_files, text_data, labels, audio_transform=None, text_transform=tokenize_text,
                                 max_len=max_len, max_time_steps=max_time_steps,
                                 target_sample_rate=target_sample_rate, n_mels=n_mels,
                                 n_fft=n_fft, hop_length=hop_length)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

## Training and Evaluation

In [12]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for audio, input_ids, attention_mask, labels in tqdm(dataloader):
        audio, input_ids, attention_mask, labels = audio.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(audio, input_ids, attention_mask)
        # Compute loss
        loss = criterion(outputs, labels)
        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for audio, input_ids, attention_mask, labels in tqdm(dataloader):
            audio, input_ids, attention_mask, labels = audio.to(device), input_ids.to(device), attention_mask.to(device), labels.to(device)

            # Forward pass
            outputs = model(audio, input_ids, attention_mask)
            # Compute loss
            loss = criterion(outputs, labels)
            # Accumulate loss
            total_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total
    return total_loss / len(dataloader), accuracy

In [13]:
# Move model to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer and loss function
optimizer = Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Train and evaluate
epochs = 10
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_accuracy = evaluate(model, val_loader, criterion, device)
    print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")

Epoch 1/10


100%|██████████| 3/3 [00:01<00:00,  2.42it/s]
100%|██████████| 3/3 [00:00<00:00, 12.93it/s]


Train Loss: 4.7846, Val Loss: 0.8973, Val Accuracy: 0.3333
Epoch 2/10


100%|██████████| 3/3 [00:00<00:00,  5.81it/s]
100%|██████████| 3/3 [00:00<00:00, 12.80it/s]


Train Loss: 1.0261, Val Loss: 0.5571, Val Accuracy: 0.6667
Epoch 3/10


100%|██████████| 3/3 [00:00<00:00,  5.27it/s]
100%|██████████| 3/3 [00:00<00:00, 11.33it/s]


Train Loss: 0.6330, Val Loss: 0.6005, Val Accuracy: 0.6667
Epoch 4/10


100%|██████████| 3/3 [00:00<00:00,  5.28it/s]
100%|██████████| 3/3 [00:00<00:00, 13.23it/s]


Train Loss: 0.6329, Val Loss: 0.5620, Val Accuracy: 0.6667
Epoch 5/10


100%|██████████| 3/3 [00:00<00:00,  5.86it/s]
100%|██████████| 3/3 [00:00<00:00, 13.66it/s]


Train Loss: 0.8027, Val Loss: 0.5772, Val Accuracy: 0.6667
Epoch 6/10


100%|██████████| 3/3 [00:00<00:00,  5.70it/s]
100%|██████████| 3/3 [00:00<00:00, 15.59it/s]


Train Loss: 0.5784, Val Loss: 0.5950, Val Accuracy: 0.3333
Epoch 7/10


100%|██████████| 3/3 [00:00<00:00,  5.82it/s]
100%|██████████| 3/3 [00:00<00:00, 13.73it/s]


Train Loss: 0.6312, Val Loss: 0.2763, Val Accuracy: 1.0000
Epoch 8/10


100%|██████████| 3/3 [00:00<00:00,  5.03it/s]
100%|██████████| 3/3 [00:00<00:00, 14.85it/s]


Train Loss: 0.4023, Val Loss: 0.1761, Val Accuracy: 1.0000
Epoch 9/10


100%|██████████| 3/3 [00:00<00:00,  5.25it/s]
100%|██████████| 3/3 [00:00<00:00, 11.98it/s]


Train Loss: 0.2486, Val Loss: 0.1063, Val Accuracy: 1.0000
Epoch 10/10


100%|██████████| 3/3 [00:00<00:00,  5.88it/s]
100%|██████████| 3/3 [00:00<00:00, 12.51it/s]

Train Loss: 0.0859, Val Loss: 0.0081, Val Accuracy: 1.0000



