In [2]:
import os
import whisper
from whisper.tokenizer import Tokenizer

import torch
from torch import nn
import speechbrain as sb
import torch.nn.functional as F
from pytorch_metric_learning import losses
from torch.utils.data import Dataset, DataLoader

from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split

import random
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt

import librosa
import torchaudio.transforms as transforms

In [3]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

In [4]:
EXP_NAME = 'joint_ce-0.5accent'
# Dict for accents to index.
accents_to_index = {
    'ABA' : 0, # Arabic
    'SKA' : 0,
    'YBAA' : 0,
#     'ZHAA' : 0,
    'BWC' : 1, # Mandarin
    'LXC' : 1,
    'NCC' : 1,
#     'TXHC' : 1,
    'ASI' : 2, # Hindi
    'RRBI' : 2,
    'SVBI' : 2,
#     'TNI' : 2,
    'HJK' : 3, # Korean
    'HKK' : 3,
    'YDCK' : 3,
#     'YKWK' : 3,
    'EBVS' : 4, # Spanish
    'ERMS' : 4,
    'MBMPS' : 4,
#     'NJS' : 4,
    'HQTV' : 5, # Vitenamese
    'PNV' : 5,
    'THV' : 5,
#     'TLV' : 5
    }

held_out_set = {
    'ZHAA' : 0,
    'TXHC' : 1,
    'TNI' : 2,
    'YKWK' : 3,
    'NJS' : 4,
    'TLV' : 5
}
# Dict for index to accents.
index_to_accents = {
    0 : 'Arabic',
    1 : 'Mandarin',
    2 : 'Hindi',
    3 : 'Korean',
    4 : 'Spanish',
    5 : 'Vitenamese'
    }

PWD = os.getcwd()
WHISPER_HIDDEN_LAYER = 384 # Dim of Whisper Last Hidden layer of Encoder
TSNE_SAMPLES = 2400 # Number of Data points required to plot T-SNE 
batch_size = 8 # Batch Size
SAMPLING_RATE = 16000 # Required SR for Whisper Feature Extractor
TEMPERATURE = 0.1 # Supervised Constrative Loss Temparture
ALPHA = 1 # Weight of SCL 
MODEL = "tiny.en" # Pre-Trained Whisper Model
NUM_EPOCH = 5
MULTILINGUAL = None # For large-v2, please change it to ".en"
PATIENCE = 20 # early stopping patience; how long to wait after last time validation loss improved.

In [5]:
path = f'{PWD}/dataset/l2_arctic/ZHAA/transcript/'
classes = os.listdir(path)
classes = [x.split('.')[0] for x in classes]
classes_to_index = {}
index_to_classes = {}
for i, j in enumerate(classes):
    classes_to_index[j] = i
    index_to_classes[i] = j


In [6]:
def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

In [7]:
def scale_to_01_range(x):
    value_range = (np.max(x) - np.min(x))
    starts_from_zero = x - np.min(x)
 
    return starts_from_zero / value_range

In [8]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.1):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, feature_vectors, labels):
        # Normalize feature vectors
        feature_vectors_normalized = F.normalize(feature_vectors.squeeze(), p=2, dim=1)
        # Compute logits
        logits = torch.div(
            torch.matmul(
                feature_vectors_normalized, torch.transpose(feature_vectors_normalized, 0, 1)
            ),
            self.temperature, 
        )
        return losses.NTXentLoss(temperature=self.temperature)(logits, torch.squeeze(labels))

In [9]:
woptions = whisper.DecodingOptions()
wtokenizer = whisper.tokenizer.get_tokenizer(
    MULTILINGUAL,
    task=woptions.task
    )

In [10]:
class L2ArcticDataset(Dataset):
    def __init__(self, split = 'train'):
        self.paths = []
        self.tokenizer = wtokenizer

        # Get paths of the wavs for all accents.
        for accent in accents_to_index:
            path = f'{PWD}/dataset/l2_arctic/{accent}/wav/'
            dir_list = os.listdir(path)
            dir_list = [[path + dir, accents_to_index[accent], classes_to_index[dir.split('.')[0]]] for dir in dir_list]
            self.paths.extend(dir_list)
        
        # 80% of the paths 
        if split == 'train':
            self.paths, _ = train_test_split(
              self.paths,
              shuffle=True,
              random_state = 42,
              test_size = 0.2
            )
        # 20% of the paths
        elif split == 'val':
            _, self.paths = train_test_split(
              self.paths,
              shuffle=True,
              random_state = 42,
              test_size = 0.2
            )
        elif split == 'held':
            # Get paths of the wavs for all accents.
            self.paths = []
            for accent in held_out_set:
                path = f'{PWD}/dataset/l2_arctic/{accent}/wav/'
                dir_list = os.listdir(path)
                dir_list = [[path + dir, held_out_set[accent], classes_to_index[dir.split('.')[0]]] for dir in dir_list]
                self.paths.extend(dir_list)

    
    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        # load audio and pad/trim it to fit 30 seconds
        audio = whisper.load_audio(self.paths[idx][0])
        audio = whisper.pad_or_trim(audio)

        # make log-Mel spectrogram and move to the same device as the model
        mel = whisper.log_mel_spectrogram(audio)

        accent = self.paths[idx][1]
        ngram = self.paths[idx][2]
        
        transcript = self.paths[idx][0].replace("/wav/", "/transcript/")
        transcript = transcript.replace(".wav", ".txt")
        transcript = open(transcript, 'r').read()
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(transcript)
        labels = text[1:] + [self.tokenizer.eot]

        # Return extracted input feature and respective accent as label.
        return {
            "input_ids": mel,
            "accent": accent,
            "ngram" : ngram,
            "labels": labels,
            "dec_input_ids": text
        }

In [11]:
# Initialize train, val dataset
train_dataset = L2ArcticDataset('train')
val_dataset = L2ArcticDataset('val')
held_out_dataset = L2ArcticDataset('held')

In [12]:
class WhisperDataCollatorWhithPadding:
    def __call__(sefl, features):
        mel_augment = transforms.FrequencyMasking(freq_mask_param=80)
        input_ids, labels, dec_input_ids, accents, ngrams = [], [], [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])
            accents.append(f["accent"])
            ngrams.append(f["ngram"])

        for f in features:
            input_ids.append(mel_augment(f["input_ids"]))
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])
            accents.append(f["accent"])
            ngrams.append(f["ngram"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])

        
        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths + dec_input_ids_length)

        labels = [
            np.pad(
                lab,
                (0, max_label_len - lab_len),
                'constant',
                constant_values=-100
            )
            for lab, lab_len in zip(labels, label_lengths)
        ]
        
        dec_input_ids = [
            np.pad(
                e,
                (0, max_label_len - e_len),
                'constant',
                constant_values=wtokenizer.eot
            ) 
            for e, e_len in zip(dec_input_ids, dec_input_ids_length)
        ]

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids
        batch["accents"] = torch.tensor(accents, dtype=torch.int64)
        batch["ngrams"] = torch.tensor(ngrams, dtype=torch.int64)

        return batch

In [13]:
if device == "cuda:1":
  num_workers = 4 # Dataloader stuck after 1 epoch, so have to make it zero.
  pin_memory = True
else:
  num_workers = 0
  pin_memory = False

# Dataloader for training step.
train_loader = DataLoader(
    train_dataset, 
    batch_size=batch_size,
    collate_fn=WhisperDataCollatorWhithPadding(),
    shuffle=True, 
    num_workers=num_workers,
    pin_memory=pin_memory
    )

# Dataloader for training step.
val_loader = DataLoader(
    val_dataset, 
    batch_size=batch_size,
    collate_fn=WhisperDataCollatorWhithPadding(),
    shuffle=True, 
    num_workers=num_workers,
    pin_memory=pin_memory
    )

# Dataloader for Held Out Dataset.
held_out_loader = DataLoader(
    held_out_dataset, 
    batch_size=batch_size,
    collate_fn=WhisperDataCollatorWhithPadding(),
    shuffle=False, 
    num_workers=num_workers,
    pin_memory=pin_memory
    )

In [14]:
# Set seed
set_seed()

# Load whsiper model
model = whisper.load_model(MODEL)

# See extra device is available, if so make data parallel.
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = nn.DataParallel(model)

# Model parameters to device 
model.to(device)

# Compute gradiant for parameters of both encoder and FFN. 
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=2e-6
    )

# Define Loss Function
loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
scloss_fn = losses.SupConLoss(temperature=TEMPERATURE)

Random seed set as 42


Let's use 4 GPUs!


In [15]:
# Empty cache for proper utilization of GPU.
torch.cuda.empty_cache()

num_epochs = NUM_EPOCH
training_epoch_loss = []
validation_epoch_loss = []
held_out_epoch_loss = []

for epoch in range(num_epochs):
    with tqdm(
        train_loader,
        total=round(len(train_dataset) / batch_size + 0.5),
        unit = "batch") as tepoch:
    
        # Training Step
        model.train()
        torch.set_grad_enabled(True)
        total_loss = []
        
        for batch in tepoch:

            # Input_ids --> wav, Accents --> Resp. accent
            input_ids = batch["input_ids"].to(device)
            accents = batch["accents"].to(device)
            ngrams = batch["ngrams"].to(device)
            labels = batch["labels"].long().to(device)
            dec_input_ids = batch["dec_input_ids"].long().to(device)

            
            # Description of progress bar
            tepoch.set_description(f"Epoch {epoch + 1}") 
            
            # Extract Last Layer of Encoder
            embedding = model.module.encoder(input_ids)
            
            # Aggregate using mean. 
            embedding_mean = torch.mean(embedding, dim=1)

            # Pass Embedding decoder to get logits.
            out = model.module.decoder(dec_input_ids, embedding)

            # Loss = CE + alpha * SCL
            loss = loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) - 0.5*(scloss_fn(embedding_mean, accents))

            # update loss in progress bar 
            tepoch.set_postfix(loss=loss.item())

            # Back Prop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss.append(loss.item())
        avg_loss = np.array(total_loss).mean()
        training_epoch_loss.append(avg_loss)
        print(f"Epoch {epoch+1}, train loss={avg_loss:.4f}")

    with tqdm(
        val_loader,
        total=round(len(val_dataset) / batch_size + 0.5),
        unit = "batch") as vepoch:
        # Validation Step
        model.eval()
        torch.set_grad_enabled(False)

        total_loss = []
        for batch in vepoch:
            
            # Input_ids --> .wav, Labels --> Ground Truth, Dec_input_ids --> Decoder's input. 
            input_ids = batch["input_ids"].to(device)
            accents = batch["accents"].long().to(device)
            ngrams = batch["ngrams"].to(device)
            labels = batch["labels"].long().to(device)
            dec_input_ids = batch["dec_input_ids"].long().to(device)

            
            # Description of progress bar
            tepoch.set_description(f"Epoch {epoch + 1}") 
            
            # Extract Last Layer of Encoder
            embedding = model.module.encoder(input_ids)
            
            # Aggregate using mean. 
            embedding_mean = torch.mean(embedding, dim=1)

            # Pass Embedding decoder to get logits.
            out = model.module.decoder(dec_input_ids, embedding)

            # Loss = CE + alpha * SCL
            loss = loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) - 0.5*(scloss_fn(embedding_mean, accents))

            # update loss in progress bar 
            vepoch.set_postfix(loss=loss.item())
                        
            total_loss.append(loss.item())
        avg_loss = np.array(total_loss).mean()
        validation_epoch_loss.append(avg_loss)
        print(f"Epoch {epoch+1}, validation loss={avg_loss:.4f}")
        
    with tqdm(
        held_out_loader,
        total=round(len(held_out_dataset) / batch_size + 0.5),
        unit = "batch") as hepoch:
        # Validation Step
        model.eval()
        torch.set_grad_enabled(False)

        total_loss = []
        for batch in hepoch:
            
            # Input_ids --> .wav, Labels --> Ground Truth, Dec_input_ids --> Decoder's input. 
            input_ids = batch["input_ids"].to(device)
            accents = batch["accents"].long().to(device)
            ngrams = batch["ngrams"].to(device)
            labels = batch["labels"].long().to(device)
            dec_input_ids = batch["dec_input_ids"].long().to(device)

            
            # Description of progress bar
            tepoch.set_description(f"Epoch {epoch + 1}") 
            
            # Extract Last Layer of Encoder
            embedding = model.module.encoder(input_ids)
            
            # Aggregate using mean. 
            embedding_mean = torch.mean(embedding, dim=1)

            # Pass Embedding decoder to get logits.
            out = model.module.decoder(dec_input_ids, embedding)

            # Loss = CE + alpha * SCL
            loss = loss_fn(out.view(-1, out.size(-1)), labels.view(-1)) - 0.5*(scloss_fn(embedding_mean, accents))

            # update loss in progress bar 
            hepoch.set_postfix(loss=loss.item())
                        
            total_loss.append(loss.item())
        avg_loss = np.array(total_loss).mean()
        held_out_epoch_loss.append(avg_loss)
        print(f"Epoch {epoch+1}, held out loss={avg_loss:.4f}")

plt.plot(training_epoch_loss, label='train_loss')
plt.plot(validation_epoch_loss, label='val_loss')
plt.plot(held_out_epoch_loss, label='held_out_loss')
plt.legend()
plt.savefig(f"{PWD}/loss/" + f"loss_joint_{EXP_NAME}.png")
plt.close()

  0%|          | 0/2008 [00:00<?, ?batch/s]

Epoch 1: 100%|██████████| 2008/2008 [32:29<00:00,  1.03batch/s, loss=-7.68]


Epoch 1, train loss=-6.6345


100%|██████████| 502/502 [08:14<00:00,  1.01batch/s, loss=-6.09]


Epoch 1, validation loss=-7.3586


100%|██████████| 849/849 [13:46<00:00,  1.03batch/s, loss=1.33] 


Epoch 1, held out loss=0.8086


Epoch 2: 100%|██████████| 2008/2008 [32:26<00:00,  1.03batch/s, loss=-6.95]


Epoch 2, train loss=-7.4499


100%|██████████| 502/502 [08:11<00:00,  1.02batch/s, loss=-7.85]


Epoch 2, validation loss=-7.6147


100%|██████████| 849/849 [13:44<00:00,  1.03batch/s, loss=1.11] 


Epoch 2, held out loss=0.6203


Epoch 3: 100%|██████████| 2008/2008 [32:21<00:00,  1.03batch/s, loss=-9.04]


Epoch 3, train loss=-7.6441


100%|██████████| 502/502 [08:10<00:00,  1.02batch/s, loss=-7.89]


Epoch 3, validation loss=-7.7480


100%|██████████| 849/849 [13:38<00:00,  1.04batch/s, loss=0.434]


Epoch 3, held out loss=0.4893


Epoch 4: 100%|██████████| 2008/2008 [32:29<00:00,  1.03batch/s, loss=-7.49]


Epoch 4, train loss=-7.7882


100%|██████████| 502/502 [08:10<00:00,  1.02batch/s, loss=-8.39]


Epoch 4, validation loss=-7.8368


100%|██████████| 849/849 [13:43<00:00,  1.03batch/s, loss=0.501] 


Epoch 4, held out loss=0.4081


Epoch 5: 100%|██████████| 2008/2008 [32:22<00:00,  1.03batch/s, loss=-8.36]


Epoch 5, train loss=-7.8821


100%|██████████| 502/502 [08:10<00:00,  1.02batch/s, loss=-8.04]


Epoch 5, validation loss=-7.8643


100%|██████████| 849/849 [13:42<00:00,  1.03batch/s, loss=0.328] 

Epoch 5, held out loss=0.3477





In [16]:
def plot_tsne(data = 'held_out'):
    if data == "held_out":
        loader = held_out_loader
        dataset = held_out_dataset
    elif data == "val":
        loader = val_loader
        dataset =  val_dataset

    # T-SNE Plot 

    # Froze batch norm and dropout layer at time of evaluation. 
    model.eval()
    with torch.no_grad():
        embeds, targets = [], []
        with tqdm(
            enumerate(loader),
            total=round(len(dataset) / batch_size + 0.5),
            unit = "batch") as tepoch:
            for i, batch in tepoch :  
                input_ids = batch["input_ids"].to(device)
                accents = batch["accents"].to(device)
                
                embeddings = model.module.embed_audio(
                    input_ids) 

                embeddings = torch.mean(embeddings, dim=1)

                embeds.extend(embeddings.cpu().detach().numpy())

                targets.extend(accents.cpu().detach().numpy())

    tsne = TSNE(n_components=2).fit_transform(np.array(embeds))
    tx = tsne[:, 0]
    ty = tsne[:, 1]

    tx = scale_to_01_range(tx)
    ty = scale_to_01_range(ty)

    df = pd.DataFrame()
    df["y"] = [index_to_accents[i] for i in targets]
    df["comp-1"] = tx
    df["comp-2"] = ty

    sns.scatterplot(x="comp-1", y="comp-2", hue=df.y.tolist(),
                    palette=sns.color_palette("hls", len(index_to_accents)),
                    data=df).set(title="T-SNE projection")


    plt.savefig(f"{PWD}/t-sne/" + f"tnse_{data}_{EXP_NAME}.png")
    plt.close()

In [17]:
plot_tsne('held_out')
plot_tsne('val')

  0%|          | 0/849 [00:00<?, ?batch/s]

100%|██████████| 849/849 [13:08<00:00,  1.08batch/s]
100%|██████████| 502/502 [05:57<00:00,  1.41batch/s]


In [None]:
torch.save(model, f"{PWD}/models/{EXP_NAME}")