In [3]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
from google.cloud import storage

In [79]:
weights_path = "../data/weights_30.pth"

In [80]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [81]:
class AudioToTextSmallModel(nn.Module):
    def __init__(self):
        super(AudioToTextSmallModel, self).__init__()
        # Initialize T5 model and tokenizer
        self.t5 = T5ForConditionalGeneration.from_pretrained("t5-small")

    def forward(self, audio_embeddings, labels=None):
        # Ensure correct shape for inputs_embeds: (batch_size, seq_length, embedding_dim)
        # T5 expects the shape (batch_size, seq_length, embedding_dim)
        projected_embeddings = audio_embeddings.unsqueeze(1)  # Add seq_length dimension (usually 1 for this case)

        # Generate outputs with T5
        outputs = self.t5(
            inputs_embeds=projected_embeddings,
            labels=labels
        )
        return outputs
    
# class AudioToTextBaseModel(nn.Module):
#     def __init__(self):
#         super(AudioToTextBaseModel, self).__init__()
#         # Initialize T5 model and tokenizer with t5-large
#         self.t5 = T5ForConditionalGeneration.from_pretrained("t5-base")
#         # Linear layer to project 512-dimensional CLAP embeddings to 1024-dimensional embeddings
#         self.projection_layer = nn.Linear(512, 768)

#     def forward(self, audio_embeddings, labels=None):
#         # Project audio embeddings from 512 to 1024 dimensions
#         projected_embeddings = self.projection_layer(audio_embeddings)
        
#         # Add seq_length dimension (usually 1 for this case)
#         projected_embeddings = projected_embeddings.unsqueeze(1)

#         # Generate outputs with T5
#         outputs = self.t5(
#             inputs_embeds=projected_embeddings,
#             labels=labels
#         )
#         return outputs

tokenizer = T5Tokenizer.from_pretrained("t5-small")
    
# Initialize the model and tokenizer
model = AudioToTextSmallModel().to(device)  # Move the model to GPU

In [82]:
model.load_state_dict(torch.load(weights_path))

# Step 4: Set the model to evaluation mode (if you only need to do inference)
model.eval()

  model.load_state_dict(torch.load(weights_path))


AudioToTextSmallModel(
  (t5): T5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=512, bias=False)
                (k): Linear(in_features=512, out_features=512, bias=False)
                (v): Linear(in_features=512, out_features=512, bias=False)
                (o): Linear(in_features=512, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 8)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=512, out_features=2048, bias=False)
                (wo): Linear(in_f

In [83]:
# Load the training data
train_data = torch.load('../data/train_data.pt')
test_data = torch.load('../data/test_data.pt')

  train_data = torch.load('../data/train_data.pt')
  test_data = torch.load('../data/test_data.pt')


In [84]:
train_embeddings = torch.tensor(np.array(train_data["embeddings"])).to(device)  # Move to GPU
train_labels = [str(label) for label in train_data["labels"]]

test_embeddings = torch.tensor(np.array(test_data["embeddings"])).to(device)  # Move to GPU
test_labels = [str(label) for label in test_data["labels"]]

In [85]:
# Ensure all labels are strings
for label in train_labels:
    if label is None or not isinstance(label, str):
        print("Label has an error or is not a string")

In [86]:
# Tokenize the labels (convert them into token IDs) just once
train_tokenized_labels = tokenizer(train_labels, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)  # Move to GPU

test_tokenized_labels = tokenizer(test_labels, padding=True, truncation=True, return_tensors="pt").input_ids.to(device)  # Move to GPU


In [87]:
# Create a DataLoader for your train data
train_dataset = TensorDataset(train_embeddings, train_tokenized_labels)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)

test_dataset = TensorDataset(test_embeddings, test_tokenized_labels)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [88]:
optimizer = optim.AdamW(model.parameters(), lr=1e-5)  # You can adjust the learning rate

In [89]:
def evaluate_final_loss(model, data_loader):
    total_loss = 0
    for i, batch in enumerate(data_loader):
        audio_embeddings, labels = batch

        # Move data to GPU
        audio_embeddings = audio_embeddings.to(device)
        labels = labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(audio_embeddings, labels=labels)

        # Calculate loss
        loss = outputs.loss
        total_loss += loss.item()

    # Calculate and print the loss for this epoch
    avg_loss = total_loss / len(data_loader)
    return avg_loss

In [90]:
evaluate_final_loss(model, train_loader)

0.9895334436864609

In [91]:
evaluate_final_loss(model, test_loader)

1.0829760775634936

In [92]:
model.eval()

AudioToTextSmallModel(
  (t5): T5ForConditionalGeneration(
    (shared): Embedding(32128, 512)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 512)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=512, out_features=512, bias=False)
                (k): Linear(in_features=512, out_features=512, bias=False)
                (v): Linear(in_features=512, out_features=512, bias=False)
                (o): Linear(in_features=512, out_features=512, bias=False)
                (relative_attention_bias): Embedding(32, 8)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseActDense(
                (wi): Linear(in_features=512, out_features=2048, bias=False)
                (wo): Linear(in_f

In [102]:
def inference(example_embedding):
    with torch.no_grad():
        generated_ids = model.t5.generate(
            inputs_embeds=example_embedding.view(1, 1, 512),
            max_length=100,  # Adjust as needed
            early_stopping=True
        )
    return tokenizer.decode(generated_ids[0], skip_special_tokens=True)

In [103]:
def run_predictions(model, data_loader):
    pred_text = []
    true_text = []
    for i, batch in enumerate(data_loader):
        audio_embeddings, labels = batch
        pred = inference(audio_embeddings[0])
        true = tokenizer.decode(labels[0], skip_special_tokens=True)
        pred_text.append(pred)
        true_text.append(true)

    # Calculate and print the loss for this epoch
    return true_text, pred_text

In [None]:
train_true, train_pred = run_predictions(model, test_loader)
for i in range(5):
    print(train_true[i])
    print(train_pred[i])

In [None]:
evaluate_scores(model, train_loader)

In [98]:
print(test_embeddings[0] == test_embeddings[1])

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [99]:
inference(torch.randn(512,).to(device))

torch.Size([512])


'<pad>The low quality recording features a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a snare sound effect, a s'

In [20]:
train_data["filenames"][:5]

['MrMXYO2fzJ4', 'OPX9ukYun3o', 'aHZdDmYFZN0', 'H_He9_zHk8I', 'm7i4g_o-znQ']