In [7]:
import os
os. chdir('../')

In [38]:
import torch
from torch import nn
from pytorch_lightning import LightningModule
from transformers import Wav2Vec2ForCTC, Wav2Vec2Model
from constants.mir_constants import TrainingArgs, WAV2VEC2_ARGS
from dataclasses import dataclass, asdict
import json
import argparse

In [20]:
from flash.audio import SpeechRecognitionData

In [21]:
print(json.dumps(asdict(WAV2VEC2_ARGS), indent = 4))

{
    "TRAIN_FILE_PATH": "/home/users/gmenon/notebooks/home/users/gmenon/notebooks/train_song_metadata_en_demucs_cleaned.csv",
    "TEST_FILE_PATH": "/home/users/gmenon/notebooks/home/users/gmenon/notebooks/validation_song_metadata_en_demucs_cleaned.csv",
    "MODEL_BACKBONE": "facebook/wav2vec2-large-960h-lv60-self",
    "BATCH_SIZE": 1,
    "NUM_EPOCHS": 15,
    "MODEL_SAVE_PATH": "/home/users/gmenon/workspace/songsLyricsGenerator/src/model_artefacts/wav2vec2_demucs_en_finetuned_model.pt",
    "FINETUNE_STRATEGY": "no_freeze_deepspeed",
    "LR_SCHEDULER": "reduce_on_plateau_schedule"
}


In [77]:
datamodule = SpeechRecognitionData.from_csv("consolidated_file_path",
                                                         "transcription_capitalized",
                                                         train_file=WAV2VEC2_ARGS.TRAIN_FILE_PATH,
                                                         test_file=WAV2VEC2_ARGS.TEST_FILE_PATH,
                                                         batch_size=WAV2VEC2_ARGS.BATCH_SIZE
                                                         )

  exec(code_obj, self.user_global_ns, self.user_ns)


In [78]:
datamodule

<flash.audio.speech_recognition.data.SpeechRecognitionData at 0x7ffe9a5af940>

In [79]:
import flash
import torchaudio
from pytorch_lightning import LightningDataModule

class SpeechDataModule(LightningDataModule):
    def __init__(self, audio_paths, transcript_paths, batch_size=16):
        super().__init__()
        self.batch_size = 2

    def prepare_data(self):
        # Download and extract the audio files and transcripts, if necessary.
        pass

    def setup(self, stage=None):
        # Create the dataset
        # datamodule = 
        self.dataset = SpeechRecognitionData.from_csv("consolidated_file_path",
                                                         "transcription_capitalized",
                                                         train_file=WAV2VEC2_ARGS.TRAIN_FILE_PATH,
                                                         test_file=WAV2VEC2_ARGS.TEST_FILE_PATH,
                                                         batch_size=WAV2VEC2_ARGS.BATCH_SIZE
                                                         )

        # Create the dataloader
        self.train_dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
        )
        self.val_dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
        )

    def train_dataloader(self):
        return self.train_dataloader

    def val_dataloader(self):
        return self.val_dataloader

In [71]:
SpeechDataModule

__main__.SpeechDataModule

In [84]:
import torch
from torch import nn
from pytorch_lightning import LightningModule, Trainer
from transformers import Wav2Vec2ForCTC, Wav2Vec2Model, AutoModelForSeq2SeqLM

class Wav2SeqModel(LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.wav2vec2 = Wav2Vec2ForCTC.from_pretrained(hparams.wav2vec2_model)
        self.seq2seq = AutoModelForSeq2SeqLM.from_pretrained(hparams.lm_model)

    def forward(self, audio):
        x = self.wav2vec2(audio)
        logits = self.seq2seq(x.view(x.size(0), -1))
        return logits

    def training_step(self, batch):
        audio, labels = batch
        logits = self(audio)
        loss = nn.CTCLoss(blank_id=0).forward(logits, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch):
        audio, labels = batch
        logits = self(audio)
        loss = nn.CTCLoss(blank_id=0).forward(logits, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=hparams.learning_rate)




In [85]:
import pytorch_lightning as pl

In [87]:
# class DALIDataset(pl.LightningDataModule):
#     def __init__(self, batch_size: int = 4,
#                   train_path :Optional[str] = None,
#                     validation_path: Optional[str] = None,
#                       model_backbone: pl.LightningModule = None,
#                       args: TrainingArgs = WAV2VEC2_ARGS
#                       ):
        
#         super().__init__()
#         self.train_path = train_path if validation_path is not None else args.TRAIN_FILE_PATH
#         self.validation_path = validation_path if validation_path is not None else args.TEST_FILE_PATH
#         self.model_backbone = model_backbone if model_backbone is not None else args.MODEL_BACKBONE

#         def prepare_data(self):
#             pass
        
#         def setup(self):
#             train_df = pd.read_csv(WAV2VEC2_ARGS.TRAIN_FILE_PATH) 
#             validation_df = pd.read_csv(WAV2VEC2_ARGS.TEST_FILE_PATH)
#             songs_metadata = pd.concat([train_df,validation_df], ignore_index = True)
#             audio_dataset = Dataset.from_dict(
#                 {"audio": list(songs_metadata["file_name"]),
#                  "transcription": list(songs_metadata["transcription"])}).cast_column("audio", Audio(sampling_rate=16_000))
#             audio_dataset["transcription"] = audio_dataset["transcription"] = re.sub(WAV2VEC2_ARGS.CHARS_TO_REMOVE_FROM_TRANSCRIPTS, '', audio_dataset["transcription"]).upper()
#             audio_dataset = audio_dataset.train_test_split(test_size=0.2, shuffle=True)

In [88]:
hparams = argparse.Namespace()
hparams.wav2vec2_model = 'facebook/wav2vec2-base'
hparams.lm_model = 'facebook/bart-large'
hparams.vocab_size = 10000
hparams.learning_rate = 0.001

model = Wav2SeqModel(hparams)
trainer = Trainer(max_epochs=1,devices=1, accelerator="gpu")
trainer.fit(model,datamodule.train_dataloader, datamodule.val_dataloader)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type                         | Params
----------------------------------------------------------
0 | wav2vec2 | Wav2Vec2ForCTC               | 94.4 M
1 | seq2seq  | BartForConditionalGeneration | 406 M 
----------------------------------------------------------
500 M     Trainable params
0         Non-trainable params
500 M     Total params
2,002.751 Total estimated model params size (MB)
SLURM auto-requeueing enabled. Setting signal handlers.


Sanity Checking: 0it [00:00, ?it/s]

TypeError: 'method' object is not iterable

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule

class SpeechDataset(Dataset):
    def __init__(self, audio_paths, transcript_paths):
        self.audio_paths = audio_paths
        self.transcript_paths = transcript_paths

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, index):
        audio_path = self.audio_paths[index]
        transcript_path = self.transcript_paths[index]

        audio = torch.load(audio_path)
        transcript = torch.load(transcript_path)

        return audio, transcript

class SpeechDataModule(LightningDataModule):
    def __init__(self, audio_paths, transcript_paths, batch_size=16):
        super().__init__()
        self.audio_paths = audio_paths
        self.transcript_paths = transcript_paths
        self.batch_size = batch_size

    def prepare_data(self):
        # Download and extract the dataset
        pass

    def setup(self, stage=None):
        # Create the dataset
        self.dataset = SpeechDataset(self.audio_paths, self.transcript_paths)

        # Create the dataloader
        self.train_dataloader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
        )
        self.val_dataloader = DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
        )

    def train_dataloader(self):
        return self.train_dataloader

    def val_dataloader(self):
        return self.val_dataloader