In [1]:
!pip install params

Collecting params
  Downloading params-0.9.0-py3-none-any.whl.metadata (631 bytes)
Downloading params-0.9.0-py3-none-any.whl (11 kB)
Installing collected packages: params
Successfully installed params-0.9.0


In [2]:
!pip install -U --force-reinstall sympy

Collecting sympy
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading sympy-1.14.0-py3-none-any.whl (6.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m74.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading mpmath-1.3.0-py3-none-any.whl (536 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m536.2/536.2 kB[0m [31m42.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mpmath, sympy
  Attempting uninstall: mpmath
    Found existing installation: mpmath 1.3.0
    Uninstalling mpmath-1.3.0:
      Successfully uninstalled mpmath-1.3.0
  Attempting uninstall: sympy
    Found existing installation: sympy 1.13.3
    Uninstalling sympy-1.13.3:
      Successfully uninstalled sympy-1.13.3
Successfully installed mpmath-1.3.0 sympy-1.14.0


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# ilinca's file
import gdown

file_id = "13k-ACA6Qt9CJ3MZI6Ot6qD9TAUY3mHUA"
url = f"https://drive.google.com/uc?id={file_id}"
output = "tensor_file.pt"

gdown.download(url, output, quiet=False)

Downloading...
From (original): https://drive.google.com/uc?id=13k-ACA6Qt9CJ3MZI6Ot6qD9TAUY3mHUA
From (redirected): https://drive.google.com/uc?id=13k-ACA6Qt9CJ3MZI6Ot6qD9TAUY3mHUA&confirm=t&uuid=024a9981-a305-46c4-aa61-2501550c624c
To: /content/tensor_file.pt
100%|██████████| 1.12G/1.12G [00:13<00:00, 84.9MB/s]


'tensor_file.pt'

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from tqdm import tqdm
import editdistance as ed

from collections import defaultdict
import subprocess
import numpy as np
import matplotlib.pyplot
from IPython import display
import params
import time
import re


In [6]:
def read_TIMIT(path):
  '''
  args:
    path: path of TIMIT data(mfcc features, phoneme labels)
  return:
    feats: list of list for each audio samples
    labels: list of list for each audio samples
  '''

  feats, labels= [], []
  texts = []
  length_feats, length_labels= [], []

  # read processed TIMIT data
  # list of dictionarys with keys being 'mfcc', 'phonemes', 'path'
  samples= torch.load(path, weights_only= False)
  for idx in range(len(samples)):
    feats.append(samples[idx]['mfcc'])
    labels.append([phoneme.strip() for phoneme in samples[idx]['phonemes']])
    texts.append([text.strip() for text in samples[idx]['text']])

  return feats, labels, texts

In [7]:
path= r"tensor_file.pt"
feats, labels, texts= read_TIMIT(path)
# WHEN WE EXTRACT THEM: size TIME STEPS x FEATURES (which is 39)
# check mfcc feature matrix dimension
print(f'MFCC feature matrix Shape (one audio sample):\t{feats[-1].shape}')
# check ARPA repository
print(f'phoneme labels(one audio sample):\t{labels[-1]}')

MFCC feature matrix Shape (one audio sample):	torch.Size([129, 39])
phoneme labels(one audio sample):	['h#', 'ih', 'tcl', 'm', 'ay', 'tcl', 'hv', 'er', 'tcl', 'ch', 'ux', 'dh', 'ow', 'h#']


In [8]:
print('\n'.join([f'num_frames:\t{len(feats[5])}', f'num_labels:\t{len(labels[5])}']))

num_frames:	252
num_labels:	31


In [9]:
print(feats[1].shape)
print(labels[1])
print(texts[1])
# Looks pretty ok but idk the specifics

torch.Size([209, 39])
['h#', 'ch', 'ae', 'l', 'ax', 'n', 'dcl', 'jh', 'iy', 'tcl', 'ch', 'dcl', 'jh', 'eh', 'nx', 'axr', 'el', 'z', 'ax', 'n', 'tcl', 't', 'eh', 'l', 'ax', 'dcl', 'jh', 'en', 'tcl', 's', 'h#']
['c', 'h', 'a', 'l', 'l', 'e', 'n', 'g', 'e', '', 'e', 'a', 'c', 'h', '', 'g', 'e', 'n', 'e', 'r', 'a', 'l', 's', '', 'i', 'n', 't', 'e', 'l', 'l', 'i', 'g', 'e', 'n', 'c', 'e']


In [10]:
# simply remove them
labels = [[symbol for symbol in label if symbol not in 'h# epi pau'] for label in labels]

handle _diphtongs_, _similar sounds_ in labels

In [11]:
# diphthongs
diphthongs= ['ey', 'aw', 'ay', 'ow']
# ey 'bait' -> split
# aw 'bout' -> split
# ay 'bite' -> split
# oy 'boy' -> oh + y
# ow 'boat' -> split

diphthong_regex= re.compile('|'.join(sorted(map(re.escape, diphthongs),
                                              key= len, reverse= True)))
oy_regex= re.compile('oy')

def split_diphthongs(label):
  label= ' '.join(label)
  label= diphthong_regex.sub(lambda x: ' '.join(x.group()), label)
  label = oy_regex.sub('oh y', label).split()
  return label

In [12]:
# similar sounds
merge_arpa= {
    # replace allophones that likely don't exist in Yoruba
    # even if they exist as allophones of other sounds,
    # substituting them would undermine the distributions learned by the model
    # e.g. devoiced shwa could be by essence regarded as [h], but /h/ would never occur in C_C
    'ax-h': 'ax',  # devoiced schwa
    # closures of stops, that occur without release in English codas,
    # however no consonant codas in Yoruba
    'bcl': 'b',
    'dcl': 'd',
    'gcl': 'g',
    'kcl': 'k',
    'pcl': 'p',
    'tcl': 't',

    # replace syllabic with common sonorant labels
    'en': 'n',
    'em': 'm',
    'el': 'l',
    'eng': 'ng',

    ## rhotic vowels
    # 'axr': 'r' ? 'ɹ' ?
    # 'er': 'r', 'ɹ' ?

    # flapped /t, d, n/ substituted for the closest Yoruba analogs
    'dx': 'r',
    'nx': 'n',

    # /h/ sound
    'hh': 'h',
}

execute and review

In [13]:
# split diphthongs
labels = [split_diphthongs(label) for label in labels]
labels_merge = [[merge_arpa.get(symbol, symbol) for symbol in label] for label in labels]

In [14]:
# handling bcl, dcl, gcl, kcl, pcl, tcl
  # used as abbreviation for voiceless stop or plosive with closure
  # e.g. bcl : bilbial closure, tcl : alveolar closure

# after merging, there may appear identical consecutive phones

# if dcl and d > d
# if dcl was by itself > d

labels_final = []

for label, label_merge in zip(labels, labels_merge):
  label = ' '.join(label)
  label_merge = ' '.join(label_merge)
  for key, item in merge_arpa.items():
    if ' '.join([key, item]) in label:
      label_merge = re.sub(' '.join([item, item]), item, label_merge)

  labels_final.append(label_merge.split())

In [15]:
# redefine labels and the dictionary
labels= labels_final

# split train / dev dataset
train_feats, dev_feats, train_labels, dev_labels= train_test_split(feats, labels)

####create phoneme labels dictionary
set global variables in case we will change the arrangmement with IDs

In [16]:
BLANK_ID = 0
PAD_ID = 1

define help function

In [17]:
def create_ARPAdictionary(labels, include_unk=False):
  '''
  args:
    labels: list of list containing sequence of label for each audio sample
  return: arpa2idx
        dictionary of ARPA_label to index
  '''
  arpas= set()
  for label in labels:
    arpas= arpas.union(set(label))
  arpas= sorted(arpas)

  arpa2idx= {arpa:(idx+2) for idx, arpa in enumerate(arpas)}
  arpa2idx['<blank>'] = BLANK_ID # pad will also be used as this
  if include_unk:
    arpa2idx['<unk>'] = len(arpa2idx) + 1

  return arpa2idx, {v: k for k, v in arpa2idx.items()}

execute and review

In [18]:
arpa2idx, idx2arpa = create_ARPAdictionary(labels)
print(*sorted(arpa2idx))

<blank> a aa ae ah ao ax axr b ch d dh e eh er f g h hv ih ix iy jh k l m n ng o oh p q r s sh t th uh uw ux v w y z zh


In [19]:
print(f'the number of ARPAbet labels in TIMIT:\t {len(arpa2idx)}')
print(f'Index of blank symbol <blank>":\t {arpa2idx["<blank>"]}')
# print(f'Index of unk symbol <unk>":\t {arpa2idx["<unk>"]}')

the number of ARPAbet labels in TIMIT:	 45
Index of blank symbol <blank>":	 0


define num classes after making the dictionaries

In [20]:
NUM_CLASSES = max(arpa2idx.values()) + 1

In [21]:
# compare phonemic / phonetic symbols between
# provided TIMIT phonecode and transcription data
# https://catalog.ldc.upenn.edu/docs/LDC96S32/PHONCODE.TXT

# existing= set(arpa2idx.keys())

# phonecode= set(['b', 'd', 'g', 'p', 't', 'k', 'dx', 'q', # stops
#              'jh', 'ch', # affricates
#              's', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', # frcatives
#              'm', 'n', 'ng', 'em', 'en', 'eng', 'nx', # nasals
#              'l', 'r', 'w', 'y', 'hh', 'hv', 'el', # semivowels and glides
#              'iy', 'ih', 'eh', 'ey', 'ae', 'aa', 'aw', 'ay', 'ah', 'ao', 'oy', 'ow', 'uh', 'uw', 'ux', 'er', 'ax', 'ix', 'axr', 'ax-h', # vowels
#              'pau', 'epi', 'h#', '1', '2',# others # epi: epenthetic silence # h# : begin/ end marker # 1 : primary stress marker # 2 : secondary stress marker
#              ])
"""
print(*sorted((existing - phonecode))) # symbols that does not exist in phonecode but in TIMIT corpus
print(*sorted((phonecode - existing))) # symbols that does not exist (or is removed from) in TIMIT corpus but in phonecode
"""
# review
# in TIMIT
  # there is no stress markers(not important to our project)
  # we intentionally removed epi, h# and pause
  # we intentionally split the diphthongs (aw, ay, ey, ow, oy)
  # we intentionally merged (1) ax-h, hh to h (2) el to l (3) en eng to n (4) pau to <blank>

'\nprint(*sorted((existing - phonecode))) # symbols that does not exist in phonecode but in TIMIT corpus\nprint(*sorted((phonecode - existing))) # symbols that does not exist (or is removed from) in TIMIT corpus but in phonecode\n'

#### Dataset + pad
define Dataset Class

In [22]:
class PhonemeASRDataset(Dataset):
  def __init__(self, feats, labels, arpa2idx):
    super().__init__()
    self.feats, self.labels = feats, labels
    self.arpa2idx = arpa2idx
    self.has_unk = "<unk>" in arpa2idx

  def __len__(self):
    return len(self.feats)

  def __getitem__(self, idx):
      feat, label = self.feats[idx], self.labels[idx]
      label = [lab.strip() for lab in label]

      if self.has_unk:  # map the unknown symbols to unknown
        label = [self.arpa2idx.get(lab, self.arpa2idx['<unk>']) for lab in label]
      else:
        label = [self.arpa2idx[lab] for lab in label]

      lab = torch.as_tensor(label, dtype=torch.long)
      assert lab.numel() > 0, "Empty target sequence."
      return torch.as_tensor(feat, dtype=torch.float32), lab

define padding function

In [23]:
def pad_collate(batch, pad_value_feat=0.0, pad_value_label=PAD_ID):
    '''
      for collate_fn in DataLoader function

    args:
      batch: a list of tuples (mfcc, label)
      return: padded_mfccs, padded_labels, input_lengths, target_lengths
    '''
    mfccs, labels = zip(*batch)

    # find max length for mfcc (time step) and label in the current batch
    feat_lengths = [mfcc.shape[0] for mfcc in mfccs]
    label_lengths = [label.shape[0] for label in labels]
    max_len_feats = max(feat_lengths)
    max_len_labels = max(label_lengths)

    # calculate lengths of input and target lengths
    input_lengths = torch.tensor(feat_lengths, dtype=torch.long)
    target_lengths = torch.tensor(label_lengths, dtype=torch.long)

    # pad mfcc matrices and labels
    padded_mfccs = [F.pad(mfcc, (0, 0, 0, max_len_feats - mfcc.shape[0]), value=pad_value_feat) for mfcc in mfccs]
    padded_labels = [F.pad(label, (0, max_len_labels - label.shape[0]), value=pad_value_label) for label in labels]

    return torch.stack(padded_mfccs), torch.stack(padded_labels), input_lengths, target_lengths

In [25]:
print("""
for batch, _, _, _ in train_loader:
  for instance in batch:
    print(instance.size())

for batch, _, _, _ in dev_loader:
  for instance in batch:
    print(instance.size())
""")


for batch, _, _, _ in train_loader:
  for instance in batch:
    print(instance.size())

for batch, _, _, _ in dev_loader:
  for instance in batch:
    print(instance.size())



execute and review

In [26]:
# NOw we passed it to the Phoneme ASRDataset, meaning it should be correctly padded
# But again this might just have to do with how it was oragnized by me during preprocessing
# VS how Aaron processed it
# No nevermind according to his code its
train_ds= PhonemeASRDataset(train_feats, train_labels, arpa2idx=arpa2idx)
print(train_ds.labels)
print("pad collate should activate here")
train_loader= DataLoader(train_ds, batch_size=2, # can adjust
                          shuffle=True, collate_fn=pad_collate) # yields batch_size x max_len x num_feats as one training batch

a, b, _, _ = next(iter(train_loader))

[['m', 'eh', 'n', 'iy', 'w', 'eh', 'l', 'th', 'iy', 't', 'a', 'y', 'k', 'ux', 'n', 's', 'p', 'l', 'er', 'jh', 'd', 'ix', 'm', 'b', 'ao', 'q', 'b', 'o', 'w', 'th', 'ix', 'y', 'aa', 'q', 'ih', 'n', 'ix', 's', 'k', 'ux', 'n', 'er'], ['ch', 'ih', 'p', 'o', 'w', 's', 'p', 'o', 'w', 'n', 'd', 'q', 'ae', 'l', 'ax', 'm', 'o', 'w', 'n', 'iy', 'p', 'e', 'y', 'm', 'ix', 'n', 't', 's', 'q', 'ix', 'n', 't', 'ix', 'l', 'dh', 'ax', 'l', 'e', 'y', 'r', 'ih', 's', 'p', 'aa', 's', 'b', 'l', 'd', 'e', 'y', 't'], ['q', 'ao', 'f', 'ah', 'n', 'y', 'l', 'g', 'ih', 't', 'b', 'ae', 'g', 'm', 'ao', 'r', 'dh', 'eh', 'n', 'y', 'ax', 'p', 'uh', 'r', 'ih', 'n'], ['b', 'aa', 'b', 'f', 'a', 'w', 'n', 'm', 'ao', 'r', 'k', 'l', 'ae', 'm', 'z', 'q', 'ix', 't', 'dh', 'ix', 'q', 'o', 'w', 'sh', 'n', 'z', 'q', 'eh', 'd', 'jh'], ['dh', 'ix', 'k', 'l', 'a', 'w', 'd', 'b', 'axr', 's', 'k', 'ah', 'r', 'ao', 'f', 'ax', 'b', 'r', 'ah', 'p', 'l', 'ix'], ['b', 'ix', 't', 'n', 'a', 'w', 'sh', 'iy', 'l', 'uh', 'k', 't', 'q', 'ah', '

####from tensors to seqences
decoding helper

In [27]:
def decode_ctc(pred_seq, blank=BLANK_ID):
    out, prev = [], None
    for p in pred_seq.tolist():
        if p != blank and p != prev:
            out.append(p)
        prev = p
    return out

phoneme errrrrate

In [45]:
def evaluate_PER(model, data_loader, arpa2idx, debug=False):
    model.eval()
    total_phonemes = 0
    total_errors = 0

    empty_decodes = 0
    n_items = 0

    # trouble shooting every step here
    with torch.no_grad():
        for x, y, input_lengths, target_lengths in data_loader:
            # Move to device
            x, y = x.to(SETTING['device']), y.to(SETTING['device'])
            input_lengths, target_lengths = compute_cnn_output_lengths(model, input_lengths), target_lengths.to(SETTING['device'])

            # Forward pass
            log_probs = model(x)  # (B, T, C)
            # Check for NaN/Inf
            if torch.isnan(log_probs).any() or torch.isinf(log_probs).any():
                if debug: print("Warning: NaN or Inf detected in model output during evaluation")
                continue

            # Compute CNN output lengths
            T = log_probs.size(1)  # if log_probs.dim() == 3 else log_probs.size(0)
            cnn_output_lengths = torch.clamp(input_lengths, max=T)

            # Get predictions
            preds = log_probs.argmax(dim=-1)

            for i in range(preds.size(0)):
                n_items += 1
                # Get lengths as scalars
                valid_length, target_length = cnn_output_lengths[i].item(), target_lengths[i].item()

                if valid_length <= 0 or target_length <= 0:
                    if debug: print(f"Skip item {i}: valid_len={valid_length}, target_len={target_length}")
                    continue

                # Extract sequences
                pred_seq = preds[i, :valid_length]
                pred_decoded = decode_ctc(pred_seq)
                target_seq = y[i, :target_length].tolist()
                target_seq = [t for t in target_seq if t != PAD_ID]

                # Skip if either sequence is empty
                if not target_seq:
                    if debug: print(f"Skip item {i}: empty target after PAD filter.")
                    continue
                if not pred_decoded:
                    empty_decodes += 1

                # Calculate errors
                errors = levenshtein_distance(pred_decoded, target_seq)
                total_errors += errors
                total_phonemes += len(target_seq)

                assert (target_lengths <= cnn_output_lengths).all(), "Eval: target exceeds input length"

    # printing if the decoding yielded empty
    PER = total_errors / total_phonemes if total_phonemes > 0 else float('inf')
    if debug and n_items > 0:
        print(f"Decoded-empty items: {empty_decodes}/{n_items} ({empty_decodes / max(1,n_items):.1%})")
        print(f"PER = {PER:.4f} over {total_phonemes} reference phones")

    return PER

In [46]:
# def levenshtein_distance(seq1, seq2):
#   m, n = len(seq1), len(seq2)
#   dp = np.zeros((m+1, n+1), dtype=int)

#   for i in range(m + 1):
#     dp[i][0] = i
#   for j in range(n + 1):
#     dp[0][j] = j

#   for i in range(1, m + 1):
#       for j in range(1, n + 1):
#           if seq1[i - 1] == seq2[j - 1]:
#               dp[i][j] = dp[i - 1][j - 1]
#           else:
#             dp[i][j] = 1 + min(
#                 dp[i - 1][j],    # deletion
#                 dp[i][j - 1],    # insertion
#                 dp[i - 1][j - 1] # substitution
#                 )
#   return dp[m][n]

def levenshtein_distance(a, b):
    """
    our function was perfectly fine, but i just copypasted this "optimized" one
    """
    if a is b: return 0
    if not a: return len(b)
    if not b: return len(a)

    if len(b) < len(a):
        a, b = b, a

    m, n = len(a), len(b)
    prev = [j * 1 for j in range(n + 1)]

    for i in range(1, m + 1):
        cur = [i * 1] + [0] * n
        x = a[i - 1]
        for j in range(1, n + 1):
            y = b[j - 1]
            cur[j] = prev[j - 1] if x == y else min(
                prev[j] + 1,   # delete
                cur[j - 1] + 1, # insert
                prev[j - 1] + 1 # substitute
            )
        prev = cur

    return prev[n]

In [47]:
# didn't use it because was hard to understand the debug brain

def compare(model, data_loader, idx2arpa):
    # idx2arpa = dict((idx, arpa) for arpa, idx in arpa2idx.items())
    model.eval()

    with torch.no_grad():
        for x, y, input_lengths, target_lengths in data_loader:
            # Move to device
            x, y = x.to(SETTING['device']), y.to(SETTING['device'])
            input_lengths, target_lenghts = input_lengths.to(SETTING['device']), target_lengths.to(SETTING['device'])

            # Compute CNN output lengths
            cnn_output_lengths = compute_cnn_output_lengths(model, input_lengths).to(SETTING['device'])

            # Forward pass
            log_probs = model(x)  # (B, T, C)

            # Check for NaN/Inf
            if torch.isnan(log_probs).any() or torch.isinf(log_probs).any():
                print("Warning: NaN or Inf detected in model output during evaluation")
                continue

            # Get predictions
            preds = log_probs.argmax(dim=-1)  # (B, T)

            for i in range(preds.size(0)):
                # Get lengths as scalars
                valid_length, target_length = cnn_output_lengths[i].item(), target_lengths[i].item()

                # Extract sequences
                pred_seq = preds[i, :valid_length]
                pred_decoded = [idx2arpa[seq] for seq in decode_ctc(pred_seq)]
                target_seq = y[i, :target_length].tolist()
                target_seq = [idx2arpa[t] for t in target_seq if t != PAD_ID]

                print(pred_decoded, target_seq)

because our CNN blocks don't reduce time dimension (stride=1)

we don't have to recompute anything, so i subbed with a nothing function returning the same tensor

In [48]:
@torch.no_grad()
def compute_cnn_output_lengths(model, input_lengths):
    if torch.is_tensor(input_lengths):
        return input_lengths.to(dtype=torch.long, device=SETTING["device"])
    return torch.tensor(input_lengths, dtype=torch.long, device=SETTING["device"])

In [49]:
# def compute_cnn_output_lengths(model, input_lengths):
#     """
#     Computes the output time lengths after passing through the CNN part
#     of ASRModel (init_conv + res_blocks).
#     """
#     if isinstance(input_lengths, torch.Tensor):
#         input_lengths = input_lengths.tolist()

#     new_lengths = []
#     for length in input_lengths:
#         dummy = torch.zeros(1, model.init_conv[0].in_channels, length).to(device=SETTING["device"])
#         with torch.no_grad():
#             out = model.init_conv(dummy)
#             out = model.res_blocks(out)
#         new_lengths.append(out.shape[-1])
#     return torch.tensor(new_lengths, dtype=torch.long)

In [50]:
# ## VIBE CODED - doesn't run the legths that were in cashe
# _len_cache = {}
# @torch.no_grad()
# def compute_cnn_output_lengths(model, input_lengths):
#     if torch.is_tensor(input_lengths):
#         input_lengths = input_lengths.tolist()
#     wanted = [int(L) for L in input_lengths]

#     # probe only unseen lengths
#     for L in set(wanted) - _len_cache.keys():
#         was_training = model.training
#         model.eval()                       # avoid BN stat updates
#         C = model.init_conv[0].in_channels
#         dummy = torch.zeros(1, C, L, device=SETTING["device"])
#         out = model.res_blocks(model.init_conv(dummy))
#         _len_cache[L] = out.shape[-1]
#         if was_training: model.train()

#     return torch.tensor([_len_cache[L] for L in wanted], dtype=torch.long, device=SETTING["device"])

In [51]:
# def beam_search_decoder(probs, beam_width=10, blank=0):
#     """
#     Beam search decoder for CTC.
#     probs: (time, num_classes) - log probabilities
#     """
#     T, V = probs.shape
#     beams = [(tuple(), 0.0)]  # (prefix, score)

#     for t in range(T):
#         new_beams = defaultdict(lambda: -float("inf"))
#         for prefix, score in beams:
#             for c in range(V):
#                 p = probs[t, c].item()
#                 new_prefix = prefix + (c,)
#                 new_beams[new_prefix] = max(new_beams[new_prefix], score + p)
#         # Keep top beam_width
#         beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]

#     # Collapse repeats & remove blanks
#     best_seq, _ = beams[0]
#     collapsed = []
#     prev = None
#     for c in best_seq:
#         if c != blank and c != prev:
#             collapsed.append(c)
#         prev = c
#     return collapsed

In [52]:
# changed the mode from mac to min
class EarlyStopping:
  def __init__(self, patience=  10, delta= 1e-5, mode= 'min'):
    self.patience=  patience # the number of epochs to wait observe loss
    self.counter= 0
    self.best_score= None
    self.early_stop= False
    self.delta= delta
    self.mode= mode

  def __call__(self, current):
    if self.best_score is None:
      self.best_score= current
      return False

    improvement= (current - self.best_score) if self.mode ==  "max" else (self.best_score - current)

    if improvement <=  self.delta:
      self.counter +=  1
      if self.counter >=  self.patience:
        self.early_stop= True
    else:
      self.best_score= current
      self.counter= 0

    return self.early_stop

In [53]:
class ResBlock(nn.Module):
    def __init__(self, num_cnn_layers, cnn_filters, cnn_kernel_size, use_resnet=True):
        super(ResBlock, self).__init__()
        self.use_resnet = use_resnet
        layers = []
        for _ in range(num_cnn_layers):
            layers.append(nn.Conv1d(cnn_filters, cnn_filters, cnn_kernel_size, padding=cnn_kernel_size // 2))
            layers.append(nn.BatchNorm1d(cnn_filters))
            layers.append(nn.PReLU())
        self.res_block = nn.Sequential(*layers)

    def forward(self, x):
        res = self.res_block(x)
        if self.use_resnet:
            return x + res
        return res


class ASRModel(nn.Module):
    def __init__(self, ip_channel, num_classes, num_res_blocks=3, num_cnn_layers=1, cnn_filters=50,
                 cnn_kernel_size=15, num_rnn_layers=2, rnn_dim=170, num_dense_layers=1,
                 dense_dim=300, use_birnn=True, use_resnet=True, rnn_type="lstm", rnn_dropout=0.6):
        super(ASRModel, self).__init__()

        # Initial Conv layer
        self.init_conv = nn.Sequential(
            nn.Conv1d(ip_channel, cnn_filters, cnn_kernel_size, padding=cnn_kernel_size // 2),
            nn.BatchNorm1d(cnn_filters),
            nn.PReLU()
        )

        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResBlock(num_cnn_layers, cnn_filters, cnn_kernel_size, use_resnet) for _ in range(num_res_blocks)]
        )

        # RNN layers
        rnn_module = nn.GRU if rnn_type.lower() == "gru" else nn.LSTM
        rnn_input_dim = cnn_filters
        self.rnns = nn.ModuleList()
        for _ in range(num_rnn_layers):
            if use_birnn:
                self.rnns.append(rnn_module(rnn_input_dim, rnn_dim, batch_first=True, dropout=rnn_dropout, bidirectional=True))
                rnn_input_dim = rnn_dim * 2
            else:
                self.rnns.append(rnn_module(rnn_input_dim, rnn_dim, batch_first=True, dropout=rnn_dropout))
                rnn_input_dim = rnn_dim

        # Dense layers
        dense_layers = []
        dense_in_dim = rnn_input_dim
        for _ in range(num_dense_layers):
            dense_layers.append(nn.Linear(dense_in_dim, dense_dim))
            dense_layers.append(nn.ReLU())
            dense_in_dim = dense_dim
        self.dense_layers = nn.Sequential(*dense_layers)

        # Output layer
        self.out_layer = nn.Linear(dense_in_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
        x = self.init_conv(x)
        x = self.res_blocks(x)

        x = x.transpose(1, 2) # -> (B, T, C)
        for rnn in self.rnns:
            x, _ = rnn(x)

        x = self.dense_layers(x)
        x = self.out_layer(x)
        x = F.log_softmax(x, dim=-1)  # log_softmax for CTC Loss
        return x  # (B, T, C)

def save_checkpoint(model, optimizer, filename='checkpoint.pth.tar'):
  checkpoint = {
      "state_dict": model.state_dict(),
      "optimizer": optimizer.state_dict(),
  }
  print("=> Saving checkpoint")
  torch.save(checkpoint, filename)

def load_checkpoint(checkpoint, model, optimizer):
  print("=> Loading checkpoint")
  model.load_state_dict(checkpoint["state_dict"])
  optimizer.load_state_dict(checkpoint["optimizer"])


In [54]:
def train_fn(train_loader, model, optimizer, loss_fn):
    model.train()
    total_losses = []

    inner_loop = tqdm(train_loader, desc='Batch', leave=False, position=0)

    for x, y, input_lengths, target_lengths in inner_loop:
      # Target lengths should be the longest tensor in the batch (for labels)
      # Check for NaNs/Infs before placing it to device
      if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf in input!")
        raise ValueError("Invalid model input")

      # place data to device
      x, y = x.to(SETTING["device"]), y.to(SETTING["device"])
      input_lengths = compute_cnn_output_lengths(model, input_lengths).to(SETTING["device"])
      target_lengths = target_lengths.to(SETTING["device"])

      log_probs = model(x) # (B, T, C)
      T = log_probs.size(1) # TRUE time length
      input_lengths = torch.clamp(input_lengths, max=T)  # also trim the batch
      log_probs = log_probs.transpose(0, 1) # (T, B, C) for CTCLoss

      # Feasibility check: no silent zero-loss
      bad = (target_lengths > input_lengths).nonzero(as_tuple=False).flatten()
      if bad.numel() > 0:
            raise ValueError(
                f"CTC invalid: target_lengths > input_lengths for items {bad.tolist()} "
                f"(max target={int(target_lengths.max())}, max input={int(input_lengths.max())}, T={T})"
            )

      y_concat = torch.cat([y[i][:target_lengths[i].item()] for i in range(y.size(0))])
      # calculate loss
      loss = loss_fn(log_probs, y_concat, input_lengths, target_lengths)

      # Check for NaNs/Infs among loss
      if torch.isnan(loss) or torch.isinf(loss) or torch.isnan(log_probs).any() or torch.isinf(log_probs).any():
        print("NaN or Inf in loss!")
        raise ValueError("Invalid loss")

      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
      optimizer.step()

      total_losses.append(loss.item())

    return sum(total_losses) / len(total_losses)

In [55]:
def eval_fn(dev_loader, model, loss_fn):
    model.eval()
    total_losses = []

    with torch.no_grad():
        for x, y, input_lengths, target_lengths in dev_loader:
            # Check for NaNs/Infs in input
            if torch.isnan(x).any() or torch.isinf(x).any():
                print("NaN or Inf in input during evaluation!")
                continue  # Skip this batch instead of raising

            x, y = x.to(SETTING['device']), y.to(SETTING['device'])
            input_lengths = compute_cnn_output_lengths(model, input_lengths).to(SETTING['device'])
            target_lengths = target_lengths.to(SETTING['device'])

            log_probs = model(x)  # (B, T, C)
            T = log_probs.size(1)  # true time steps
            input_lengths = torch.clamp(input_lengths, max=T)  # same
            log_probs = log_probs.transpose(0, 1)  # (T, B, C)

            # sanity check
            if (target_lengths > input_lengths).any():
                raise ValueError("CTC invalid in eval: target length > input length")

            y_concat = torch.cat([y[i, :target_lengths[i].item()] for i in range(y.size(0))])
            loss = loss_fn(log_probs, y_concat, input_lengths, target_lengths)

            if torch.isnan(loss) or torch.isinf(loss):
                print("NaN or Inf in loss during evaluation!")
                continue

            total_losses.append(loss.item())

    return sum(total_losses) / len(total_losses) if total_losses else float('inf')

In [56]:
SETTING = {
    "seed": 43,
    "learning_rate": 1e-3,  # made much bigger
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
    "batch_size": 64,
    "weight_decay": 1e-4,
    "num_epochs": 30,
    "num_workers": 2,
    "pin_memory": True,
    "load_model": True,
    "load_model_file": "/content/drive/MyDrive/ResNetCTC.path.tar",
    "patience": 10,
#    "feat_dir": directory of features
#    "label_dir": direct1ory of labels
}

In [57]:
torch.manual_seed(SETTING["seed"])

#------------------------------- DataLoader -------------------------------#

train_ds= PhonemeASRDataset(train_feats, train_labels, arpa2idx= arpa2idx)
dev_ds= PhonemeASRDataset(dev_feats, dev_labels, arpa2idx= arpa2idx)
# test_ds =  test 어쩌구

train_loader= DataLoader(train_ds,
                          batch_size= SETTING["batch_size"],
                          shuffle= True, collate_fn= pad_collate,
                         num_workers= SETTING["num_workers"],
                         pin_memory= SETTING["pin_memory"]) # yields (B, T, C)
dev_loader= DataLoader(dev_ds,
                      batch_size= SETTING["batch_size"],
                       shuffle= False, collate_fn= pad_collate,
                       num_workers= SETTING["num_workers"],
                      pin_memory= SETTING["pin_memory"])

early_stopping =  EarlyStopping(patience= SETTING["patience"], delta= 0.001, mode =  "min")

INPUT_DIM = 39

# --------------------------------- Model & Optimizer --------------------------------- #
 # Why 4 tensors lol
a,b,c,d = next(iter(train_loader))

# idx2arpa = dict((idx, arpa) for arpa, idx in arpa2idx.items())

model = ASRModel(
    ip_channel=INPUT_DIM,
    num_classes=NUM_CLASSES,
    num_res_blocks=3,
    num_cnn_layers=1,
    cnn_filters=50,
    cnn_kernel_size=15, # changing this caused error
    num_rnn_layers=2,
    rnn_dim=170,
    num_dense_layers=1,
    dense_dim=300,
    use_birnn=True,
    rnn_type="lstm",
    rnn_dropout=0.2
).to(SETTING["device"])

# change from chatgpt, apparently avoids getting stuck on something
with torch.no_grad():
    if model.out_layer.bias is not None:
        model.out_layer.bias[BLANK_ID] -= 2.0

loss_fn = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True, reduction='sum')  # blank token idx=0
#optimizer = optim.Adam(model.parameters(), lr=SETTING['learning_rate'], weight_decay=SETTING['weight_decay'])

# -------



In [58]:
print(SETTING['weight_decay'])

0.0001


In [59]:
import sympy
print(sympy.__version__)
import sympy
import sympy.printing

1.14.0


In [60]:
optimizer = optim.Adam(model.parameters(), lr=SETTING['learning_rate'], weight_decay=SETTING['weight_decay'])

In [61]:
scheduler = ReduceLROnPlateau(optimizer, 'min')

train_losses, PER_list_train = [], []
dev_losses, PER_list_dev = [], []

torch.autograd.set_detect_anomaly(True)
outer_loop = tqdm(range(SETTING["num_epochs"]),desc="Epoch", position=0)
eval_interval = 5

for epoch in outer_loop:
  train_avg_loss = train_fn(train_loader, model, optimizer, loss_fn)
  train_losses.append(train_avg_loss)
  dev_avg_loss = eval_fn(dev_loader, model, loss_fn)
  dev_losses.append(dev_avg_loss)

  scheduler.step(dev_avg_loss)

  if epoch % eval_interval == 0:
    model.eval()
    with torch.no_grad():
      PER_val = evaluate_PER(model, dev_loader, arpa2idx)

      PER_train = evaluate_PER(model, train_loader, arpa2idx)

      PER_list_dev.append(PER_val)
      PER_list_train.append(PER_train)
  else:
    PER_train, PER_val = PER_list_train[-1], PER_list_dev[-1]

  tqdm.write(f"Epoch {epoch+1}/{SETTING['num_epochs']} - Train Loss: {train_avg_loss:.6f} - Val Loss: {dev_avg_loss:.6f} - Train PER: {PER_train:.6f} - Val PER: {PER_val:.6f}")

  if early_stopping(PER_val):
    print(f"Early stopping triggered at epoch {epoch+1}")
    break

save_checkpoint(model, optimizer, '/content/drive/MyDrive/ResNetCTC.pth.tar')
file = f'seting+{time.time}'

with open(file, mode = 'w') as f_out:
  for setting in SETTING:
    f_out.write(setting)

# compare(model, dev_loader, arpa2idx)


Epoch:   3%|▎         | 1/30 [04:04<1:58:18, 244.77s/it]

Epoch 1/30 - Train Loss: 11663.355711 - Val Loss: 5967.966260 - Train PER: 0.980772 - Val PER: 0.980824


Epoch:   7%|▋         | 2/30 [06:23<1:25:00, 182.15s/it]

Epoch 2/30 - Train Loss: 3675.716397 - Val Loss: 2527.081265 - Train PER: 0.980772 - Val PER: 0.980824


Epoch:  10%|█         | 3/30 [08:42<1:13:06, 162.47s/it]

Epoch 3/30 - Train Loss: 2097.930565 - Val Loss: 1878.703611 - Train PER: 0.980772 - Val PER: 0.980824


Epoch:  13%|█▎        | 4/30 [11:00<1:06:20, 153.09s/it]

Epoch 4/30 - Train Loss: 1636.531718 - Val Loss: 1565.044424 - Train PER: 0.980772 - Val PER: 0.980824


Epoch:  17%|█▋        | 5/30 [13:19<1:01:37, 147.91s/it]

Epoch 5/30 - Train Loss: 1335.724238 - Val Loss: 1319.342482 - Train PER: 0.980772 - Val PER: 0.980824


Epoch:  20%|██        | 6/30 [17:31<1:13:18, 183.26s/it]

Epoch 6/30 - Train Loss: 1099.379485 - Val Loss: 1156.559916 - Train PER: 0.160186 - Val PER: 0.178536


Epoch:  23%|██▎       | 7/30 [19:50<1:04:38, 168.64s/it]

Epoch 7/30 - Train Loss: 895.473922 - Val Loss: 956.664358 - Train PER: 0.160186 - Val PER: 0.178536


Epoch:  27%|██▋       | 8/30 [22:07<58:13, 158.79s/it]  

Epoch 8/30 - Train Loss: 731.477083 - Val Loss: 827.523191 - Train PER: 0.160186 - Val PER: 0.178536


Epoch:  30%|███       | 9/30 [24:26<53:22, 152.52s/it]

Epoch 9/30 - Train Loss: 583.595154 - Val Loss: 694.257481 - Train PER: 0.160186 - Val PER: 0.178536


Epoch:  33%|███▎      | 10/30 [26:44<49:21, 148.09s/it]

Epoch 10/30 - Train Loss: 465.566543 - Val Loss: 544.677104 - Train PER: 0.160186 - Val PER: 0.178536


Epoch:  37%|███▋      | 11/30 [30:56<56:57, 179.87s/it]

Epoch 11/30 - Train Loss: 372.841519 - Val Loss: 508.444260 - Train PER: 0.058532 - Val PER: 0.078447


Epoch:  37%|███▋      | 11/30 [31:10<53:50, 170.01s/it]


KeyboardInterrupt: 