In [1]:
import pandas as pd
import numpy as np
from utils import ReferenceEncoderClassifier, prepare_tensor, _pad_tensor, fixed_timbre_perturb, sliced_timbre_perturb, train, finegrained_timbre_perturb
from samplers import PerfectBatchSampler
from losses import AngleProtoLoss
from reference_encoder import ReferenceEncoder
from torch.utils.data import DataLoader,random_split
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import torch
import os
import sys
sys.path.append('../')
from omegaconf import OmegaConf
from vits_extend.stft import TacotronSTFT
from vits.utils import load_wav_to_torch
from vits.data_utils import load_filepaths
import librosa
import IPython
import torch.multiprocessing as mp
# mp.set_start_method('spawn')


# Making pandas dataframe from wav32 folder

In [2]:
## Resampling validation wavs
from scipy.io import wavfile

def resample_wave(wav_in, wav_out, sample_rate):
    wav, _ = librosa.load(wav_in, sr=sample_rate)
    wav = wav / np.abs(wav).max() * 0.6
    wav = wav / max(0.01, np.max(np.abs(wav))) * 32767 * 0.6
    wavfile.write(wav_out, sample_rate, wav.astype(np.int16))

In [3]:
val_df = pd.read_csv('../../unicamp-aiv2/TTS/data_path/training_lists/mp_styles_val.csv', sep=';')
val_df.head()

Unnamed: 0,phonetic_transcription,wav_path,speaker,style
0,#p kk 'aa zz uc nn 'an wn tt 'en nh ac ss 'uu ...,/l/disk1/awstebas/data/TTS/speaker-adriana/eps...,adriana,neutro
1,#p uc tt 'en pp uc dz ic rx ee kk uu pp ee rr ...,/l/disk1/awstebas/data/TTS/speaker-adriana/eps...,adriana,animado
2,#p nn uc ll ii mm 'ii ts ic #c ss uu pp ee rr ...,/l/disk1/awstebas/data/TTS/speaker-adriana/eps...,adriana,rispido
3,#p 'uu mm ac pp ee ss 'oo ac kk ic ss 'aa ij b...,/l/disk1/awstebas/data/TTS/speaker-adriana/eps...,adriana,neutro
4,#p kk ic uc ss 'oh uw bb rd 'ii lh ic nn uc ss...,/l/disk1/awstebas/data/TTS/speaker-adriana/eps...,adriana,animado


In [4]:
# Run this once (IF ALREADY RAN DONT RUN AGAIN)
# Processing all validation files
wav_val = []
for wav in val_df.wav_path:
    wav_out = '../../data/valid_adriana_carlos_rosana_32k/' + '_'.join(wav.split('/')[-5:])
# #     break
    wav_val.append(wav_out)
#     resample_wave(wav, wav_out, 32000)

In [5]:
# !ls ../data_svc/waves-32k

In [6]:
tr_items = load_filepaths('../files/train.txt')
wav_train = []
for wavpath, spec, pitch, vec, ppg, spk, style_id in tr_items:
    wav_train.append('.' + wavpath)

In [7]:
# len(wav_train), len(wav_val)

In [8]:
wav_train[:3]

['../data_svc/waves-32k/rosana/speaker-rosana_rf_selecionadas_wav22_rf_selecionadas0577.wav',
 '../data_svc/waves-32k/adriana/speaker-adriana_eps_acolhedor_rf_wav22_eps_acolhedor_rf0407.wav',
 '../data_svc/waves-32k/carlos/speaker-carlos_riqueza_fonetica_wav22_riqueza_fonetica0578.wav']

In [9]:
for w in wav_val:
    if(not os.path.isfile(w)):
        print('error')

In [10]:
# Dict for mapping style to numeric ids
style2id = {
    "acolhedor": 0,
    "animado": 1,
    "neutro": 2,
    "rispido": 3
}

def path2style(wav_path):
    if('eps_acolhedor' in wav_path):
        return 0
    elif('eps_animado' in wav_path):
        return 1
    elif('eps_rispido' in wav_path):
        return 3
    else:
        return 2

In [11]:
# Making dataframes

train_df = pd.DataFrame({'wav_path': wav_train})
val_df = pd.DataFrame({'wav_path': wav_val})

styles_tr = [path2style(w) for w in wav_train]
train_df['style'] = styles_tr

styles_val = [path2style(w) for w in wav_val]
val_df['style'] = styles_val

In [12]:
train_df['style'].value_counts()

2    14755
0     1243
1     1243
3     1186
Name: style, dtype: int64

In [13]:
# train_df[train_df['style']!=2]

In [14]:
val_df['style'].value_counts()

2    227
0     39
1     38
3     37
Name: style, dtype: int64

# Defining how to read mel from wavpath

In [15]:
config_path = '../configs/base_pretrainedstyleencoder.yaml'
rank = 0
# device = torch.device('cuda:{:d}'.format(rank))
device = torch.device('cpu')
hp = OmegaConf.load(config_path)

stft = TacotronSTFT(filter_length=hp.data.filter_length,
                    hop_length=hp.data.hop_length,
                    win_length=hp.data.win_length,
                    n_mel_channels=hp.data.mel_channels,
                    sampling_rate=hp.data.sampling_rate,
                    mel_fmin=hp.data.mel_fmin,
                    mel_fmax=hp.data.mel_fmax,
                    center=False,
                    device=device)
def read_wav(filename):
    audio, sampling_rate = load_wav_to_torch(filename)
#     audio, sampling_rate = librosa.load(filename, sr = 22050)
#     audio = torch.FloatTensor(audio)
#     assert sampling_rate == hp.data.sampling_rate, f"error: this sample rate of {filename} is {sampling_rate}"
    audio_norm = audio  / hp.data.max_wav_value
    audio_norm = audio_norm.unsqueeze(0)
    return audio_norm

In [16]:
import numpy as np
#with librosa
# wav = read_wav(train_df.wav_path[0]).cuda()
y , sr = librosa.load(train_df.wav_path[0], sr = None)
wav = torch.tensor(y).to(device).unsqueeze(0)
mel = stft.mel_spectrogram(wav)
mel.shape, wav.shape, wav[0][:3]

(torch.Size([1, 100, 464]),
 torch.Size([1, 148756]),
 tensor([0.0000e+00, 6.1035e-05, 1.2207e-04]))

In [17]:
import numpy as np
wav = read_wav(train_df.wav_path[0])
mel = stft.mel_spectrogram(wav)
mel.shape, wav.shape, wav[0][:3]

(torch.Size([1, 100, 464]),
 torch.Size([1, 148756]),
 tensor([0.0000e+00, 6.1035e-05, 1.2207e-04]))

In [18]:
wav_perturbed = fixed_timbre_perturb(wav.squeeze(0).cpu().numpy(), sr = 32000, segment_size= 32000//2, formant_rate=1.4, pitch_steps = 0.01, pitch_floor=75, pitch_ceil=600, fname='null')

In [19]:
wav_perturbed.dtype

dtype('float64')

In [20]:
IPython.display.Audio(wav_perturbed, rate = 32000)

In [21]:
wav.squeeze(0).cpu().numpy()[:3]

array([0.0000000e+00, 6.1035156e-05, 1.2207031e-04], dtype=float32)

In [22]:
torch.tensor(wav.squeeze(0).cpu().numpy()).to(device).unsqueeze(0)[0][:3]

tensor([0.0000e+00, 6.1035e-05, 1.2207e-04])

In [23]:
torch.tensor(wav.squeeze(0).cpu().numpy()).to(device).unsqueeze(0).shape

torch.Size([1, 148756])

In [24]:
for w in wav_val:
    if w.split('.')[-1] != 'wav':
        print(w)
        
for w in wav_train:
    if w.split('.')[-1] != 'wav':
        print(w)

# Defining utils for training

In [25]:
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
    r"""Check model gradient against unexpected jumps and failures"""
    skip_flag = False
    if ignore_stopnet:
        if not amp_opt_params:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
            )
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
    else:
        if not amp_opt_params:
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        else:
            grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)

    # compatibility with different torch versions
    if isinstance(grad_norm, float):
        if np.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    else:
        if torch.isinf(grad_norm):
            print(" | > Gradient is INF !!")
            skip_flag = True
    return grad_norm, skip_flag

In [26]:
import random
class EncoderDataset(Dataset):
    def __init__(
        self,
#         config,
#         ap,
        df,
        voice_len=1.6,
        num_classes_in_batch=5,
        num_utter_per_class=10,
        verbose=True,
        augmentation_config=None,
        use_torch_spec=False,
        use_timbre_perturb=False,
    ):
        """
        Args:
            ap (TTS.tts.utils.AudioProcessor): audio processor object.
            meta_data (list): list of dataset instances.
            seq_len (int): voice segment length in seconds.
            verbose (bool): print diagnostic information.
        """
        super().__init__()
#         self.config = config
        self.items = df
        self.sample_rate = 32000
        self.seq_len = int(voice_len * self.sample_rate)
        self.num_utter_per_class = num_utter_per_class
#         self.ap = ap
        self.verbose = verbose
        self.use_torch_spec = use_torch_spec
        self.classes, self.items = self.__parse_items()

        self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}


        self.use_timbre_perturb = use_timbre_perturb
        # Data Augmentation
        # self.augmentator = None
        # self.gaussian_augmentation_config = None
        # if augmentation_config:
        #     self.data_augmentation_p = augmentation_config["p"]
        #     if self.data_augmentation_p and ("additive" in augmentation_config or "rir" in augmentation_config):
        #         self.augmentator = AugmentWAV(ap, augmentation_config)

        #     if "gaussian" in augmentation_config.keys():
        #         self.gaussian_augmentation_config = augmentation_config["gaussian"]

        if self.verbose:
            print("\n > DataLoader initialization")
            print(f" | > Classes per Batch: {num_classes_in_batch}")
            print(f" | > Number of instances : {len(self.items)}")
            print(f" | > Sequence length: {self.seq_len}")
            print(f" | > Num Classes: {len(self.classes)}")
            print(f" | > Classes: {self.classes}")

    def load_wav(self, filename):
        audio, _ = librosa.load(filename, sr=self.sample_rate)
#         audio = read_wav(filename)
        return audio

    def __parse_items(self):
        class_to_utters = {}
        for i in range(self.items.shape[0]):
            path_ = self.items["wav_path"].values[i]
            class_name = self.items['style'].values[i]
            if class_name in class_to_utters.keys():
                class_to_utters[class_name].append(path_)
            else:
                class_to_utters[class_name] = [
                    path_,
                ]

        # skip classes with number of samples >= self.num_utter_per_class
        class_to_utters = {k: v for (k, v) in class_to_utters.items() if len(v) >= self.num_utter_per_class}

        classes = list(class_to_utters.keys())
        classes.sort()

        new_items = []
#         print(self.items.shape)
        for i in range(self.items.shape[0]):
            path_ = self.items["wav_path"].values[i]
            class_name = self.items['style'].values[i]
            # ignore filtered classes
            if class_name not in classes:
                continue
            # ignore small audios
            if self.load_wav(path_).shape[0] - self.seq_len <= 0: #Must be shape 1 because we read a torhc tensor [1,seq_len]
#                 print('entrou')
                continue

            new_items.append({"wav_file_path": path_, "class_name": class_name})

        return classes, new_items

    def __len__(self):
        return len(self.items)

    def get_num_classes(self):
        return len(self.classes)

    def get_class_list(self):
        return self.classes

    def set_classes(self, classes):
        self.classes = classes
        self.classname_to_classid = {key: i for i, key in enumerate(self.classes)}

    def get_map_classid_to_classname(self):
        return dict((c_id, c_n) for c_n, c_id in self.classname_to_classid.items())

    def __getitem__(self, idx):
        return self.items[idx]

    def collate_fn(self, batch):
        # get the batch class_ids
        labels = []
        feats = []
        for item in batch:
            utter_path = item["wav_file_path"]
            class_name = item["class_name"]

            # get classid
            class_id = self.classname_to_classid[class_name]
            # load wav file
            wav = self.load_wav(utter_path)
            offset = random.randint(0, wav.shape[0] - self.seq_len)
            wav = wav[offset : offset + self.seq_len]


            if(self.use_timbre_perturb):
                # wav = finegrained_timbre_perturb(np.asarray(wav, dtype=np.float32), 5, self.sample_rate , self.sample_rate//2, 1.4, 0.01, 75,600)
                wav = fixed_timbre_perturb(wav, sr = self.sample_rate, segment_size= self.sample_rate//2, formant_rate=1.4, pitch_steps = 0.01, pitch_floor=75, pitch_ceil=600, fname='null').astype(np.float32)
            # if self.augmentator is not None and self.data_augmentation_p:
            #     if random.random() < self.data_augmentation_p:
            #         wav = self.augmentator.apply_one(wav)

#             if not self.use_torch_spec:
#                 mel = self.ap.melspectrogram(wav)
#                 feats.append(torch.FloatTensor(mel))
#             else:
#                 feats.append(torch.FloatTensor(wav))
            feats.append(stft.mel_spectrogram(torch.tensor(wav).to(device).unsqueeze(0)).squeeze(0))

            labels.append(class_id)

        feats = torch.stack(feats)
        labels = torch.LongTensor(labels)

        return feats, labels

In [27]:
from torch.utils.data import DataLoader

train_dataset = EncoderDataset(train_df, use_timbre_perturb = True) #use_tp = True)
val_dataset = EncoderDataset(val_df)

num_classes_in_batch = 4
num_utter_per_class = 10
is_val = True
classes = val_dataset.get_class_list()

train_sampler = PerfectBatchSampler(
    val_dataset.items,
    classes,
    batch_size=num_classes_in_batch * num_utter_per_class,  # total batch size
    num_classes_in_batch=num_classes_in_batch,
    num_gpus=1,
    shuffle=True,
    drop_last=True,
)

eval_sampler = PerfectBatchSampler(
    val_dataset.items,
    classes,
    batch_size=num_classes_in_batch * num_utter_per_class,  # total batch size
    num_classes_in_batch=num_classes_in_batch,
    num_gpus=1,
    shuffle=not is_val,
    drop_last=True,
)

# train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=4)


 > DataLoader initialization
 | > Classes per Batch: 5
 | > Number of instances : 17443
 | > Sequence length: 51200
 | > Num Classes: 4
 | > Classes: [0, 1, 2, 3]

 > DataLoader initialization
 | > Classes per Batch: 5
 | > Number of instances : 340
 | > Sequence length: 51200
 | > Num Classes: 4
 | > Classes: [0, 1, 2, 3]


In [28]:
classes

[0, 1, 2, 3]

In [29]:


train_dataloader = DataLoader(train_dataset, collate_fn=train_dataset.collate_fn, num_workers=4,  batch_sampler=train_sampler)

val_dataloader = DataLoader(val_dataset, collate_fn=val_dataset.collate_fn, num_workers=4,  batch_sampler=eval_sampler)

batch = next(iter(val_dataloader))
batch[0].shape, batch[1].shape

(torch.Size([40, 100, 160]), torch.Size([40]))

In [30]:
inputs, labels = batch
# agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
labels = torch.transpose(labels.view(num_utter_per_class, num_classes_in_batch), 0, 1).reshape(
    labels.shape
)
inputs = torch.transpose(inputs.view(num_utter_per_class, num_classes_in_batch, -1), 0, 1).reshape(
    inputs.shape
)

labels.shape, inputs.shape

(torch.Size([40]), torch.Size([40, 100, 160]))

In [31]:
batch[1]

tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,
        0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])

In [32]:
labels

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

In [33]:
class ReferenceEncoderRepresentation(nn.Module):
    """NN module creating a fixed size prosody embedding from a spectrogram.

    inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
    outputs: [batch_size, embedding_dim]
    """

    def __init__(self, num_mel, embedding_dim, use_nonlinear_proj = False):

        super().__init__()
        self.num_mel = num_mel
        self.re = ReferenceEncoder(num_mel, embedding_dim, use_nonlinear_proj = False)

        # self.classifier_layer = nn.Linear(embedding_dim, num_classes)

    def forward(self, inputs):
        # print(inputs.shape)
        inputs = inputs.permute(0,2,1)
        # print(inputs.shape)
        # print(inputs.view())

        out = self.re(inputs)

        # out = self.classifier_layer(out)

        return out.squeeze(0)

In [34]:
m = ReferenceEncoderRepresentation(100, 128)

In [35]:
out = m(inputs)

out.shape

torch.Size([40, 128])

In [36]:
out.view(num_classes_in_batch, out.shape[0] // num_classes_in_batch, -1).shape

torch.Size([4, 10, 128])

In [37]:
criterion = AngleProtoLoss()
criterion(
                out.view(num_classes_in_batch, out.shape[0] // num_classes_in_batch, -1), labels
            )

 > Initialized Angular Prototypical loss


tensor(1.2817, grad_fn=<NllLossBackward0>)

# Defining training function

In [38]:
## Also defining global constants
eval_num_utter_per_class = 10
eval_num_classes_in_batch = 4

num_utter_per_class = 10
num_classes_in_batch = 4

num_loader_workers = 4
grad_clip = 3

print_step = 1000
save_step = 1000
epochs = 1000
lr_decay = 0

save_path = './tpangre'

run_eval = True

def evaluation(model, criterion, data_loader, global_step):
    eval_loss = 0
    for _, data in enumerate(data_loader):
        with torch.no_grad():
            # setup input data
            inputs, labels = data

            # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
            labels = torch.transpose(
                labels.view(eval_num_utter_per_class, eval_num_classes_in_batch), 0, 1
            ).reshape(labels.shape)
            inputs = torch.transpose(
                inputs.view(eval_num_utter_per_class, eval_num_classes_in_batch, -1), 0, 1
            ).reshape(inputs.shape)

            # dispatch data to GPU
            if use_cuda:
                inputs = inputs.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            # forward pass model
            outputs = model(inputs)

            # loss computation
            loss = criterion(
                outputs.view(eval_num_classes_in_batch, outputs.shape[0] // eval_num_classes_in_batch, -1), labels
            )

            eval_loss += loss.item()

    eval_avg_loss = eval_loss / len(data_loader)
    # # save stats
    # dashboard_logger.eval_stats(global_step, {"loss": eval_avg_loss})
    # # plot the last batch in the evaluation
    # figures = {
    #     "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), c.num_classes_in_batch),
    # }
    # dashboard_logger.eval_figures(global_step, figures)
    return eval_avg_loss

import time
def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader, global_step):
    model.train()
    best_loss = float("inf")
    # best_loss = {"train_loss": None, "eval_loss": float("inf")}
    avg_loader_time = 0
    end_time = time.time()
    for epoch in range(epochs):
        tot_loss = 0
        epoch_time = 0
        for _, data in enumerate(data_loader):
            start_time = time.time()

            # setup input data
            inputs, labels = data
            # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1]
            labels = torch.transpose(labels.view(num_utter_per_class, num_classes_in_batch), 0, 1).reshape(
                labels.shape
            )
            inputs = torch.transpose(inputs.view(num_utter_per_class, num_classes_in_batch, -1), 0, 1).reshape(
                inputs.shape
            )

            loader_time = time.time() - end_time
            global_step += 1

            # setup lr
            if lr_decay:
                scheduler.step()
            optimizer.zero_grad()

            # dispatch data to GPU
            if use_cuda:
                inputs = inputs.cuda(non_blocking=True)
                labels = labels.cuda(non_blocking=True)

            # forward pass model
            outputs = model(inputs)

            # loss computation
            loss = criterion(
                outputs.view(num_classes_in_batch, outputs.shape[0] // num_classes_in_batch, -1), labels
            )
            loss.backward()
            grad_norm, _ = check_update(model, grad_clip)
            optimizer.step()

            step_time = time.time() - start_time
            epoch_time += step_time

            # acumulate the total epoch loss
            tot_loss += loss.item()

            # Averaged Loader Time
            # num_loader_workers = num_loader_workers if c.num_loader_workers > 0 else 1
            avg_loader_time = (
                1 / num_loader_workers * loader_time + (num_loader_workers - 1) / num_loader_workers * avg_loader_time
                if avg_loader_time != 0
                else loader_time
            )
            current_lr = optimizer.param_groups[0]["lr"]


            if global_step % print_step == 0:
                print(
                    "   | > Step:{}  Loss:{:.5f}  GradNorm:{:.5f}  "
                    "StepTime:{:.2f}  LoaderTime:{:.2f}  AvGLoaderTime:{:.2f}  LR:{:.6f}".format(
                        global_step, loss.item(), grad_norm, step_time, loader_time, avg_loader_time, current_lr
                    ),
                    flush=True,
                )

            if global_step % save_step == 0:
                # save model
                # save_checkpoint(
                #     c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
                # )
                print("Saving iteration model")
                torch.save(model.state_dict(), f'{save_path}/checkpoint_{global_step}.pth')

            end_time = time.time()

        print("")
        print(
            ">>> Epoch:{}  AvgLoss: {:.5f} GradNorm:{:.5f}  "
            "EpochTime:{:.2f} AvGLoaderTime:{:.2f} ".format(
                epoch, tot_loss / len(data_loader), grad_norm, epoch_time, avg_loader_time
            ),
            flush=True,
        )
        # evaluation
        if run_eval:
            model.eval()
            eval_loss = evaluation(model, criterion, eval_data_loader, global_step)
            print("\n\n")
            print("--> EVAL PERFORMANCE")
            print(
                "   | > Epoch:{}  AvgLoss: {:.5f} ".format(epoch, eval_loss),
                flush=True,
            )

            if(eval_loss < best_loss):
                print("Saving best model")
                torch.save(model.state_dict(), f'{save_path}/best_model_{global_step}.pth')
                best_loss = eval_loss
            # save the best checkpoint
            # best_loss = save_best_model(
            #     {"train_loss": None, "eval_loss": eval_loss},
            #     best_loss,
            #     c,
            #     model,
            #     optimizer,
            #     None,
            #     global_step,
            #     epoch,
            #     OUT_PATH,
            #     criterion=criterion.state_dict(),
            # )
            model.train()

    return best_loss, global_step

In [39]:
model = ReferenceEncoderRepresentation(100, 128)
optimizer = optim.RAdam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08)
criterion = AngleProtoLoss()
scheduler = None
global_step = 0
use_cuda = True

if use_cuda:
    model = model.cuda()
    criterion.cuda()

from torch.utils.data import DataLoader

train_dataset = EncoderDataset(train_df, use_timbre_perturb = True) #use_tp = True)
val_dataset = EncoderDataset(val_df)

is_val = True
classes = val_dataset.get_class_list()

train_sampler = PerfectBatchSampler(
    train_dataset.items,
    classes,
    batch_size=num_classes_in_batch * num_utter_per_class,  # total batch size
    num_classes_in_batch=num_classes_in_batch,
    num_gpus=1,
    shuffle=True,
    drop_last=True,
)

eval_sampler = PerfectBatchSampler(
    val_dataset.items,
    classes,
    batch_size=num_classes_in_batch * num_utter_per_class,  # total batch size
    num_classes_in_batch=num_classes_in_batch,
    num_gpus=1,
    shuffle=not is_val,
    drop_last=True,
)

train_dataloader = DataLoader(train_dataset, collate_fn=train_dataset.collate_fn, num_workers=4,  batch_sampler=train_sampler)

val_dataloader = DataLoader(val_dataset, collate_fn=val_dataset.collate_fn, num_workers=4,  batch_sampler=eval_sampler)

batch = next(iter(val_dataloader))
batch[0].shape, batch[1].shape

# train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True, collate_fn=train_dataset.collate_fn, num_workers=4)


 > Initialized Angular Prototypical loss

 > DataLoader initialization
 | > Classes per Batch: 5
 | > Number of instances : 17443
 | > Sequence length: 51200
 | > Num Classes: 4
 | > Classes: [0, 1, 2, 3]

 > DataLoader initialization
 | > Classes per Batch: 5
 | > Number of instances : 340
 | > Sequence length: 51200
 | > Num Classes: 4
 | > Classes: [0, 1, 2, 3]


(torch.Size([40, 100, 160]), torch.Size([40]))

In [40]:
%%timeit
batch = next(iter(val_dataloader))
batch[0].shape, batch[1].shape

363 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.manual_seed(54321)

_, global_step = train(model, optimizer, scheduler, criterion, train_dataloader, val_dataloader, global_step)

Clamping max value higher than 1 to 1
