In [None]:
# Self-Supervised Learning (SSL) Training with SpecAugment

import torch
import torchaudio
import torchaudio.transforms as T
from transformers import Wav2Vec2Model, Wav2Vec2Processor, BertTokenizer, BertForMaskedLM
from torch.utils.data import DataLoader, Dataset
import os
import librosa
import numpy as np
import pandas as pd

# Audio Dataset with SpecAugment
class AudioDataset(Dataset):
    def __init__(self, data_dir, augment=True):
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
        self.processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base')
        self.augment = augment
    
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        audio, sr = librosa.load(self.files[idx], sr=16000)
        if self.augment:
            audio = self.apply_spec_augment(audio, sr)
        input_values = self.processor(audio, return_tensors="pt", sampling_rate=16000).input_values
        return input_values.squeeze(0)
    
    def apply_spec_augment(self, audio, sr):
        mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=128)
        mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
        time_masking = T.TimeMasking(time_mask_param=80)
        freq_masking = T.FrequencyMasking(freq_mask_param=30)
        mel_spec = time_masking(torch.tensor(mel_spec))
        mel_spec = freq_masking(mel_spec)
        return librosa.feature.inverse.mel_to_audio(mel_spec.numpy(), sr=sr)

# SSL Audio Training

def train_ssl_audio():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base').to(device)
    dataset = AudioDataset("data/processed")
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

    model.train()
    for epoch in range(10):
        for batch in dataloader:
            batch = batch.to(device)
            outputs = model(batch, output_hidden_states=True).hidden_states[-1]
            loss = (outputs ** 2).mean()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()
        print(f"Epoch {epoch+1}: Loss = {loss.item()}")
    
    torch.save(model.state_dict(), "models/ssl_audio_model.pth")
    print("SSL Audio training completed.")

if __name__ == "__main__":
    train_ssl_audio()

# Text Dataset for SSL
class TextDataset(Dataset):
    def __init__(self, data_path):
        self.data = pd.read_csv(data_path)["transcript"].dropna().tolist()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        inputs = self.tokenizer(self.data[idx], return_tensors="pt", padding=True, truncation=True, max_length=512)
        return inputs.input_ids.squeeze(0), inputs.attention_mask.squeeze(0)

# SSL Text Training
def train_ssl_text():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BertForMaskedLM.from_pretrained('bert-base-uncased').to(device)
    dataset = TextDataset("data/processed/transcripts.csv")
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

    model.train()
    for epoch in range(10):
        for input_ids, attention_mask in dataloader:
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()
        print(f"Epoch {epoch+1}: Loss = {loss.item()}")
    
    torch.save(model.state_dict(), "models/ssl_text_model.pth")
    print("SSL Text training completed.")

if __name__ == "__main__":
    train_ssl_text()
