<a href="https://colab.research.google.com/github/jcgeo9/Conformer/blob/main/Conformer_with_Libri_Speech_DS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DOWNLOAD LIBRISPEECH DATASET TO GOOGLE DRIVE

In [None]:
#@title  { form-width: "30%" }
%cd drive/MyDrive/Datasets

!wget https://openslr.elda.org/resources/12/train-clean-100.tar.gz

!tar -xzvf "/content/drive/MyDrive/Datasets/train-clean-100.tar.gz" -C "/content/drive/MyDrive/Datasets/"     #[run this cell to extract tar.gz files]

# CREATE MANIFEST FILE FOR DATASET

In [None]:
import os
import pandas as pd

def get_df_from_dataset_directory(rootdir,saveloc):
  array_of_audio_path=[]
  array_of_sentences_files=[]
  array_of_sentences=[]
  array_of_locations=[]

  for subdir, dirs, files in os.walk(rootdir):
    files.sort()
    for file in files:
      filepath=os.path.join(subdir, file)
      if (filepath.endswith('.flac')):
        array_of_audio_path.append(filepath)
      else:
        array_of_sentences_files.append(filepath)

  for sent_file in array_of_sentences_files:
    f = open(sent_file, "r")
    for line in f:
      parts=line.split(' ',1)
      if "\n" in parts[1]:
        stripped_sent=parts[1].strip()
      else:
        stripped_sent=parts[1]
      stripped_sent=stripped_sent.lower()
      array_of_sentences.append(stripped_sent)
      array_of_locations.append(parts[0])

  #create new df 
  df = pd.DataFrame({'audio_path':array_of_audio_path,'sentence_location':array_of_locations,'translation':array_of_sentences})

  df.to_csv(saveloc+'/manifest_file.csv')
  # pd.set_option('max_colwidth', 800)
  # pd.set_option('max_columns', 4)
  # pd.describe_option('max_colwidth')
  # pd.describe_option('max_columns')
  # print(df)



get_df_from_dataset_directory('/content/drive/MyDrive/Datasets/LibriSpeech/train-clean-100','/content/drive/MyDrive/Datasets/LibriSpeech')

# VOCABULARY AND DATALOADERS

In [None]:
# !apt install ffmpeg

import os 
import pandas as pd 
import numpy as np
import spacy  
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence 
from torch.utils.data import DataLoader, Dataset
import torchaudio.transforms as transforms

spacy_en = spacy.load("en")

class Vocabulary:
    def __init__(self, freq_threshold):
        self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
        self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
        self.freq_threshold = freq_threshold

    def __len__(self):
        return len(self.itos)

    @staticmethod
    def tokenizer_eng(text):
        return [tok.text.lower() for tok in spacy_en.tokenizer(text)]

    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4

        for sentence in sentence_list:
            for word in self.tokenizer_eng(sentence):
                if word not in frequencies:
                    frequencies[word] = 1

                else:
                    frequencies[word] += 1

                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1

    def numericalize(self, text):
        tokenized_text = self.tokenizer_eng(text)

        return [
            self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
            for token in tokenized_text
        ]
    
    def vocab_to_list(self):
        new_list=[*self.stoi]
        return new_list

    def convert_pred_to_words(self, stoi_pred):
      preds=[]
      conv= [self.itos[token] if token in self.itos else self.itos[3]
          for token in stoi_pred]
      # for i in stoi_pred:
      #   conv=[]
      #   for token in i:
      #     if (int(token) in self.itos):
      #       add_word=self.itos[int(token)]
      #     else:
      #       add_word=self.itos[3]
      #     conv.append(add_word)

        # preds.append(conv)
      return conv


class EnglishDataset(Dataset):
    def __init__(self, root_dir, translation_file, freq_threshold=3):
        self.root_dir = root_dir
        self.df = pd.read_csv(translation_file, header=0)

        # Get audio path,translation columns
        self.audio_path = self.df["audio_path"]
        self.translation = self.df["translation"]

        # Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.translation.tolist())
        print("Vocab Size:", self.vocab.__len__())

        #for specgram
        # self.tensor_size=1963
        #for mel_specgram
        self.tensor_size=392400
        self.tensor_trans_size=78
        self.pad_idx = self.vocab.stoi["<PAD>"]

    def __len__(self):
        return len(self.df)

    def __getvocab__(self):
        return self.vocab

    def __getitem__(self, index):
        translation = self.translation[index]
        audio_id = self.audio_path[index]

        path_to_audio=self.audio_path[index]

        #load audio and normalize between -1 and 1
        audio, sample_rate=torchaudio.load(path_to_audio)
        audio=torch.squeeze(audio)

        #pad audio tensor based on max length calculated manually
        values_to_pad=self.tensor_size-len(audio)
        m = torch.nn.ConstantPad1d((0, values_to_pad), 0)
        padded_audio=m(audio)
        
        #convert audio to Spectrogram or MelSpectrogram
        transform=torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate)
        # specgram = torchaudio.transforms.Spectrogram()(audio)
        specgram_WITHNOPAD=transform(audio)
        specgram = transform(padded_audio)

        #sentences translated to numerical values
        numericalized_translation = [self.vocab.stoi["<SOS>"]]
        numericalized_translation += self.vocab.numericalize(translation)
        numericalized_translation.append(self.vocab.stoi["<EOS>"])

        #list of numerical translation to tensor
        translation_tensor=torch.tensor(numericalized_translation)
        # print(translation_tensor)

        #pad translation tensor based on max length calculated manually
        padded_tens_size_trans=self.tensor_trans_size-translation_tensor.shape[0]
        p1d=torch.nn.ConstantPad1d( (0, padded_tens_size_trans), self.pad_idx )
        padded_tensor_trans=p1d(translation_tensor)

        return specgram,specgram_WITHNOPAD.shape[1], padded_tensor_trans, translation_tensor.shape[0]


# class MyCollate:
#     def __call__(self, batch):
#         audio = [item[0].unsqueeze(0) for item in batch]
#         audio = torch.cat(audio, dim=0)

#         # audio=torch.squeeze(audio)

#         targets = [item[1].unsqueeze(0) for item in batch]
#         targets = torch.cat(targets, dim=0)

#         return audio, targets

#number of samples=28539, we use batch size=32 
#colab's max workers are 4, but if increased we can use 8
def get_loader(root_folder,annotation_file,batch_size=16,num_workers=4,shuffle=True,pin_memory=True,):
    dataset = EnglishDataset(root_folder, annotation_file)

    words_vocab=dataset.__getvocab__()

    train_size = int(0.7 * len(dataset))
    validation_size = int(0.2 * len(dataset))
    test_size = len(dataset) - train_size - validation_size

    train_dataset, validation_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, validation_size, test_size])


    ##########################################################################
    #FOR GPU

    train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=True,pin_memory=pin_memory,drop_last=True,)
    validation_loader = DataLoader(dataset=validation_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=False,pin_memory=pin_memory,drop_last=True,)
    test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,num_workers=num_workers,shuffle=False,pin_memory=pin_memory,drop_last=True,)

    ##########################################################################
    #FOR TPU

    # import torch_xla
    # import torch_xla.core.xla_model as xm


    # train_sampler = torch.utils.data.distributed.DistributedSampler(
    #       train_dataset,
    #       num_replicas=xm.xrt_world_size(),
    #       rank=xm.get_ordinal(),
    #       shuffle=True)
    # validation_sampler = torch.utils.data.distributed.DistributedSampler(
    #       validation_dataset,
    #       num_replicas=xm.xrt_world_size(),
    #       rank=xm.get_ordinal(),
    #       shuffle=False)    
    # test_sampler = torch.utils.data.distributed.DistributedSampler(
    #       test_dataset,
    #       num_replicas=xm.xrt_world_size(),
    #       rank=xm.get_ordinal(),
    #       shuffle=False)    

    # train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,num_workers=num_workers,sampler=train_sampler,pin_memory=pin_memory,drop_last=True,)
    # validation_loader = DataLoader(dataset=validation_dataset,batch_size=batch_size,num_workers=num_workers,sampler=validation_sampler,pin_memory=pin_memory,drop_last=True,)
    # test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,num_workers=num_workers,sampler=test_sampler,pin_memory=pin_memory,drop_last=True,)

    ##########################################################################


    return train_loader, validation_loader, test_loader, dataset, words_vocab


if __name__ == "__main__":

    train_loader, validation_loader, test_loader, dataset, words_vocab = get_loader("/content/drive/MyDrive/Datasets/LibriSpeech", "/content/drive/MyDrive/Datasets/LibriSpeech/manifest_file.csv")

    print("Training Loader number of Batches:", len(train_loader))
    print("Validation Loader number of Batches:", len(validation_loader))
    print("Testing Loader number of Batches:", len(test_loader))


Vocab Size: 16121
Training Loader number of Batches: 1248
Validation Loader number of Batches: 356
Testing Loader number of Batches: 178


# FIND MAX AUDIO LENGTH AND MAX TRANSLATION LENGTH FOR PADDING IN LOADERS

In [None]:
translation_file="drive/MyDrive/Datasets/LibriSpeech/manifest_file.csv"

df_test = pd.read_csv(translation_file, header=0)

# Get audio path,translation columns
audio_path = df_test["audio_path"]
translation =df_test["translation"]

max_shape=0
for i,aud in enumerate(audio_path):
  audio, sample_rate=torchaudio.load(aud)
  # specgram = torchaudio.transforms.Spectrogram()(audio)
  specgram = torchaudio.transforms.MelSpectrogram()(audio)

  #if statement to keep track
  # if (i % 1000)==0:
  #   print(i, specgram.shape)

  if (specgram.shape[2]>max_shape):
    max_shape=specgram.shape[2]

list_of_lens=[]
for i,sent in enumerate(translation):
  a=[tok.text for tok in spacy_en.tokenizer(sent)]
  list_of_lens.append(len(a))
  

print(max_shape)
#+2 because we add SOS and EOS
print(max(list_of_lens)+2)

In [None]:
translation_file="drive/MyDrive/Datasets/LibriSpeech/manifest_file.csv"

df_test = pd.read_csv(translation_file, header=0)

# Get audio path,translation columns
audio_path = df_test["audio_path"]
translation =df_test["translation"]

max_shape=392400
for i,aud in enumerate(audio_path):

  audio, sample_rate=torchaudio.load(aud)
  audio=torch.squeeze(audio)

  values_to_pad=max_shape-len(audio)
  m = torch.nn.ConstantPad1d((0, values_to_pad), 0)

  padded_audio=m(audio)
  print(padded_audio)
  print(padded_audio.shape)

  # if (audio.shape[1]>max_shape):
  #   max_shape=audio.shape[1]

  transform=torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate)
  # transform_pad_one=torchaudio.transforms.MelSpectrogram(sample_rate=sample_rate,pad=10)
  # specgram = torchaudio.transforms.Spectrogram()(audio)
  specgram1 = transform(padded_audio)
  # specgram2= transform_pad_one(audio)

  print(specgram1.shape)
  # print(specgram2.shape)
  
  break
#max shape is calculated=392400
print(max_shape)

tensor([6.1035e-05, 6.1035e-05, 9.1553e-05,  ..., 0.0000e+00, 0.0000e+00,
        0.0000e+00])
torch.Size([392400])
torch.Size([128, 1963])
392400


# PACKAGE INSTALLATION & MODEL TRAINING AND EVALUATION

## INSTALL PACKAGES TO USE TPUS AND TORCHAUDIO (COMBATIBLE)

### SETUP CONFORMER FROM https://github.com/sooftware/conformer

In [None]:
#CODE TO INSTALL CONFORMER PACKAGE AUTOMATICALLY
# https://github.com/sooftware/conformer

!pip install git+https://github.com/sooftware/conformer.git


Collecting git+https://github.com/sooftware/conformer.git
  Cloning https://github.com/sooftware/conformer.git to /tmp/pip-req-build-xokjnzbr
  Running command git clone -q https://github.com/sooftware/conformer.git /tmp/pip-req-build-xokjnzbr
Building wheels for collected packages: conformer
  Building wheel for conformer (setup.py) ... [?25l[?25hdone
  Created wheel for conformer: filename=conformer-latest-py3-none-any.whl size=18344 sha256=43c103d35ab990a9b051e971b5763e2c78f4c8dd10dcb5d1dbbf3b7a4c8b00db
  Stored in directory: /tmp/pip-ephem-wheel-cache-hug_980a/wheels/58/e3/8f/c80015975bb214b50aca0fcf6449d6f55154176de96c0a3046
Failed to build conformer
Installing collected packages: conformer
    Running setup.py install for conformer ... [?25l[?25hdone
[33m  DEPRECATION: conformer was installed using the legacy 'setup.py install' method, because a wheel could not be built for it. A possible replacement is to fix the wheel build issue reported above. You can find discussion reg

In [None]:
#CODE TO INSTALL MANUALLY AND BE ABLE TO MODIFY PACKAGE
# https://github.com/sooftware/conformer

!git clone https://github.com/sooftware/conformer.git

%cd conformer

!python3 setup.py install

import os, sys
sys.path.append(os.getcwd())

### SETUP CTC DECODER FROM https://github.com/parlance/ctcdecode

In [None]:
!pip install git+https://github.com/parlance/ctcdecode.git

Collecting git+https://github.com/parlance/ctcdecode.git
  Cloning https://github.com/parlance/ctcdecode.git to /tmp/pip-req-build-ikwdtef4
  Running command git clone -q https://github.com/parlance/ctcdecode.git /tmp/pip-req-build-ikwdtef4
  Running command git submodule update --init --recursive -q
Building wheels for collected packages: ctcdecode
  Building wheel for ctcdecode (setup.py) ... [?25l[?25hdone
  Created wheel for ctcdecode: filename=ctcdecode-1.0.3-cp37-cp37m-linux_x86_64.whl size=13308100 sha256=c1b91a16b403aef16b1f6e5bd796dee286bfe09d27a84c5c60253d0043a8150d
  Stored in directory: /tmp/pip-ephem-wheel-cache-_f5vx1kd/wheels/ee/a2/65/642d6cea0147b4683da85d09137940b92987cbe132b883c9ca
Successfully built ctcdecode
Installing collected packages: ctcdecode
Successfully installed ctcdecode-1.0.3


### SETUP TORCHMETRICS

In [None]:
#CODE TO INSTALL METRICS
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-0.7.3-py3-none-any.whl (398 kB)
[?25l[K     |▉                               | 10 kB 38.8 MB/s eta 0:00:01[K     |█▋                              | 20 kB 22.5 MB/s eta 0:00:01[K     |██▌                             | 30 kB 17.6 MB/s eta 0:00:01[K     |███▎                            | 40 kB 15.8 MB/s eta 0:00:01[K     |████▏                           | 51 kB 7.3 MB/s eta 0:00:01[K     |█████                           | 61 kB 8.6 MB/s eta 0:00:01[K     |█████▊                          | 71 kB 9.1 MB/s eta 0:00:01[K     |██████▋                         | 81 kB 9.3 MB/s eta 0:00:01[K     |███████▍                        | 92 kB 10.2 MB/s eta 0:00:01[K     |████████▎                       | 102 kB 8.5 MB/s eta 0:00:01[K     |█████████                       | 112 kB 8.5 MB/s eta 0:00:01[K     |█████████▉                      | 122 kB 8.5 MB/s eta 0:00:01[K     |██████████▊                     | 133 kB 8.5 MB/s eta

### SETUP TPU WITH TORCHAUDIO

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

!pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

!pip install torch==1.9.0+cu111 torchaudio -f https://download.pytorch.org/whl/cu111/torch_stable.html

### OTHER

In [None]:
!pip3 install --pre torch torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu113/torch_nightly.html -U

In [None]:
from torchaudio.models import Conformer
import torch
import warnings

warnings.filterwarnings("ignore")

cuda = torch.cuda.is_available()  
device = torch.device('cuda' if cuda else 'cpu')
print('Device:', device)

conformer = Conformer(input_dim=128, num_heads=4, ffn_dim=144, num_layers=8, depthwise_conv_kernel_size=31,dropout=0.1).to(device)

for i, (audio,audio_len, translations, translation_len) in enumerate(train_loader):
  input=torch.transpose(audio, 1, 2).to(device)
  lengths=audio_len.to(device)
  print(lengths.shape)
  lengths=torch.full((4,1), 1963).squeeze().to(device)
  print(input.shape)
  print(lengths.shape)
  
  output,out_len = conformer(input, lengths)
  print(output)
  break



# lengths = torch.randint(1, 400, (10,))  # (batch,)
# input = torch.rand(10, int(lengths.max()), 128)  # (batch, num_frames, input_dim)
# print(input)
# print(lengths)
# output,out_len = conformer(input, lengths)
# print(output.shape)

## TRAINING AND EVALUATION

### LEARNING RATE SCHEDULER

In [None]:
'''A wrapper class for scheduled optimizer '''
import numpy as np

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, lr_mul, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.lr_mul = lr_mul
        self.d_model = d_model
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0


    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()


    def zero_grad(self):
        "Zero out the gradients with the inner optimizer"
        self._optimizer.zero_grad()


    def _get_lr_scale(self):
        d_model = self.d_model
        n_steps, n_warmup_steps = self.n_steps, self.n_warmup_steps
        return (d_model ** -0.5) * min(n_steps ** (-0.5), n_steps * n_warmup_steps ** (-1.5))


    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_steps += 1
        lr = self.lr_mul * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [None]:
class ScheduleOptimizer():
  "Optim wrapper that implements rate."
  def __init__(self, model_size, warmup, optimizer):
    self.optimizer = optimizer
    self._step = 0
    self.warmup = warmup
    self.model_size = model_size
    self._rate = 0
  
  def state_dict(self):
    """Returns the state of the warmup scheduler as a :class:`dict`.
    It contains an entry for every variable in self.__dict__ which
    is not the optimizer.
    """
    return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
  
  def load_state_dict(self, state_dict):
    """Loads the warmup scheduler's state.
    Arguments:
        state_dict (dict): warmup scheduler state. Should be an object returned
            from a call to :meth:`state_dict`.
    """
    self.__dict__.update(state_dict) 

  def zero_grad(self):
    "Zero out the gradients with the inner optimizer"
    self.optimizer.zero_grad()

  def step(self):
    "Update parameters and rate"
    self._step += 1
    rate = self.rate()
    for p in self.optimizer.param_groups:
        p['lr'] = rate
    self._rate = rate
    self.optimizer.step()
      
  def rate(self, step = None):
    "Implement `lrate` above"
    if step is None:
        step = self._step
    return (self.model_size ** (-0.5) *min(step ** (-0.5), step * self.warmup ** (-1.5))) 

### BEAM SEARCH DECODER

In [None]:
from ctcdecode import CTCBeamDecoder

decoder = CTCBeamDecoder(
    words_vocab.vocab_to_list(),
    model_path=None,
    alpha=0,
    beta=0,
    cutoff_top_n=40,
    cutoff_prob=1.0,
    beam_width=100,
    num_processes=4,
    blank_id=0,
    log_probs_input=True
)

In [None]:
import numpy as np
import math


def greedy_search_decoder(predictions):
  
    #select token with the maximum probability for each prediction
    output_sequence = [np.argmax(prediction) for prediction in predictions]
    
    #storing token probabilities
    token_probabilities = [np.max(prediction) for prediction in predictions]
    
    #multiply individaul token-level probabilities to get overall sequence probability
    sequence_probability = np.product(token_probabilities)
    
    return output_sequence, sequence_probability
    
# model_prediction = [[0.1, 0.7, 0.1, 0.1],
#                     [0.7, 0.1, 0.1, 0.1],
#                     [0.1, 0.1, 0.6, 0.2],
#                     [0.1, 0.1, 0.1, 0.7],
#                     [0.4, 0.3, 0.2, 0.1]]

# greedy_search_decoder(model_prediction)

def beam_search_decoder(predictions, top_k = 3):
    #start with an empty sequence with zero score
    output_sequences = [([], 0)]
    
    #looping through all the predictions
    for token_probs in predictions:
        new_sequences = []
        
        #append new tokens to old sequences and re-score
        for old_seq, old_score in output_sequences:
            for char_index in range(len(token_probs)):
                new_seq = old_seq + [char_index]
                #considering log-likelihood for scoring
                new_score = old_score + token_probs[char_index]
                new_sequences.append((new_seq, new_score))
                
        #sort all new sequences in the de-creasing order of their score
        output_sequences = sorted(new_sequences, key = lambda val: val[1], reverse = True)
        
        #select top-k based on score 
        # *Note- best sequence is with the highest score
        output_sequences = output_sequences[:top_k]
        
    return output_sequences

# model_prediction = [[0.1, 0.7, 0.1, 0.1],
#                     [0.7, 0.1, 0.1, 0.1],
#                     [0.1, 0.1, 0.6, 0.2],
#                     [0.1, 0.1, 0.1, 0.7],
#                     [0.4, 0.3, 0.2, 0.1]]
                    
# beam_search_decoder(model_prediction, top_k = 5)


# beam search
def beam_search_decoder_test(data, k):
	sequences = [[list(), 0.0]]
	# walk over each step in sequence
	for row in data:
		all_candidates = list()
		# expand each current candidate
		for i in range(len(sequences)):
			seq, score = sequences[i]
			for j in range(len(row)):
				candidate = [seq + [j], score - row[j]]
				all_candidates.append(candidate)
		# order all candidates by score
		ordered = sorted(all_candidates, key=lambda tup:tup[1])
		# select k best
		sequences = ordered[:k]
	return sequences

In [None]:
def beam_search_decoder_torch(post, k):
    """Beam Search Decoder

    Parameters:

        post(Tensor) – the posterior of network.
        k(int) – beam size of decoder.

    Outputs:

        indices(Tensor) – a beam of index sequence.
        log_prob(Tensor) – a beam of log likelihood of sequence.

    Shape:

        post: (batch_size, seq_length, vocab_size).
        indices: (batch_size, beam_size, seq_length).
        log_prob: (batch_size, beam_size).

    Examples:

        >>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
        >>> indices, log_prob = beam_search_decoder(post, 3)

    """

    batch_size, seq_length, _ = post.shape
    log_post = post
    log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
    indices = indices.unsqueeze(-1)
    for i in range(1, seq_length):
        log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
        log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
        indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
    return indices, log_prob

### TPU TRAIN

In [None]:
# import torch_xla
# import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
import time

warnings.filterwarnings("ignore")

import torch
import time
import sys
from google.colab import output
import torch.nn as nn
from conformer import Conformer
import torchmetrics
import random
# import torchaudio.models

def _run(flags):
  cuda = torch.cuda.is_available()  
  # device = torch.device('cuda' if cuda else 'cpu')


  ################################################################################

  def train_model(model, optimizer, criterion, loader, device, metric):
    running_loss = 0.0
    for i, (audio,audio_len, translations, translation_len) in enumerate(loader):
      output.clear(output_tags='some_outputs')
      with output.use_tags('some_outputs'):
        sys.stdout.write('Batch: '+ str(i+1)+'/'+str(len(loader)))
        sys.stdout.flush();

      #sorting inputs and targets to have targets in descending order based on len
      sorted_list,sorted_indices=torch.sort(translation_len,descending=True)

      sorted_audio=torch.zeros(audio.shape,dtype=torch.float)
      sorted_audio_len=torch.zeros(audio_len.shape,dtype=torch.int)
      sorted_translations=torch.zeros(translations.shape,dtype=torch.int)
      sorted_translation_len=sorted_list

      for index, contentof in enumerate(translation_len):
        sorted_audio[index]=audio[sorted_indices[index]]
        sorted_audio_len[index]=audio_len[sorted_indices[index]]
        sorted_translations[index]=translations[sorted_indices[index]]

      #transpose inputs from (batch, dim, seq_len) to (batch, seq_len, dim)
      inputs=sorted_audio.to(device)
      inputs=torch.transpose(inputs, 1, 2)
      input_lengths=sorted_audio_len
      targets=sorted_translations.to(device)
      target_lengths=sorted_translation_len

      optimizer.zero_grad()

      # Forward propagate
      outputs, output_lengths = model(inputs, input_lengths)

      # Calculate CTC Loss
      loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)

      loss.backward()
      # optimizer.step()

      xm.optimizer_step(optimizer)

      # print statistics
      running_loss += loss.item()

    loss_per_epoch=running_loss/(i+1)
    # print(f'Loss: {loss_per_epoch:.3f}')

    output.clear(output_tags='some_outputs')
    return loss_per_epoch

  ################################################################################

  def eval_model(model, optimizer, criterion, loader, device, metric):
    running_loss = 0.0
    wer_calc=0.0
    random_index_per_epoch= random.randint(0, 178)

    for i, (audio,audio_len, translations, translation_len) in enumerate(loader):
      output.clear(output_tags='some_outputs')
      with output.use_tags('some_outputs'):
        sys.stdout.write('Batch: '+ str(i+1)+'/'+str(len(loader)))
        sys.stdout.flush();

      #sorting inputs and targets to have targets in descending order based on len
      sorted_list,sorted_indices=torch.sort(translation_len,descending=True)

      sorted_audio=torch.zeros(audio.shape,dtype=torch.float)
      sorted_audio_len=torch.zeros(audio_len.shape,dtype=torch.int)
      sorted_translations=torch.zeros(translations.shape,dtype=torch.int)
      sorted_translation_len=sorted_list

      for index, contentof in enumerate(translation_len):
        sorted_audio[index]=audio[sorted_indices[index]]
        sorted_audio_len[index]=audio_len[sorted_indices[index]]
        sorted_translations[index]=translations[sorted_indices[index]]

      #transpose inputs from (batch, dim, seq_len) to (batch, seq_len, dim)
      inputs=sorted_audio.to(device)
      inputs=torch.transpose(inputs, 1, 2)
      input_lengths=sorted_audio_len
      targets=sorted_translations.to(device)
      target_lengths=sorted_translation_len

      # Forward propagate
      outputs, output_lengths = model(inputs, input_lengths)

      # predictions=greedy_search_decoder(outputs)

      # Calculate CTC Loss
      loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)

      # for i in outputs:
        # model_pred=greedy_search_decoder(i.cpu().detach().numpy())
      ind,prob=beam_search_decoder_torch(outputs.cpu().detach(),5)

      # print(ind.shape)
      print(ind)
      # print(prob.shape)
      print(prob)
          # print(len(model_pred[0]))

      # print(outputs.transpose(0,1)[0])
      # outputs_in_words=words_vocab.convert_pred_to_words(outputs.transpose(0, 1)[0])
      # targets_in_words=words_vocab.convert_pred_to_words(targets)
      # wer=metrics_calculation(metric, outputs_in_words,targets_in_words)
      

      # if (i==random_index_per_epoch):
      #     print(outputs_in_words,targets_in_words)

      # if (i % 100)==0:
      #   xm.master_print(f'Batch: {i}/{len(loader)}')


      running_loss += loss.item()
      # wer_calc += wer

    loss_per_epoch=running_loss/(i+1)
    wer_per_epoch=wer_calc/(i+1)

    output.clear(output_tags='some_outputs')
    return loss_per_epoch, wer_per_epoch

  ################################################################################

  def train_eval_model(epochs):

    device=xm.xla_device()
    print('Device:', device)

    inputdim=dataset[0][0].shape[0]
    print('Input Dimensions:', inputdim)

    #conformer model init
    model = nn.DataParallel(Conformer(num_classes=16121, input_dim=inputdim, encoder_dim=144, num_encoder_layers=16, num_attention_heads=4, conv_kernel_size=31)).to(device)

    # Optimizers specified in the torch.optim package
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

    #loss function
    criterion = nn.CTCLoss().to(device)

    #metrics init
    metric=torchmetrics.WordErrorRate()

    for epoch in range(epochs):
      start = time.time()

      print("Epoch", epoch+1)

      ############################################################################
      #TRAINING   

      model.train()
      print("Training")

      epoch_loss=train_model(model=model,optimizer=optimizer, criterion=criterion, loader=train_loader, device=device, metric=metric)

      print(format(epoch_loss, ".4f")) 
      # print(f'WER: {epoch_wer:.3f}')

      ############################################################################
      #EVALUATION

      with torch.no_grad():
        model.train(False)

        print("Validation")

        epoch_val_loss, epoch_val_wer=eval_model(model=model,optimizer=optimizer, criterion=criterion, loader=validation_loader, device=device, metric=metric)
        
        print(format(epoch_val_loss, ".4f")) 
        # print(f'WER: {epoch_val_wer:.3f}')   

      print(f'Epoch {epoch+1} completed in {(time.time() - start)/60} minutes')

  ###############################################################################

  def metrics_calculation(metric, predictions, targets):
      print(predictions)
      print(targets)
      wer=metric(predictions, targets)

      return wer
  
  train_eval_model(1)

# Start training processes
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run(flags)

flags={}
xmp.spawn(_mp_fn, args=(flags,), nprocs=1, start_method='fork')



# train_eval_model(1)

Device: xla:1
Input Dimensions: 128
Epoch 1
Training


Batch: 10/4994

### GPU TRAIN

In [None]:
import warnings

warnings.filterwarnings("ignore")

import torch
import time
import sys
from google.colab import output
import torch.nn as nn
from conformer import Conformer
import torchmetrics
import random
from torch.utils.tensorboard import SummaryWriter

###############################################################################

def train_model(model, optimizer, criterion, loader, device, metric):
  running_loss = 0.0
  running_wer=0.0
  for i, (audio,audio_len, translations, translation_len) in enumerate(loader):
    output.clear(output_tags='some_outputs')
    with output.use_tags('some_outputs'):
      sys.stdout.write('Batch: '+ str(i+1)+'/'+str(len(loader)))
      sys.stdout.flush();

    #sorting inputs and targets to have targets in descending order based on len
    sorted_list,sorted_indices=torch.sort(translation_len,descending=True)

    sorted_audio=torch.zeros(audio.shape,dtype=torch.float)
    sorted_audio_len=torch.zeros(audio_len.shape,dtype=torch.int)
    sorted_translations=torch.zeros(translations.shape,dtype=torch.int)
    sorted_translation_len=sorted_list

    for index, contentof in enumerate(translation_len):
      sorted_audio[index]=audio[sorted_indices[index]]
      sorted_audio_len[index]=audio_len[sorted_indices[index]]
      sorted_translations[index]=translations[sorted_indices[index]]

    #transpose inputs from (batch, dim, seq_len) to (batch, seq_len, dim)
    inputs=sorted_audio.to(device)
    inputs=torch.transpose(inputs, 1, 2)
    input_lengths=sorted_audio_len
    targets=sorted_translations.to(device)
    target_lengths=sorted_translation_len

    optimizer.zero_grad()

    # Forward propagate
    outputs, output_lengths = model(inputs, input_lengths)

    # Calculate CTC Loss
    loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)

    loss.backward()
    optimizer.step()

    # Beam Search Decoding
    beam_results, beam_scores, timesteps, out_lens = decoder.decode(outputs,output_lengths)
    # beam_results, beam_scores, timesteps, out_lens = decoder.decode(logP, logits_len)

    # Batch Pred List
    batch_pred_list = []
    word_err_rate_batch=0

    # Batch loop
    for b in range(outputs.size(0)):
      result_beam=beam_results[b][0][:out_lens[b][0]].tolist()
      batch_pred_list.append(result_beam)
      
      word_error_rate=metrics_calculation(metric,result_beam,targets[b])
      word_err_rate_batch+=word_error_rate

    # print(batch_pred_list)

    # print statistics
    running_loss += loss.item()
    running_wer+=word_err_rate_batch/outputs.size(0)

  loss_per_epoch=running_loss/(i+1)
  wer_per_epoch=running_wer/(i+1)

  output.clear(output_tags='some_outputs')
  return loss_per_epoch, wer_per_epoch

################################################################################

def eval_model(model, optimizer, criterion, loader, device, metric):
  running_loss = 0.0
  running_wer=0.0

  for i, (audio,audio_len, translations, translation_len) in enumerate(loader):
    output.clear(output_tags='some_outputs')
    with output.use_tags('some_outputs'):
      sys.stdout.write('Batch: '+ str(i+1)+'/'+str(len(loader)))
      sys.stdout.flush();

    #sorting inputs and targets to have targets in descending order based on len
    sorted_list,sorted_indices=torch.sort(translation_len,descending=True)

    sorted_audio=torch.zeros(audio.shape,dtype=torch.float)
    sorted_audio_len=torch.zeros(audio_len.shape,dtype=torch.int)
    sorted_translations=torch.zeros(translations.shape,dtype=torch.int)
    sorted_translation_len=sorted_list

    for index, contentof in enumerate(translation_len):
      sorted_audio[index]=audio[sorted_indices[index]]
      sorted_audio_len[index]=audio_len[sorted_indices[index]]
      sorted_translations[index]=translations[sorted_indices[index]]

    #transpose inputs from (batch, dim, seq_len) to (batch, seq_len, dim)
    inputs=sorted_audio.to(device)
    inputs=torch.transpose(inputs, 1, 2)
    input_lengths=sorted_audio_len
    targets=sorted_translations.to(device)
    target_lengths=sorted_translation_len

    # Forward propagate
    outputs, output_lengths = model(inputs, input_lengths)

    # predictions=greedy_search_decoder(outputs)

    # Calculate CTC Loss
    loss = criterion(outputs.transpose(0, 1), targets, output_lengths, target_lengths)

    # ind,prob=beam_search_decoder_torch(outputs.cpu().detach(),5)
    # print(ind.shape)
    # print(ind)
    # print(prob.shape)
    # print(prob)

    # Beam Search Decoding
    beam_results, beam_scores, timesteps, out_lens = decoder.decode(outputs,output_lengths)

    # Batch Pred List
    batch_pred_list = []
    word_err_rate_batch=0

    # Batch loop
    for b in range(outputs.size(0)):
      result_beam=beam_results[b][0][:out_lens[b][0]].tolist()
      batch_pred_list.append(result_beam)
      
      word_error_rate=metrics_calculation(metric,result_beam,targets[b])
      word_err_rate_batch+=word_error_rate

    # print(outputs.transpose(0,1)[0])
    # outputs_in_words=words_vocab.convert_pred_to_words(outputs.transpose(0, 1)[0])
    # targets_in_words=words_vocab.convert_pred_to_words(targets)
    # wer=metrics_calculation(metric, outputs_in_words,targets_in_words)
    

    # if (i==random_index_per_epoch):
    #     print(outputs_in_words,targets_in_words)

    # if (i % 100)==0:
    #   xm.master_print(f'Batch: {i}/{len(loader)}')


    running_loss += loss.item()
    running_wer+=word_err_rate_batch/outputs.size(0)

  loss_per_epoch=running_loss/(i+1)
  wer_per_epoch=running_wer/(i+1)

  output.clear(output_tags='some_outputs')
  return loss_per_epoch, wer_per_epoch

################################################################################

def train_eval_model(epochs):

  path = "/content/drive/MyDrive/Datasets/model.pt"

  # tensorboard writer
  writer = SummaryWriter("/content/drive/MyDrive/Datasets/tensorboard")

  # device setup
  cuda = torch.cuda.is_available()  
  device = torch.device('cuda' if cuda else 'cpu')
  print('Device:', device)

  inputdim=dataset[0][0].shape[0]
  print('Input Dimensions:', inputdim)

  #conformer model init
  model = nn.DataParallel(Conformer(num_classes=16121, input_dim=inputdim, encoder_dim=32, num_encoder_layers=2, num_attention_heads=4, conv_kernel_size=31)).to(device)

  # Optimizers specified in the torch.optim package

  optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
  # optimizer = ScheduleOptimizer(optimizer=torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09), model_size=144, warmup=10000)

  #loss function
  criterion = nn.CTCLoss().to(device)

  #metrics init
  metric=torchmetrics.WordErrorRate()

  if os.path.exists(path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    starting_point = checkpoint['epoch']+1
    loss = checkpoint['loss']
    print("Model Checkpoint Found:")
    print("Starting epoch:", starting_point)
    print("Current Loss:", loss)
  else:
    print("No model Checkpoint found")
    starting_point=1

  for epoch in range(starting_point,epochs):
    start = time.time()

    print("Epoch", epoch)

    ############################################################################
    #TRAINING   

    model.train()
    print("Training")

    epoch_loss,epoch_wer=train_model(model=model,optimizer=optimizer, criterion=criterion, loader=train_loader, device=device, metric=metric)

    writer.add_scalar("Loss/Train", epoch_loss, epoch)
    writer.add_scalar("WER/Train", epoch_wer, epoch)
    print("Loss:", format(epoch_loss, ".4f")) 
    print("WER:", format(epoch_wer, ".4f"))

    ############################################################################
    #EVALUATION

    with torch.no_grad():
      model.train(False)

      print("Validation")

      epoch_val_loss, epoch_val_wer=eval_model(model=model,optimizer=optimizer, criterion=criterion, loader=validation_loader, device=device, metric=metric)
      
      writer.add_scalar("Loss/Validation", epoch_val_loss, epoch)
      writer.add_scalar("WER/Validation", epoch_val_wer, epoch)
      print("Loss:", format(epoch_val_loss, ".4f")) 
      print("WER:", format(epoch_val_wer, ".4f"))

    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': epoch_loss,
            }, path)
    
    print(f'Epoch {epoch} completed in {(time.time() - start)/60} minutes')
  writer.flush()
  writer.close()

###############################################################################

def metrics_calculation(metric, predictions, targets):
  pred_sent=' '.join(words_vocab.convert_pred_to_words(predictions))
  targ_sent=' '.join(words_vocab.convert_pred_to_words(targets.tolist()))

  wer=metric(pred_sent, targ_sent)

  return wer.item()

train_eval_model(100)

Device: cuda
Input Dimensions: 128
Model Checkpoint Found:
Starting epoch: 50
Current Loss: 3.7623177480239134
Epoch 50
Training
Loss: 3.7447
WER: 0.8467
Validation
Loss: 3.4365
WER: 0.8211
Epoch 50 completed in 94.42254243691762 minutes
Epoch 51
Training
Loss: 3.7133
WER: 0.8452
Validation
Loss: 3.4309
WER: 0.8203
Epoch 51 completed in 91.84995356003444 minutes
Epoch 52
Training
Loss: 3.6878
WER: 0.8443
Validation
Loss: 3.4354
WER: 0.8193
Epoch 52 completed in 91.59916432301203 minutes
Epoch 53
Training
Loss: 3.6572
WER: 0.8427
Validation
Loss: 3.3956
WER: 0.8164
Epoch 53 completed in 91.73973194758098 minutes
Epoch 54
Training
Loss: 3.6356
WER: 0.8415
Validation
Loss: 3.3998
WER: 0.8167
Epoch 54 completed in 92.15700422525406 minutes
Epoch 55
Training
Loss: 3.6089
WER: 0.8401
Validation
Loss: 3.3708
WER: 0.8137
Epoch 55 completed in 92.37225886185963 minutes
Epoch 56
Training
Loss: 3.5898
WER: 0.8391
Validation
Loss: 3.3517
WER: 0.8137
Epoch 56 completed in 93.88647603193918 minutes


Batch: 1241/1248

### TESTING

In [None]:
import warnings

warnings.filterwarnings("ignore")

import torch
import time
import sys
from google.colab import output
import torch.nn as nn
from conformer import Conformer
# import torchmetrics
import random

path = "/content/drive/MyDrive/Datasets/model.pt"

# device setup
cuda = torch.cuda.is_available()  
device = torch.device('cuda' if cuda else 'cpu')
print('Device:', device)

inputdim=dataset[0][0].shape[0]
print('Input Dimensions:', inputdim)

#conformer model init
model = nn.DataParallel(Conformer(num_classes=16121, input_dim=inputdim, encoder_dim=144, num_encoder_layers=16, num_attention_heads=4, conv_kernel_size=31)).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# optimizer = ScheduleOptimizer(optimizer=torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-09), model_size=144, warmup=10000)

#loss function
criterion = nn.CTCLoss().to(device)

#metrics init
# metric=torchmetrics.WordErrorRate()

if os.path.exists(path):
  checkpoint = torch.load(path, map_location=device)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  starting_point = checkpoint['epoch']+1
  loss = checkpoint['loss']
  print("Model Checkpoint Found")
  # print("Starting epoch:", starting_point)
  # print("Current Loss:", loss)

with torch.no_grad():
  model.train(False)

  for i, (audio,audio_len, translations, translation_len) in enumerate(test_loader):
    #sorting inputs and targets to have targets in descending order based on len
    sorted_list,sorted_indices=torch.sort(translation_len,descending=True)

    sorted_audio=torch.zeros(audio.shape,dtype=torch.float)
    sorted_audio_len=torch.zeros(audio_len.shape,dtype=torch.int)
    sorted_translations=torch.zeros(translations.shape,dtype=torch.int)
    sorted_translation_len=sorted_list

    for index, contentof in enumerate(translation_len):
      sorted_audio[index]=audio[sorted_indices[index]]
      sorted_audio_len[index]=audio_len[sorted_indices[index]]
      sorted_translations[index]=translations[sorted_indices[index]]

    #transpose inputs from (batch, dim, seq_len) to (batch, seq_len, dim)
    inputs=sorted_audio.to(device)
    inputs=torch.transpose(inputs, 1, 2)
    input_lengths=sorted_audio_len
    targets=sorted_translations.to(device)
    target_lengths=sorted_translation_len

    # Forward propagate
    outputs, output_lengths = model(inputs, input_lengths)

    # print(beam_results[4][0][:out_lens[4][0]])

    # Beam Search Decoding
    beam_results, beam_scores, timesteps, out_lens = decoder.decode(outputs,output_lengths)
    # beam_results, beam_scores, timesteps, out_lens = decoder.decode(logP, logits_len)

    # Batch Pred List
    batch_pred_list = []

    # Batch loop
    for b in range(outputs.size(0)):
        batch_pred_list.append(beam_results[b][0][:out_lens[b][0]].tolist())

    print(batch_pred_list)

    # print(beam_scores)

    # transl_test=words_vocab.convert_pred_to_words(beam_results[0][0])
    # print(transl_test)
    # predictions,chance=beam_search_decoder_torch(outputs,3)
    # print(predictions[torch.argmax(chance[0])])
    # print(chance[0])
    break

Device: cuda
Input Dimensions: 128
Model Checkpoint Found
[[1, 3, 4, 3, 3, 3, 3, 3, 2], [1, 3, 3, 4, 3, 3, 4, 2, 3, 2], [1, 3, 3, 3, 4, 3, 3, 3, 2], [1, 3, 3, 4, 3, 3, 3, 3, 2], [1, 3, 3, 3, 3, 3, 3, 2], [1, 3, 3, 3, 3, 2, 3, 2], [1, 4, 3, 4, 4, 4, 2], [1, 3, 4, 3, 4, 3, 3, 3, 2], [1, 3, 4, 4], [1, 3, 3, 3, 3, 3, 3, 2], [1, 3, 3, 4, 3, 4, 3, 4, 2, 3, 2], [1, 3, 3, 3, 3], [1, 3, 3, 4, 3], [1, 3, 3], [1, 3, 3], [1, 3, 3]]
