# Imports

In [1]:
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torch.nn as nn
import nemo.collections.asr as nemo_asr
import json
from torch.utils.data import Dataset,DataLoader
from typing import Optional

      def forward(
    
      def backward(ctx, grad_output):
    
      def forward(
    
      def backward(ctx, grad_output):
    


# Utils

## Content Embeddings

In [2]:
def extract_content_embeddings(mel_spectrogram_tensor, encoder,device):
    mel_tensor = mel_spectrogram_tensor.to(device)
    lengths = torch.full((mel_tensor.shape[0],), mel_tensor.shape[2], dtype=torch.int32).to(device)

    # # Ensure the input tensor has the correct shape
    # if len(mel_tensor.shape) == 2:
    #     # Adding batch dimension (1, mel_channels, time_frames)
    #     mel_tensor = mel_tensor.unsqueeze(0)

    encoder.eval()
    with torch.no_grad():
        # content_embeddings = encoder(mel_tensor)
        content_embeddings, _ = encoder(
            audio_signal=mel_tensor,
            length=lengths
        )        

    return content_embeddings 

## Duration Augmented Content Embeddings 

In [3]:
def cosine_similarity(v1, v2):
    return F.cosine_similarity(v1.unsqueeze(0), v2.unsqueeze(0), dim=-1).item()

def group_similar_vectors(vectors, threshold, vector_duration):
    grouped_vectors = []
    durations = []

    current_group = [vectors[:, 0]]  # Start with the first vector (slice across batch dimension)
    current_duration = vector_duration

    for i in range(1, vectors.shape[-1]):
        sim = cosine_similarity(vectors[:, i], vectors[:, i - 1])
        # print(sim)
        # If cosine similarity is above threshold, group the vectors
        if sim > threshold:
            # print("Exceeded Threshold")
            current_group.append(vectors[:, i])
            current_duration += vector_duration
        else:
            # Compute the average of the current group and save it
            averaged_vector = torch.mean(torch.stack(current_group, dim=-1), dim=-1)
            grouped_vectors.append(averaged_vector)
            durations.append(current_duration)

            # Start a new group
            current_group = [vectors[:, i]]
            current_duration = vector_duration

    # Append the last group
    if current_group:
        averaged_vector = torch.mean(torch.stack(current_group, dim=-1), dim=-1)
        grouped_vectors.append(averaged_vector)
        durations.append(current_duration)

    return torch.stack(grouped_vectors, dim=-1), durations


def duration_augmented_representation(content_embeddings, T=0.925, vector_duration=46.44):
    # Remove the batch dimension for processing
    content_vectors = content_embeddings.squeeze(0)  # Shape: [256, num_content_vectors]

    grouped_vectors, new_durations = group_similar_vectors(content_vectors, T, vector_duration)

    # Convert durations to seconds if needed
    new_durations_in_seconds = [d / 1000 for d in new_durations]

    # Add the batch dimension back to the output
    grouped_vectors = grouped_vectors.unsqueeze(0)  # Shape: [1, 256, num_grouped_vectors]

    return grouped_vectors, new_durations_in_seconds

## Pitch Contour Feature Extraction

In [4]:
def extract_f0(audio_path, sr=22050):
    """
    Extract the fundamental frequency (F0) contour using PYin algorithm.
    
    Args:
    - audio_path (str): Path to the audio file.
    - sr (int): Sampling rate. Default is 22050.
    
    Returns:
    - f0_tensor (torch.Tensor): Extracted F0 contour as a PyTorch tensor.
    - pitch_tensor (torch.Tensor): Extracted pitch contour as a PyTorch tensor.
    """
    # Load the audio file
    y, sr = librosa.load(audio_path, sr=sr)

    # Compute the F0 (fundamental frequency) using PYin
    f0_contour, voiced_flag, voiced_probs = librosa.pyin(
        y, 
        fmin=librosa.note_to_hz('C2'),  # Minimum pitch (in Hz)
        fmax=librosa.note_to_hz('C7')   # Maximum pitch (in Hz)
    )
    
    # Replace unvoiced frames (None) with zeros or some placeholder
    f0_contour = np.nan_to_num(f0_contour)
    # print("reahced here")
    # Convert to PyTorch tensor
    f0_tensor = torch.tensor(f0_contour, dtype=torch.float32)
    # print("reahced here 1")
    # Optionally, if you want to extract a pitch contour (e.g., using voiced probabilities)
    # pitch_tensor = torch.tensor(voiced_probs, dtype=torch.float32)  # Using voiced probabilities as pitch contour

    return f0_tensor


# Function to normalize the F0 contour
def normalize_f0(f0_contour):

    # Filter out unvoiced (NaN or 0 values)
    # voiced_f0 = f0_contour[f0_contour > 0]  # Exclude unvoiced frames
    
    # Compute mean and standard deviation only for voiced frames
    mean_f0 = np.mean(f0_contour)
    std_f0 = np.std(f0_contour)
    
    # Normalize F0 contour (keep NaNs for unvoiced frames)
    normalized_f0 = (f0_contour - mean_f0) / std_f0

    # Optionally: Clip or apply ReLU to remove negative values (if required)
    # normalized_f0 = np.clip(normalized_f0, 0, None)
    
    return normalized_f0


## Speaker Embedding 

In [None]:
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large')

# Dataset

In [13]:
class SynthesizerNoTransformDataset(Dataset):
    def __init__(self, data,root_dir,content_encoder,speaker_encoder, device):
        # Load the JSON file
        # with open(json_file, 'r') as f:
        self.data = data
        self.content_encoder = content_encoder
        self.speaker_encoder = speaker_encoder
        self.device = device 
        self.root_dir = root_dir           
    
    def __len__(self):
        # Return the number of entries in the dataset
        return len(self.data)
    
    def extract_content_embeddings(self, mel_spectrogram_tensor):
        mel_tensor = mel_spectrogram_tensor.unsqueeze(0).to(self.device)  # Add batch dimension
        lengths = torch.full((mel_tensor.shape[0],), mel_tensor.shape[2], dtype=torch.int32).to(self.device)

        self.content_encoder.eval()
        with torch.no_grad():
            content_embeddings, _ = self.content_encoder(
                audio_signal=mel_tensor,
                length=lengths
            )

        return content_embeddings.cpu()   
    
    def extract_speaker_embeddings(self, audio_path):
        self.speaker_encoder.eval()
        with torch.no_grad():
            speaker_embedding = self.speaker_encoder.get_embedding(audio_path).cpu()
        return speaker_embedding      
    
    def __getitem__(self, idx):
        mel_path =  self.root_dir + "/mel_spectrograms/" + self.data[idx]['mel_filepath']
        audio_path = self.root_dir + "/audios/" + self.data[idx]['audio_filepath']
        # print(audio_path)
        mel_spectrogram = torch.from_numpy(np.load(mel_path))
        content_embeddings = self.extract_content_embeddings(mel_spectrogram) 
        speaker_embeddings = self.extract_speaker_embeddings(audio_path)  
        # duration_augmented_content_embeddings, durations = self.duration_augmented_representation(content_embeddings)
        # print(duration_augmented_content_embeddings.shape)
        # print("Reached herte amin")
        # normalized_pitch_contour =self.normalize_f0(audio_path)
        # print(normalized_pitch_contour.shape)
        
        return content_embeddings.squeeze(),speaker_embeddings.squeeze(),audio_path
    # ,duration_augmented_content_embeddings
    # ,normalized_pitch_contour

In [7]:
# def collate_fn(batch):
#     """
#     Pads the content embeddings, speaker embeddings, and duration-augmented content embeddings in the batch to the same length.
    
#     Args:
#     - batch (list of tuples): Each tuple contains content_embeddings, speaker_embeddings, duration_augmented_content_embeddings.
    
#     Returns:
#     - padded_content_embeddings (torch.Tensor): Padded content embeddings.
#     - speaker_batch (torch.Tensor): Speaker embeddings.
#     - padded_duration_augmented_embeddings (torch.Tensor): Padded duration-augmented content embeddings.
#     """
#     # Unzip the batch into separate components
#     content_embeddings, speaker_embeddings,audio_path = zip(*batch)
    
#     # Find the maximum length in the batch for padding
#     max_len_content = max([embedding.size(-1) for embedding in content_embeddings])
#     # max_len_duration = max([embedding.size(-1) for embedding in duration_augmented_content_embeddings])
    
#     # Pad the content embeddings to the maximum length
#     padded_content_embeddings = [F.pad(embedding, (0, max_len_content - embedding.size(-1))) for embedding in content_embeddings]
    
#     # Pad the duration-augmented content embeddings to the maximum length
#     # padded_duration_augmented_embeddings = [F.pad(embedding, (0, max_len_duration - embedding.size(-1))) for embedding in duration_augmented_content_embeddings]

#     # Stack the speaker embeddings (assuming they are already of fixed size)
#     speaker_batch = torch.stack(speaker_embeddings)
#     # duration_augmented_content_batch = torch.stack(duration_augmented_content_embeddings)
#     # normalized_pitch_contour_batch = torch.stack(normalized_pitch_contour)
    
#     # Stack the padded embeddings
#     padded_content_embeddings = torch.stack(padded_content_embeddings)
#     # f0_tensors_padded = pad_sequence(normalized_pitch_contour, batch_first=True, padding_value=0.0)
#     # padded_duration_augmented_embeddings = torch.stack(padded_duration_augmented_embeddings)

#     return padded_content_embeddings, speaker_batch,audio_path

In [14]:
def collate_fn(batch):
    """
    Pads the content embeddings, speaker embeddings, and duration-augmented content embeddings in the batch to the same length.
    
    Args:
    - batch (list of tuples): Each tuple contains content_embeddings, speaker_embeddings, audio_path.
    
    Returns:
    - padded_content_embeddings (torch.Tensor): Padded content embeddings.
    - speaker_batch (torch.Tensor): Speaker embeddings.
    - audio_path (list): Audio paths from the batch.
    - padding_mask (torch.Tensor): Mask indicating valid entries.
    """
    # Unzip the batch into separate components
    content_embeddings, speaker_embeddings, audio_path = zip(*batch)
    
    # Find the maximum length in the batch for padding
    max_len_content = max([embedding.size(-1) for embedding in content_embeddings])
    
    # Pad the content embeddings to the maximum length
    padded_content_embeddings = [F.pad(embedding, (0, max_len_content - embedding.size(-1))) for embedding in content_embeddings]
    
    # Stack the speaker embeddings (assuming they are already of fixed size)
    speaker_batch = torch.stack(speaker_embeddings)
    
    # Stack the padded content embeddings
    padded_content_embeddings = torch.stack(padded_content_embeddings)

    # Create a padding mask: 1 for valid entries, 0 for padding
    padding_mask = (padded_content_embeddings != 0).float()  # Assuming 0 is the padding value

    return padded_content_embeddings, speaker_batch, audio_path, padding_mask

In [16]:
data_path = "/media/keagan/hdd/project_data/SelfVC/data/val_filelist_clean.json"
with open(data_path, 'r') as file:
    data = json.load(file)



In [6]:
class ConformerEncoder256(nn.Module):
    def __init__(self, original_encoder):
        super(ConformerEncoder256, self).__init__()
        self.encoder = original_encoder
        # Add a 1D convolution with stride to downsample to 256
        self.downsample = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1, stride=2) # Example

    def forward(self, audio_signal, length):
        # Pass through original encoder
        embeddings, lengths = self.encoder(audio_signal=audio_signal, length=length)
        # Reduce the embedding size
        embeddings = self.downsample(embeddings)
        return embeddings, lengths

In [7]:
conformer_encoder_256 = torch.load('/home/keagan/Documents/projects/SelfVC/models/conformer_encoder_v2.pth')
speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large')

      conformer_encoder_256 = torch.load('/home/keagan/Documents/projects/SelfVC/models/conformer_encoder_v2.pth')
    


[NeMo I 2024-10-21 11:52:21 cloud:58] Found existing object /home/keagan/.cache/torch/NeMo/NeMo_1.23.0/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.
[NeMo I 2024-10-21 11:52:21 cloud:64] Re-using file from: /home/keagan/.cache/torch/NeMo/NeMo_1.23.0/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo
[NeMo I 2024-10-21 11:52:21 common:924] Instantiating model from pre-trained checkpoint


[NeMo W 2024-10-21 11:52:21 modelPT:165] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json
    sample_rate: 16000
    labels: null
    batch_size: 64
    shuffle: true
    is_tarred: false
    tarred_audio_filepaths: null
    tarred_shard_strategy: scatter
    augmentor:
      noise:
        manifest_path: /manifests/noise/rir_noise_manifest.json
        prob: 0.5
        min_snr_db: 0
        max_snr_db: 15
      speed:
        prob: 0.5
        sr: 16000
        resample_type: kaiser_fast
        min_speed_rate: 0.95
        max_speed_rate: 1.05
    num_workers: 15
    pin_memory: true
    
[NeMo W 2024-10-21 11:52:21 modelPT:172] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method 

[NeMo I 2024-10-21 11:52:21 features:289] PADDING: 16


      return torch.load(model_weights, map_location='cpu')
    


[NeMo I 2024-10-21 11:52:22 save_restore_connector:249] Model EncDecSpeakerLabelModel was successfully restored from /home/keagan/.cache/torch/NeMo/NeMo_1.23.0/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.


In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root_dir = "/media/keagan/hdd/project_data/SelfVC/data"
dataset = SynthesizerNoTransformDataset(data,root_dir,conformer_encoder_256,speaker_model,device)

In [18]:
batch_size = 4  # Adjust as needed
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn)

In [19]:
for i, (content_embeddings,speaker_embeddings,audio_path,padding_mask) in enumerate(dataloader):
    # duration_augmented_content_embeddings
    # normalized_pitch_contour
    print(f"Batch {i+1}:")
    print(f"content_embeddings Shape: {content_embeddings.shape}")
    print(f"speaker_embeddings Shape: {speaker_embeddings.shape}")
    print(f"audio_path: {audio_path}")
    print(f"padding_mask: {padding_mask.shape}")
    # print(f"duration_augmented_content_embeddings Shape: {duration_augmented_content_embeddings.shape}")
    # print(f"normalized_pitch_contour Shape: {normalized_pitch_contour.shape}")
    # Optionally, break after first batch for testing
    if i == 0:
        break

      with torch.cuda.amp.autocast(enabled=False):
    
      with torch.cuda.amp.autocast(enabled=False):
    


Batch 1:
content_embeddings Shape: torch.Size([4, 256, 199])
speaker_embeddings Shape: torch.Size([4, 192])
audio_path: ('/media/keagan/hdd/project_data/SelfVC/data/audios/2787_157400_000050_000000.wav', '/media/keagan/hdd/project_data/SelfVC/data/audios/6918_61317_000026_000000.wav', '/media/keagan/hdd/project_data/SelfVC/data/audios/2960_155152_000016_000003.wav', '/media/keagan/hdd/project_data/SelfVC/data/audios/3307_145145_000050_000005.wav')
padding_mask: torch.Size([4, 256, 199])


In [41]:
padding_mask.shape

torch.Size([4, 256, 199])

In [100]:
f0_tensor = extract_f0("/home/keagan/Documents/projects/SelfVC/data/audios/4957_36386_000058_000002.wav", sr=22050)

reahced here
reahced here 1


In [101]:
f0_tensor

tensor([  0.0000,   0.0000,   0.0000,   0.0000,   0.0000, 174.6141, 169.6432,
        167.6947, 165.7685, 161.0494, 158.2827, 169.6432, 220.0000, 245.5194,
        261.6256, 278.7883, 285.3047, 283.6615, 257.1310, 216.2205, 258.6205,
        245.5194, 233.0819, 223.8455, 220.0000, 220.0000, 222.5563, 225.1423,
        248.3722, 246.9417, 248.3722, 252.7136, 260.1187, 264.6655, 260.1187,
          0.0000,   0.0000, 220.0000, 211.2820, 204.0850, 194.8689, 190.4180,
        191.5211, 197.1331, 197.1331, 186.0689, 211.2820, 210.0652, 204.0850,
        202.9096, 202.9096, 207.6523, 207.6523,   0.0000,   0.0000, 248.3722,
        249.8110, 252.7136, 261.6256, 269.2918, 277.1826, 283.6615,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
          0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.00

In [97]:
def normalize_f0(audio_path, sr=22050):
    """
    Normalize the F0 contour.
    
    Args:
    - f0_tensor (torch.Tensor): F0 contour as a PyTorch tensor.
    
    Returns:
    - normalized_f0 (torch.Tensor): Normalized F0 contour.
    """
    f0_tensor = extract_f0(audio_path, sr=22050)
    # Filter out unvoiced (NaN or 0 values)
    # voiced_f0 = f0_tensor[f0_tensor > 0]  # Exclude unvoiced frames
    
    # Compute mean and standard deviation only for voiced frames
    mean_f0 = f0_tensor.mean()
    std_f0 = f0_tensor.std()
    
    # Normalize F0 contour (keep NaNs for unvoiced frames)
    normalized_f0 = (f0_tensor - mean_f0) / std_f0

    return normalized_f0     

In [99]:
f0_tensor_norm = normalize_f0("/home/keagan/Documents/projects/SelfVC/data/audios/4957_36386_000058_000002.wav", sr=22050)

reahced here
reahced here 1


In [102]:
f0_tensor_norm

tensor([-1.1538, -1.1538, -1.1538, -1.1538, -1.1538,  0.3828,  0.3390,  0.3219,
         0.3049,  0.2634,  0.2391,  0.3390,  0.7822,  1.0067,  1.1485,  1.2995,
         1.3568,  1.3424,  1.1089,  0.7489,  1.1220,  1.0067,  0.8973,  0.8160,
         0.7822,  0.7822,  0.8047,  0.8274,  1.0318,  1.0193,  1.0318,  1.0700,
         1.1352,  1.1752,  1.1352, -1.1538, -1.1538,  0.7822,  0.7055,  0.6421,
         0.5610,  0.5218,  0.5316,  0.5809,  0.5809,  0.4836,  0.7055,  0.6947,
         0.6421,  0.6318,  0.6318,  0.6735,  0.6735, -1.1538, -1.1538,  1.0318,
         1.0445,  1.0700,  1.1485,  1.2159,  1.2854,  1.3424, -1.1538, -1.1538,
        -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538,
        -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538,
        -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538, -1.1538,
        -1.1538, -1.1538, -1.1538, -1.1538,  0.6841,  0.8855,  1.0445,  1.1485,
         1.1752,  1.1887,  1.2023, -1.15

## MODEL

In [89]:
# class PositionwiseConvFF(nn.Module):
#     def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
#         super(PositionwiseConvFF, self).__init__()
#         self.d_model = d_model
#         self.d_inner = d_inner
#         self.dropout = dropout

#         self.CoreNet = nn.Sequential(
#             nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
#             nn.ReLU(),
#             nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
#             nn.Dropout(dropout),
#         )
#         self.layer_norm = nn.LayerNorm(d_model)
#         self.pre_lnorm = pre_lnorm

#     def forward(self, inp):
#         if self.pre_lnorm:
#             core_out = inp.transpose(1, 2)
#             core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
#             core_out = core_out.transpose(1, 2)
#             output = core_out + inp
#         else:
#             core_out = inp.transpose(1, 2)
#             core_out = self.CoreNet(core_out)
#             core_out = core_out.transpose(1, 2)
#             output = self.layer_norm(inp + core_out).to(inp.dtype)
#         return output


# class MultiHeadAttn(nn.Module):
#     def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1, pre_lnorm=False):
#         super(MultiHeadAttn, self).__init__()
#         self.n_head = n_head
#         self.d_model = d_model
#         self.d_head = d_head
#         self.scale = 1 / (d_head ** 0.5)
#         self.pre_lnorm = pre_lnorm

#         self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
#         # self.qkv_net = nn.Linear(d_model, n_head * d_head)
#         self.drop = nn.Dropout(dropout)
#         self.dropatt = nn.Dropout(dropatt)
#         self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
#         self.layer_norm = nn.LayerNorm(d_model)

#     def forward(self, inp, attn_mask=None):
#         # print(inp.shape)
#         residual = inp
#         if self.pre_lnorm:
#             inp = self.layer_norm(inp)

#         n_head, d_head = self.n_head, self.d_head
#         # head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
#         #b_size,512,time_steps
#         inp_permuted = inp.permute(0, 2, 1)
#         #b_size,time_steps,512
#         if not None:
#             attn_mask = attn_mask.permute(0, 2, 1)
#             #b_size,time_steps,512
#         # print(inp_permuted.shape)
#         qkv_out = self.qkv_net(inp_permuted)
#         #b_size,time_steps,192
#         # print(qkv_out.shape)
#         head_q, head_k, head_v = torch.chunk(qkv_out, 3, dim=2)
#         #b_size,time_steps,64,b_size,time_steps,64,b_size,time_steps,64
#         # print(head_q.shape,head_k.shape,head_v.shape)
#         head_q = head_q.view(inp_permuted.size(0), inp_permuted.size(1), n_head, d_head)
#         head_k = head_k.view(inp_permuted.size(0), inp_permuted.size(1), n_head, d_head)
#         head_v = head_v.view(inp_permuted.size(0), inp_permuted.size(1), n_head, d_head)
#         # print(head_q.shape,head_k.shape,head_v.shape)
#         q = head_q.permute(2, 0, 1, 3).reshape(-1, inp_permuted.size(1), d_head)
#         k = head_k.permute(2, 0, 1, 3).reshape(-1, inp_permuted.size(1), d_head)
#         v = head_v.permute(2, 0, 1, 3).reshape(-1, inp_permuted.size(1), d_head)
#         # print(q.shape,k.shape,v.shape)
#         #4,145,64 * 4,64,145 = 4,145,145 
#         attn_score = torch.bmm(q, k.transpose(1, 2))
#         print(attn_score.shape)
#         attn_score.mul_(self.scale)
#         print(attn_score.shape)

#         if attn_mask is not None:
#             print("reached MultiHeadAttn 1")
#             # print(attn_mask.shape)
#             # print(attn_mask.size(2))
#             # attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
#             attn_mask = attn_mask.repeat(n_head, attn_mask.size(1), 1)
#             print(attn_mask.shape)
#             print(attn_score.shape)
#             attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))
            

#         attn_prob = torch.softmax(attn_score, dim=2)
#         attn_prob = self.dropatt(attn_prob)
#         attn_vec = torch.bmm(attn_prob, v)

#         attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
#         attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
#             inp.size(0), inp.size(1), n_head * d_head)

#         attn_out = self.o_net(attn_vec)
#         attn_out = self.drop(attn_out)

#         if self.pre_lnorm:
#             output = residual + attn_out
#         else:
#             output = self.layer_norm(residual + attn_out)

#         return output


# class TransformerLayer(nn.Module):
#     def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout, **kwargs):
#         super(TransformerLayer, self).__init__()
#         self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
#         self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout, pre_lnorm=kwargs.get('pre_lnorm'))

#     def forward(self, dec_inp, mask=None):
#         # print(dec_inp.shape)
#         # print(mask.squeeze(2).shape)
#         output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
#         output *= mask
#         output = self.pos_ff(output)
#         output *= mask
#         return output


# def mask_from_lens(lens, max_len: Optional[int] = None):
#     if max_len is None:
#         max_len = lens.max()
#     ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
#     mask = torch.lt(ids, lens.unsqueeze(1))
#     return mask


# class FFTransformer(nn.Module):
#     def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size, dropout, dropatt, dropemb=0.0, pre_lnorm=False):
#         super(FFTransformer, self).__init__()
#         self.d_model = d_model
#         self.n_head = n_head
#         self.d_head = d_head

#         self.drop = nn.Dropout(dropemb)
#         self.layers = nn.ModuleList()

#         for _ in range(n_layer):
#             self.layers.append(
#                 TransformerLayer(
#                     n_head, d_model, d_head, d_inner, kernel_size, dropout,
#                     dropatt=dropatt, pre_lnorm=pre_lnorm
#                 )
#             )

#     def forward(self, dec_inp, seq_lens=None, conditioning=0):
#         mask = (dec_inp != 0).unsqueeze(2)
#         out = self.drop(dec_inp + conditioning)

#         for layer in self.layers:
#             out = layer(out, mask=mask)

#         out = self.drop(out)
#         return out

In [77]:
class PositionwiseConvFF(nn.Module):
    def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
        super(PositionwiseConvFF, self).__init__()

        self.d_model = d_model
        self.d_inner = d_inner
        self.dropout = dropout

        self.CoreNet = nn.Sequential(
            nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
            nn.ReLU(),
            # nn.Dropout(dropout),  # worse convergence
            nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
            nn.Dropout(dropout),
        )
        self.layer_norm = nn.LayerNorm(d_model)
        self.pre_lnorm = pre_lnorm

    def forward(self, inp):
        return self._forward(inp)

    def _forward(self, inp):
        if self.pre_lnorm:
            # layer normalization + positionwise feed-forward
            core_out = inp.transpose(1, 2)
            core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
            core_out = core_out.transpose(1, 2)

            # residual connection
            output = core_out + inp
        else:
            # positionwise feed-forward
            core_out = inp.transpose(1, 2)
            core_out = self.CoreNet(core_out)
            core_out = core_out.transpose(1, 2)

            # residual connection + layer normalization
            output = self.layer_norm(inp + core_out).to(inp.dtype)

        return output

class MultiHeadAttn(nn.Module):
    def __init__(self, n_head, d_model, d_head, dropout, dropatt=0.1,
                 pre_lnorm=False):
        super(MultiHeadAttn, self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.d_head = d_head
        self.scale = 1 / (d_head ** 0.5)
        self.pre_lnorm = pre_lnorm

        self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head)
        self.drop = nn.Dropout(dropout)
        self.dropatt = nn.Dropout(dropatt)
        self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, inp, attn_mask=None):
        return self._forward(inp, attn_mask)

    def _forward(self, inp, attn_mask=None):
        residual = inp

        if self.pre_lnorm:
            # layer normalization
            inp = self.layer_norm(inp)

        n_head, d_head = self.n_head, self.d_head

        head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
        head_q = head_q.view(inp.size(0), inp.size(1), n_head, d_head)
        head_k = head_k.view(inp.size(0), inp.size(1), n_head, d_head)
        head_v = head_v.view(inp.size(0), inp.size(1), n_head, d_head)

        q = head_q.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
        k = head_k.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)
        v = head_v.permute(2, 0, 1, 3).reshape(-1, inp.size(1), d_head)

        attn_score = torch.bmm(q, k.transpose(1, 2))
        attn_score.mul_(self.scale)
        print(attn_score.shape)

        if attn_mask is not None:
            attn_score = attn_score.masked_fill(attn_mask == 0, float('-inf'))        

        # if attn_mask is not None:
        #     attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
        #     attn_mask = attn_mask.repeat(n_head, attn_mask.size(2), 1)
        #     print(attn_mask.shape)
        #     attn_score.masked_fill_(attn_mask.to(torch.bool), -float('inf'))

        attn_prob = F.softmax(attn_score, dim=2)
        attn_prob = self.dropatt(attn_prob)
        attn_vec = torch.bmm(attn_prob, v)

        attn_vec = attn_vec.view(n_head, inp.size(0), inp.size(1), d_head)
        attn_vec = attn_vec.permute(1, 2, 0, 3).contiguous().view(
            inp.size(0), inp.size(1), n_head * d_head)

        # linear projection
        attn_out = self.o_net(attn_vec)
        attn_out = self.drop(attn_out)

        if self.pre_lnorm:
            # residual connection
            output = residual + attn_out
        else:
            # residual connection + layer normalization
            output = self.layer_norm(residual + attn_out)

        output = output.to(attn_out.dtype)

        return output
    
class TransformerLayer(nn.Module):
    def __init__(self, n_head, d_model, d_head, d_inner, kernel_size, dropout,
                 **kwargs):
        super(TransformerLayer, self).__init__()

        self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
        self.pos_ff = PositionwiseConvFF(d_model, d_inner, kernel_size, dropout,
                                         pre_lnorm=kwargs.get('pre_lnorm'))

    def forward(self, dec_inp, mask=None):
        
        output = self.dec_attn(dec_inp, attn_mask=mask)
        print(output.shape)
        print(mask.shape)
        # output = self.dec_attn(dec_inp)
        # output *= mask
        output = self.pos_ff(output)
        # output *= mask
        return output
    
class FFTransformer(nn.Module):
    def __init__(self, n_layer, n_head, d_model, d_head, d_inner, kernel_size,
                 dropout, dropatt, dropemb=0.0, embed_input=True,
                 n_embed=None, d_embed=None, pre_lnorm=False):
        super(FFTransformer, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_head = d_head
        # self.padding_idx = padding_idx

        # if embed_input:
        #     self.word_emb = nn.Embedding(n_embed, d_embed or d_model,
        #                                  padding_idx=self.padding_idx)
        # else:
        #     self.word_emb = None

        # self.pos_emb = PositionalEmbedding(self.d_model)
        self.drop = nn.Dropout(dropemb)
        self.layers = nn.ModuleList()

        for _ in range(n_layer):
            self.layers.append(
                TransformerLayer(
                    n_head, d_model, d_head, d_inner, kernel_size, dropout,
                    dropatt=dropatt, pre_lnorm=pre_lnorm)
            )

    def forward(self, dec_inp, mask=None):
        # if self.word_emb is None:
        #     inp = dec_inp
        #     mask = mask_from_lens(seq_lens).unsqueeze(2)
        # else:
        #     inp = self.word_emb(dec_inp)
        #     # [bsz x L x 1]
        #     mask = (dec_inp != self.padding_idx).unsqueeze(2)

        # pos_seq = torch.arange(inp.size(1), device=inp.device).to(inp.dtype)
        # pos_emb = self.pos_emb(pos_seq) * mask

        # out = self.drop(inp + pos_emb + conditioning)
        # mask = mask_from_lens(seq_lens).unsqueeze(2)
        # mask = (dec_inp != self.padding_idx).unsqueeze(2)
        # print(mask.shape)
        out = self.drop(dec_inp )

        for layer in self.layers:
            out = layer(out, mask=mask)
            # out = layer(out)

        out = self.drop(out)
        return out

In [24]:
class SynthesizerModel(nn.Module):
    def __init__(self, content_dim, speaker_dim, projected_dim):
        super(SynthesizerModel, self).__init__()
        
        # Linear layer to project speaker embeddings to the same dimension as content embeddings
        self.speaker_projection = nn.Linear(speaker_dim, projected_dim)
        
    
    def forward(self, content_embeddings, speaker_embeddings):
        """
        Inputs:
        - content_embeddings: Tensor of shape [batch_size, content_dim, time_steps]
        - speaker_embeddings: Tensor of shape [batch_size, speaker_dim]

        Returns:
        - Processed output
        """
        # Project the speaker embeddings from [batch_size, speaker_dim] to [batch_size, projected_dim]
        projected_speaker_embeddings = self.speaker_projection(speaker_embeddings)  # Shape: [batch_size, projected_dim]

        # Expand the speaker embeddings across the time dimension to match the content embeddings
        # Resulting shape: [batch_size, projected_dim, time_steps]
        projected_speaker_embeddings = projected_speaker_embeddings.unsqueeze(2).expand(-1, -1, content_embeddings.size(2))
        
        # Concatenate the content and speaker embeddings along the feature dimension
        combined_embeddings = torch.cat([content_embeddings, projected_speaker_embeddings], dim=1)

        
        return combined_embeddings

In [22]:
batch_size = 4
content_dim = 256
speaker_dim = 192
time_steps = 145
projected_dim = 256  # Project speaker embeddings to this dimension

# Create a dummy batch of inputs
# content_embeddings = torch.randn(batch_size, content_dim, time_steps)
# speaker_embeddings = torch.randn(batch_size, speaker_dim)

# Initialize the model
model = SynthesizerModel(content_dim=content_dim, speaker_dim=speaker_dim, projected_dim=projected_dim)

# Forward pass
output = model(content_embeddings, speaker_embeddings)

print(f"Output shape: {output.shape}") 

Output shape: torch.Size([4, 512, 199])


In [78]:
custom_encoder = FFTransformer(
    n_layer=6, n_head=1,
    d_model=512,
    d_head=64,
    d_inner=1536,
    kernel_size=3,
    dropout=0.1,
    dropatt=0.1,
    dropemb=0.0,)

In [80]:
output_permute = output.permute(0,2,1)
padding_mask_permute = padding_mask.permute(0,2,1)
mask_reduced = padding_mask_permute[:, :, 0]
mask_reduced = mask_reduced.unsqueeze(1)
encoded_result = custom_encoder(output_permute,mask_reduced)

torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])
torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])
torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])
torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])
torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])
torch.Size([4, 199, 199])
torch.Size([4, 199, 512])
torch.Size([4, 1, 199])


In [81]:
encoded_result.shape

torch.Size([4, 199, 512])

In [49]:
padding_mask_permute.shape

torch.Size([4, 199, 256])

In [64]:
mask_reduced = padding_mask_permute[:, :, 0]
mask_reduced = mask_reduced.unsqueeze(1)


In [63]:
mask_reduced.shape

torch.Size([4, 1, 199])

In [56]:
output_permute.shape

torch.Size([4, 199, 512])

In [55]:
input_dim =512
query_fc = nn.Linear(input_dim, input_dim)
key_fc = nn.Linear(input_dim, input_dim)

In [57]:
queries = query_fc(output_permute)
keys = key_fc(output_permute)

In [58]:
queries.shape

torch.Size([4, 199, 512])

In [59]:
keys.shape

torch.Size([4, 199, 512])

In [60]:
batch_size, seq_length, _ = output_permute.size()

In [61]:
queries = queries.view(batch_size, seq_length, 1, 64).transpose(1, 2)  # (batch_size, num_heads, seq_length, head_dim)
keys = keys.view(batch_size, seq_length, 1, 64).transpose(1, 2)


RuntimeError: shape '[4, 199, 1, 64]' is invalid for input of size 407552

In [None]:
# 4,1,199,64  4,1,64,199