In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount= True)

import os
os.chdir('./drive/MyDrive')

Mounted at /content/drive


#### import

In [3]:
!pip install params
!pip install jiwer

Collecting params
  Downloading params-0.9.0-py3-none-any.whl.metadata (631 bytes)
Downloading params-0.9.0-py3-none-any.whl (11 kB)
Installing collected packages: params
Successfully installed params-0.9.0
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-4.0.0 rapidfuzz-3.13.0


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import editdistance as ed

from collections import defaultdict
import subprocess
import numpy as np
import matplotlib.pyplot
from IPython import display
from jiwer import wer
import params
import time
import re


#### read data
define help function

In [5]:
def read_TIMIT(path):
  '''
  args:
    path: path of TIMIT data(mfcc features, phoneme labels)
  return:
    feats: list of list for each audio samples
    labels: list of list for each audio samples
  '''

  feats, labels= [], []
  length_feats, length_labels= [], []

  # read processed TIMIT data
  # list of dictionarys with keys being 'mfcc', 'phonemes', 'path'
  samples= torch.load(path, weights_only= False)
  for idx in range(len(samples)):
    feats.append(samples[idx]['mfcc'])
    labels.append([phoneme.strip() for phoneme in samples[idx]['phonemes']])
  return feats, labels


execute and review

In [6]:
path= r'timit_mfcc_data.pt'
feats, labels= read_TIMIT(path)


# check mfcc feature matrix dimension
print(f'MFCC feature matrix Shape (one audio sample):\t{feats[-1].shape}')
# check IPA repository
print(f'phoneme labels(one audio sample):\t{labels[-1]}')

# seems like 'h#' marks sos and eos

MFCC feature matrix Shape (one audio sample):	(82, 39)
phoneme labels(one audio sample):	['h#', 'dh', 'ix', 's', 'pcl', 'p', 'iy', 'tcl', 'ch', 's', 'em', 'pcl', 'p', 'ow', 'z', 'iy', 'ax', 'm', 'pau', 'm', 'ay', 'tcl', 'b', 'ax-h', 'gcl', 'g', 'ih', 'n', 'm', 'ah', 'n', 'dcl', 'd', 'ey', 'h#']


In [7]:
print('\n'.join([f'num_frames:\t{len(feats[0])}', f'num_labels:\t{len(labels[0])}']))

num_frames:	78
num_labels:	35


In [8]:
# mark max length for mfcc features and labels
max_len_feats= max([len(feat) for feat in feats])
max_len_labels= max([len(label) for label in labels])


Process Labels

handle h#, pau, epi

In [9]:
# simply remove them
labels = [[symbol for symbol in label if symbol not in 'h# epi pau'] for label in labels]

handle DIPHTHONGS, SIMILAR SOUNDS in labels

In [10]:
# diphthongs

diphthongs= ['ey', 'aw', 'ay', 'ow']
# ey 'bait' -> split
# aw 'bout' -> split
# ay 'bite' -> split
# oy 'boy' -> oh + y
# ow 'boat' -> split

diphthong_regex= re.compile('|'.join(sorted(map(re.escape, diphthongs),
                                              key= len, reverse= True)))
oy_regex= re.compile('oy')

def split_diphthongs(label):
  label= ' '.join(label)
  label= diphthong_regex.sub(lambda x: ' '.join(x.group()), label)
  label = oy_regex.sub('oh y', label).split()
  return label

In [11]:
# similar sounds


merge_ipa= {
    # marginal sounds
    'ax-h': 'ax',
    'bcl': 'b',
    'dcl': 'd',
    'gcl': 'g',
    'kcl': 'k',
    'pcl': 'p',
    'tcl': 't',

    'en': 'n',
    'em': 'm',
    'el': 'l',
    'eng': 'ng',

    ## /ɹ/ sound
    # 'axr': 'r' ? 'ɹ' ?
    # 'dx': 'r',
    # 'nx': 'r',
    # 'er': 'r', 'ɹ' ?

    # /h/ sound
    'hh': 'h',
}



execute and review

In [12]:
# split diphthongs
labels = [split_diphthongs(label) for label in labels]
labels_merge = [[merge_ipa.get(symbol, symbol) for symbol in label] for label in labels]

In [13]:
# handling bcl, dcl, gcl, kcl, pcl, tcl
  # used as abbreviation for voiceless stop or plosive with closure
  # e.g. bcl : bilbial closure, tcl : alveolar closure

# after merging, there may appear identical consecutive phones

# if dcl and d > d
# if dcl was by itself > d

labels_final = []

for label, label_merge in zip(labels, labels_merge):
  label = ' '.join(label)
  label_merge = ' '.join(label_merge)
  for key, item in merge_ipa.items():
    if ' '.join([key, item]) in label:
      label_merge = re.sub(' '.join([item, item]), item, label_merge)

  labels_final.append(label_merge.split())

In [14]:
for i in range(10):
  print(*labels[i])
  print(*labels_merge[i])
  print(*labels_final[i])

sh uw w ax z hh o w l dx ix ng q aa nx uh hv ih z r aa kcl k w ax dh w ah n hv ae n dcl d
sh uw w ax z h o w l dx ix ng q aa nx uh hv ih z r aa k k w ax dh w ah n hv ae n d d
sh uw w ax z h o w l dx ix ng q aa nx uh hv ih z r aa k w ax dh w ah n hv ae n d
b er th dcl d e y pcl aa r dx iy z hh eh v kcl k ah pcl k e y kcl k s ix nx a y s kcl k r iy m
b er th d d e y p aa r dx iy z h eh v k k ah p k e y k k s ix nx a y s k k r iy m
b er th d e y p aa r dx iy z h eh v k ah p k e y k s ix nx a y s k r iy m
s ah m w ax m ix ng gcl g eh dx ix r ih l th r ih l a w dx ax v hv a w s w axr kcl k
s ah m w ax m ix ng g g eh dx ix r ih l th r ih l a w dx ax v hv a w s w axr k k
s ah m w ax m ix ng g eh dx ix r ih l th r ih l a w dx ax v hv a w s w axr k
dh ix pcl e y sh ix n tcl t q eh n dh ax s er dcl jh ix nx axr bcl b o w th axr kcl k uw pcl r eh dx ix ng f axr m dh ix l eh ng th iy ah pcl er e y sh en
dh ix p e y sh ix n t t q eh n dh ax s er d jh ix nx axr b b o w th axr k k uw p r eh dx ix ng 

In [15]:
# redefine labels and the dictionary
labels= labels_final

# split train / dev dataset
train_feats, dev_feats, train_labels, dev_labels= train_test_split(feats, labels)

#### create IPA dictionary

define help function

In [16]:
def create_IPAdictionary(labels):
  '''
  args:
    labels: list of list containing sequence of label for each audio sample
  return: ipa2idx
        dictionary of IPA_label to index
  '''
  ipas= set()
  for label in labels:
    ipas= ipas.union(set(label))
  ipas= sorted(ipas)

  ipa2idx= {ipa:(idx+2) for idx, ipa in enumerate(ipas)}
  ipa2idx['<blank>']= 0

  return ipa2idx

execute and review

In [17]:
ipa2idx= create_IPAdictionary(labels)
print(*sorted(ipa2idx))

<blank> a aa ae ah ao ax axr b ch d dh dx e eh er f g h hv ih ix iy jh k l m n ng nx o oh p q r s sh t th uh uw ux v w y z zh


In [18]:
print(f'the number of IPA labels in TIMIT:\t {len(ipa2idx)}')
print(f'Index of blank symbol <blank>":\t {ipa2idx["<blank>"]}')

the number of IPA labels in TIMIT:	 47
Index of blank symbol <blank>":	 0


In [19]:
# compare phonemic / phonetic symbols between
# provided TIMIT phonecode and transcription data
# https://catalog.ldc.upenn.edu/docs/LDC96S32/PHONCODE.TXT

existing= set(ipa2idx.keys())

phonecode= set(['b', 'd', 'g', 'p', 't', 'k', 'dx', 'q', # stops
             'jh', 'ch', # affricates
             's', 'sh', 'z', 'zh', 'f', 'th', 'v', 'dh', # frcatives
             'm', 'n', 'ng', 'em', 'en', 'eng', 'nx', # nasals
             'l', 'r', 'w', 'y', 'hh', 'hv', 'el', # semivowels and glides
             'iy', 'ih', 'eh', 'ey', 'ae', 'aa', 'aw', 'ay', 'ah', 'ao', 'oy', 'ow', 'uh', 'uw', 'ux', 'er', 'ax', 'ix', 'axr', 'ax-h', # vowels
             'pau', 'epi', 'h#', '1', '2',# others # epi: epenthetic silence # h# : begin/ end marker # 1 : primary stress marker # 2 : secondary stress marker
             ])

In [20]:
print(*sorted((existing - phonecode))) # symbols that does not exist in phonecode but in TIMIT corpus
print(*sorted((phonecode - existing))) # symbols that does not exist (or is removed from) in TIMIT corpus but in phonecode

<blank> a e h o oh
1 2 aw ax-h ay el em en eng epi ey h# hh ow oy pau


In [22]:
# review
# in TIMIT
  # there is no stress markers(not important to our project)
  # we intentionally removed epi, h# and pause
  # we intentionally split the diphthongs (aw, ay, ey, ow, oy)
  # we intentionally merged (1) ax-h, hh to h (2) el to l (3) en eng to n (4) pau to <blank>

#### Dataset + pad
define Dataset Class

In [23]:
class PhonemeASRDataset(Dataset):
  def __init__(self, feats, labels, ipa2idx):
    super(PhonemeASRDataset, self).__init__()
    self.feats, self.labels= feats, labels
    self.ipa2idx= ipa2idx

  def __len__(self):
    return len(self.feats)

  def __getitem__(self, idx):
      feat, label= self.feats[idx], self.labels[idx]
      label= [ipa2idx[ipa] for ipa in label]

      return torch.tensor(feat), torch.tensor(label, dtype= torch.long)

define padding function

In [24]:
def pad_collate(batch, pad_value_feat= 0, pad_value_label= 0):
    '''
      for collate_fn in DataLoader function

    args:
      batch: a list of tuples (mfcc, label)
      return: padded_mfccs, padded_labels, input_lengths, target_lengths
    '''

    mfccs, labels= zip(*batch)

    # find max length for mfcc(time step) and label in the current batch
    max_len_feats= max(mfcc.shape[0] for mfcc in mfccs)
    max_len_labels= max(label.shape[0] for label in labels)

    # pad mfcc matrices and labels
    padded_mfccs= [F.pad(mfcc, (0, 0, 0, max_len_feats - mfcc.shape[0]), value= pad_value_feat) for mfcc in mfccs]
    padded_labels= [F.pad(label, (0, max_len_labels - label.shape[0]), value= pad_value_label) for label in labels]

    # calculate lengths of input and target lengths
    input_lengths = torch.tensor([mfcc.shape[0] for mfcc in padded_mfccs], dtype = torch.long)
    target_lengths = torch.tensor([label.shape[0] for label in padded_labels], dtype = torch.long)


    # Stack the padded tensors
    padded_mfccs= torch.stack(padded_mfccs)
    padded_labels= torch.stack(padded_labels)

    return padded_mfccs, padded_labels, input_lengths, target_lengths



execute and review

In [25]:
train_ds= PhonemeASRDataset(train_feats, train_labels, ipa2idx= ipa2idx)
train_loader= DataLoader(train_ds, batch_size= 1, # can adjust
                          shuffle= True, collate_fn= pad_collate) # yields batch_size x max_len x num_feats as one training batch

In [26]:
a, b, _, _ = next(iter(train_loader))
print(a.shape, b.shape)

torch.Size([1, 102, 39]) torch.Size([1, 40])


### Model

define utility function

In [31]:
import torch

def evaluate_PER(model, data_loader):
    model.eval()
    total_phonemes = 0
    total_errors = 0

    with torch.no_grad():
        for x, y, input_lengths, target_lengths in data_loader:
            x, y = x.to(SETTING['device']), y.to(SETTING['device'])
            input_lengths = compute_cnn_output_lengths(model, input_lengths).to(SETTING['device'])
            target_lengths = target_lengths.to(SETTING['device'])

            # Forward pass
            log_probs = model(x)  # (B, T, C)

            # Get best path (greedy decoding) by argmax over classes
            preds = log_probs.argmax(dim=-1)  # (B, T)

            # Collapse repeated and remove blanks (assuming blank=0)
            # This is a common CTC decoding post-processing:
            def decode_ctc(pred_seq):
                prev = None
                decoded = []
                for p in pred_seq:
                    if p != prev and p != 0:  # skip blanks and repeated
                        decoded.append(p.item())
                    prev = p
                return decoded

            for i in range(preds.size(0)):
                pred_seq = preds[i, :input_lengths[i]]  # cut to valid length

                pred_decoded = decode_ctc(pred_seq)
                target_seq = y[i, :target_lengths[i]].tolist()

                # Calculate phoneme error (Levenshtein distance)
                errors = levenshtein_distance(pred_decoded, target_seq)
                total_errors += errors
                total_phonemes += len(target_seq)

    PER = total_errors / total_phonemes if total_phonemes > 0 else 0.0 ## remove total_phonemes if total_phonemes > 0
    return PER


def levenshtein_distance(seq1, seq2):
    # classic DP edit distance
    m, n = len(seq1), len(seq2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if seq1[i - 1] == seq2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(
                    dp[i - 1][j],    # deletion
                    dp[i][j - 1],    # insertion
                    dp[i - 1][j - 1] # substitution
                )
    return dp[m][n]



def compute_cnn_output_lengths(model, input_lengths):
    """
    Computes the output time lengths after passing through the CNN part
    of ASRModel (init_conv + res_blocks).
    """
    if isinstance(input_lengths, torch.Tensor):
        input_lengths = input_lengths.tolist()

    new_lengths = []
    for length in input_lengths:
        dummy = torch.zeros(1, model.init_conv[0].in_channels, length)
        with torch.no_grad():
            out = model.init_conv(dummy)
            out = model.res_blocks(out)
        new_lengths.append(out.shape[-1])
    return torch.tensor(new_lengths, dtype=torch.long)


def beam_search_decoder(probs, beam_width=10, blank=0):
    """
    Beam search decoder for CTC.
    probs: (time, num_classes) - log probabilities
    """
    T, V = probs.shape
    beams = [(tuple(), 0.0)]  # (prefix, score)

    for t in range(T):
        new_beams = defaultdict(lambda: -float("inf"))
        for prefix, score in beams:
            for c in range(V):
                p = probs[t, c].item()
                new_prefix = prefix + (c,)
                new_beams[new_prefix] = max(new_beams[new_prefix], score + p)
        # Keep top beam_width
        beams = sorted(new_beams.items(), key=lambda x: x[1], reverse=True)[:beam_width]

    # Collapse repeats & remove blanks
    best_seq, _ = beams[0]
    collapsed = []
    prev = None
    for c in best_seq:
        if c != blank and c != prev:
            collapsed.append(c)
        prev = c
    return collapsed


class EarlyStopping:
  def __init__(self, patience=  10, delta= 1e-5, mode= 'max'):
    self.patience=  patience # the number of epochs to wait observe loss
    self.counter= 0
    self.best_score= None
    self.early_stop= False
    self.delta= delta
    self.mode= mode

  def __call__(self, current):
    if self.best_score is None:
      self.best_score= current
      return False

    improvement= (current - self.best_score) if self.mode ==  "max" else (self.best_score - current)

    if improvement <=  self.delta:
      self.counter +=  1
      if self.counter >=  self.patience:
        self.early_stop= True
    else:
      self.best_score= current
      self.counter= 0

    return self.early_stop

def get_loader(feat_path, label_path):
  dataset = Dataset(
      feat_path,
  )



define a model

In [62]:
"""
PyTorch conversion of the TensorFlow/Keras ResNet-BiRNN ASR model
Original Author (TF version): Manish Dhakal
Converted & extended with CTC + beam search decoding: [Your Name]
Year: 2025
"""

class ResBlock(nn.Module):
    def __init__(self, num_cnn_layers, cnn_filters, cnn_kernel_size, use_resnet=True):
        super(ResBlock, self).__init__()
        self.use_resnet = use_resnet
        layers = []
        for _ in range(num_cnn_layers):
            layers.append(nn.Conv1d(cnn_filters, cnn_filters, cnn_kernel_size, padding=cnn_kernel_size // 2))
            layers.append(nn.BatchNorm1d(cnn_filters))
            layers.append(nn.PReLU())
        self.res_block = nn.Sequential(*layers)

    def forward(self, x):
        res = self.res_block(x)
        if self.use_resnet:
            return x + res
        return res


class ASRModel(nn.Module):
    def __init__(self, ip_channel, num_classes, num_res_blocks=3, num_cnn_layers=1, cnn_filters=50,
                 cnn_kernel_size=15, num_rnn_layers=2, rnn_dim=170, num_dense_layers=1,
                 dense_dim=300, use_birnn=True, use_resnet=True, rnn_type="lstm", rnn_dropout=0.15):
        super(ASRModel, self).__init__()

        # Initial Conv layer
        self.init_conv = nn.Sequential(
            nn.Conv1d(ip_channel, cnn_filters, cnn_kernel_size, padding=cnn_kernel_size // 2),
            nn.BatchNorm1d(cnn_filters),
            nn.PReLU()
        )

        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResBlock(num_cnn_layers, cnn_filters, cnn_kernel_size, use_resnet) for _ in range(num_res_blocks)]
        )

        # RNN layers
        rnn_module = nn.GRU if rnn_type.lower() == "gru" else nn.LSTM
        rnn_input_dim = cnn_filters
        self.rnns = nn.ModuleList()
        for _ in range(num_rnn_layers):
            if use_birnn:
                self.rnns.append(rnn_module(rnn_input_dim, rnn_dim, batch_first=True, dropout=rnn_dropout, bidirectional=True))
                rnn_input_dim = rnn_dim * 2
            else:
                self.rnns.append(rnn_module(rnn_input_dim, rnn_dim, batch_first=True, dropout=rnn_dropout))
                rnn_input_dim = rnn_dim

        # Dense layers
        dense_layers = []
        dense_in_dim = rnn_input_dim
        for _ in range(num_dense_layers):
            dense_layers.append(nn.Linear(dense_in_dim, dense_dim))
            dense_layers.append(nn.ReLU())
            dense_in_dim = dense_dim
        self.dense_layers = nn.Sequential(*dense_layers)

        # Output layer
        self.out_layer = nn.Linear(dense_in_dim, num_classes)

    def forward(self, x):
        x = x.transpose(1, 2) # (B, T, F) -> (B, F, T)
        x = self.init_conv(x)
        x = self.res_blocks(x)

        x = x.transpose(1, 2) # -> (B, T, F)
        for rnn in self.rnns:
            x, _ = rnn(x)

        x = self.dense_layers(x)
        x = self.out_layer(x)
        x = F.log_softmax(x, dim=-1)  # log_softmax for CTC Loss
        return x  # (B, T, C)

### Train

define train function

In [77]:
def train_fn(train_loader, model, optimizer, loss_fn):
    model.train()
    total_loss = []

    inner_loop = tqdm(train_loader, desc='Batch', leave=False, position=0)

    for x, y, input_lengths, target_lengths in inner_loop:
      # Check for NaNs/Infs before placing it to device
      if torch.isnan(x).any() or torch.isinf(x).any():
        print("NaN or Inf in input!")
        raise ValueError("Invalud model input")

      # place data to device
      x, y = x.to(SETTING["device"]), y.to(SETTING["device"])
      input_lengths = compute_cnn_output_lengths(model, input_lengths).to(SETTING["device"])
      target_lengths = target_lengths.to(SETTING["device"])

      log_probs = model(x)  # (B, T, C)
      log_probs = log_probs.transpose(0, 1)  # (T, B, C)

      # Check for NaNs/Infs before loss
      if torch.isnan(log_probs).any() or torch.isinf(log_probs).any():
          print("NaN or Inf in model output!")
          raise ValueError("Invalid model output")

      y_concat = torch.cat([y[i][:target_lengths[i]] for i in range(y.size(0))])
      # calculate loss
      loss = loss_fn(log_probs, y_concat, input_lengths, target_lengths)

      # Check for NaNs/Infs among loss
      if torch.isnan(loss) or torch.isinf(loss) or torch.isnan(log_probs).any() or torch.isinf(log_probs).any():
        print("NaN or Inf in loss!")
        raise ValueError("Invalid loss")

      optimizer.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
      optimizer.step()


      total_loss.append(loss.item())

    return sum(total_loss) / len(total_loss)


In [74]:
#------------------------------- Traing Settings -------------------------------#
SETTING= {
    "seed": 43,
    "learning_rate": 1e-5, ## gradient exploded at epoch 5/70
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
    "batch_size": 64,
    "weight_decay": 1e-4,
    "num_epochs": 70,
    "num_workers": 2,
    "pin_memory": True,
    "load_model": False,
    "load_model_file": "./drive/MyDrive/ResLSTM.path.tar",
    "patience": 10,
#    "feat_dir": directory of features
#    "label_dir": direct1ory of labels

}

torch.manual_seed(SETTING["seed"])


#------------------------------- DataLoader -------------------------------#

train_ds= PhonemeASRDataset(train_feats, train_labels, ipa2idx= ipa2idx)
dev_ds= PhonemeASRDataset(dev_feats, dev_labels, ipa2idx= ipa2idx)
# test_ds =  test 어쩌구

train_loader= DataLoader(train_ds,
                          batch_size= SETTING["batch_size"],
                          shuffle= True, collate_fn= pad_collate,
                         num_workers= SETTING["num_workers"],
                         pin_memory= SETTING["pin_memory"]) # yields (B, T, C)
dev_loader= DataLoader(dev_ds,
                      batch_size= SETTING["batch_size"],
                       shuffle= False, collate_fn= pad_collate,
                       num_workers= SETTING["num_workers"],
                      pin_memory= SETTING["pin_memory"])

early_stopping =  EarlyStopping(patience= SETTING["patience"], delta= 0.001, mode =  "max")

In [75]:
INPUT_DIM = 39
NUM_CLASSES = 47

model = ASRModel(
    ip_channel=INPUT_DIM,
    num_classes=NUM_CLASSES,
    num_res_blocks=5,
    num_cnn_layers=2,
    cnn_filters=50,
    cnn_kernel_size=15,
    num_rnn_layers=2,
    rnn_dim=170,
    num_dense_layers=1,
    dense_dim=340,
    use_birnn=True,
    rnn_type="lstm",
    rnn_dropout=0.15
).to(SETTING["device"])



loss_fn = nn.CTCLoss(blank=0, zero_infinity=True, reduction='sum')  # blank token idx=0
optimizer = optim.Adam(model.parameters(), lr=SETTING['learning_rate'], weight_decay=SETTING['weight_decay'])


In [76]:
losses, PER_list, PER_list_train = [], [], []

torch.autograd.set_detect_anomaly(True)
outer_loop = tqdm(range(SETTING["num_epochs"]),desc="Epoch", position=0)
eval_interval = 5

for epoch in outer_loop:
  avg_loss = train_fn(train_loader, model, optimizer, loss_fn)
  losses.append(avg_loss)

  if (epoch + 1) % eval_interval == 0:
    model.eval()
    with torch.no_grad():
      PER_val = evaluate_PER(model, dev_loader)
      PER_train = evaluate_PER(model, train_loader)
      PER_list.append(PER_val)
      PER_list_train.append(PER_train)
  else:
    PER_val = PER_list[-1] if PER_list else 0.0 ## remove if PER_list else 0.0
    PER_train = PER_list_train[-1] if PER_list_train else 0.0 ## remove if PER_list_train else 0.0

  tqdm.write(f"Epoch {epoch+1}/{SETTING['num_epochs']} - Loss: {avg_loss:.6f} - Train PER: {PER_train:.6f} - Val PER: {PER_val:.6f}") ## want to fix this update info at position=0

  if early_stopping(PER_val):
    print(f"Early stopping triggered at epoch {epoch+1}")
    break


Epoch:   0%|          | 0/70 [00:00<?, ?it/s]
Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:04<04:09,  4.62s/it][A
Batch:   4%|▎         | 2/55 [00:09<04:00,  4.54s/it][A
Batch:   5%|▌         | 3/55 [00:16<05:01,  5.80s/it][A
Batch:   7%|▋         | 4/55 [00:20<04:24,  5.18s/it][A
Batch:   9%|▉         | 5/55 [00:25<04:19,  5.20s/it][A
Batch:  11%|█         | 6/55 [00:30<03:59,  4.90s/it][A
Batch:  13%|█▎        | 7/55 [00:34<03:39,  4.58s/it][A
Batch:  15%|█▍        | 8/55 [00:38<03:25,  4.38s/it][A
Batch:  16%|█▋        | 9/55 [00:44<03:46,  4.92s/it][A
Batch:  18%|█▊        | 10/55 [00:47<03:26,  4.59s/it][A
Batch:  20%|██        | 11/55 [00:51<03:10,  4.33s/it][A
Batch:  22%|██▏       | 12/55 [00:57<03:20,  4.65s/it][A
Batch:  24%|██▎       | 13/55 [01:01<03:06,  4.44s/it][A
Batch:  25%|██▌       | 14/55 [01:04<02:51,  4.18s/it][A
Batch:  27%|██▋       | 15/55 [01:11<03:17,  4.94s/it][A
Batch:  29%|██▉       | 16/55 [01:15<03:03, 

Epoch 1 - Avg Loss: 32811.567525
Epoch 1/70 - Loss: 32811.567525 - Train PER: 0.000000 - Val PER: 0.000000



Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:04<03:36,  4.01s/it][A
Batch:   4%|▎         | 2/55 [00:09<04:05,  4.63s/it][A
Batch:   5%|▌         | 3/55 [00:13<03:44,  4.32s/it][A
Batch:   7%|▋         | 4/55 [00:16<03:29,  4.11s/it][A
Batch:   9%|▉         | 5/55 [00:21<03:27,  4.14s/it][A
Batch:  11%|█         | 6/55 [00:28<04:15,  5.21s/it][A
Batch:  13%|█▎        | 7/55 [00:33<04:16,  5.34s/it][A
Batch:  15%|█▍        | 8/55 [00:39<04:15,  5.44s/it][A
Batch:  16%|█▋        | 9/55 [00:43<03:45,  4.89s/it][A
Batch:  18%|█▊        | 10/55 [00:48<03:39,  4.87s/it][A
Batch:  20%|██        | 11/55 [00:53<03:44,  5.10s/it][A
Batch:  22%|██▏       | 12/55 [00:57<03:20,  4.67s/it][A
Batch:  24%|██▎       | 13/55 [01:01<03:11,  4.55s/it][A
Batch:  25%|██▌       | 14/55 [01:06<03:05,  4.52s/it][A
Batch:  27%|██▋       | 15/55 [01:10<02:56,  4.42s/it][A
Batch:  29%|██▉       | 16/55 [01:14<02:45,  4.25s/it][A
Batch:  31%|███       | 17/55 [

Epoch 2 - Avg Loss: 31011.472541
Epoch 2/70 - Loss: 31011.472541 - Train PER: 0.000000 - Val PER: 0.000000



Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:05<04:43,  5.25s/it][A
Batch:   4%|▎         | 2/55 [00:09<04:16,  4.83s/it][A
Batch:   5%|▌         | 3/55 [00:13<03:51,  4.45s/it][A
Batch:   7%|▋         | 4/55 [00:20<04:30,  5.31s/it][A
Batch:   9%|▉         | 5/55 [00:25<04:23,  5.27s/it][A
Batch:  11%|█         | 6/55 [00:31<04:27,  5.47s/it][A
Batch:  13%|█▎        | 7/55 [00:37<04:33,  5.70s/it][A
Batch:  15%|█▍        | 8/55 [00:42<04:11,  5.35s/it][A
Batch:  16%|█▋        | 9/55 [00:47<04:02,  5.28s/it][A
Batch:  18%|█▊        | 10/55 [00:52<03:53,  5.20s/it][A
Batch:  20%|██        | 11/55 [00:56<03:38,  4.97s/it][A
Batch:  22%|██▏       | 12/55 [01:01<03:24,  4.74s/it][A
Batch:  24%|██▎       | 13/55 [01:06<03:27,  4.93s/it][A
Batch:  25%|██▌       | 14/55 [01:10<03:05,  4.52s/it][A
Batch:  27%|██▋       | 15/55 [01:15<03:07,  4.69s/it][A
Batch:  29%|██▉       | 16/55 [01:21<03:22,  5.19s/it][A
Batch:  31%|███       | 17/55 [

Epoch 3 - Avg Loss: 28803.679297
Epoch 3/70 - Loss: 28803.679297 - Train PER: 0.000000 - Val PER: 0.000000



Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:04<04:27,  4.95s/it][A
Batch:   4%|▎         | 2/55 [00:09<03:56,  4.47s/it][A
Batch:   5%|▌         | 3/55 [00:14<04:07,  4.75s/it][A
Batch:   7%|▋         | 4/55 [00:18<03:54,  4.60s/it][A
Batch:   9%|▉         | 5/55 [00:22<03:31,  4.23s/it][A
Batch:  11%|█         | 6/55 [00:27<03:53,  4.76s/it][A
Batch:  13%|█▎        | 7/55 [00:31<03:33,  4.45s/it][A
Batch:  15%|█▍        | 8/55 [00:35<03:16,  4.18s/it][A
Batch:  16%|█▋        | 9/55 [00:39<03:18,  4.31s/it][A
Batch:  18%|█▊        | 10/55 [00:43<03:09,  4.21s/it][A
Batch:  20%|██        | 11/55 [00:47<03:01,  4.12s/it][A
Batch:  22%|██▏       | 12/55 [00:52<02:58,  4.16s/it][A
Batch:  24%|██▎       | 13/55 [00:57<03:14,  4.63s/it][A
Batch:  25%|██▌       | 14/55 [01:01<02:58,  4.37s/it][A
Batch:  27%|██▋       | 15/55 [01:05<02:48,  4.22s/it][A
Batch:  29%|██▉       | 16/55 [01:11<03:02,  4.69s/it][A
Batch:  31%|███       | 17/55 [

Epoch 4 - Avg Loss: 24790.065230
Epoch 4/70 - Loss: 24790.065230 - Train PER: 0.000000 - Val PER: 0.000000



Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:05<05:00,  5.56s/it][A
Batch:   4%|▎         | 2/55 [00:10<04:49,  5.46s/it][A
Batch:   5%|▌         | 3/55 [00:17<05:09,  5.96s/it][A
Batch:   7%|▋         | 4/55 [00:21<04:29,  5.28s/it][A
Batch:   9%|▉         | 5/55 [00:25<03:55,  4.70s/it][A
Batch:  11%|█         | 6/55 [00:30<03:57,  4.86s/it][A
Batch:  13%|█▎        | 7/55 [00:34<03:40,  4.60s/it][A
Batch:  15%|█▍        | 8/55 [00:38<03:25,  4.38s/it][A
Batch:  16%|█▋        | 9/55 [00:45<04:00,  5.22s/it][A
Batch:  18%|█▊        | 10/55 [00:49<03:35,  4.79s/it][A
Batch:  20%|██        | 11/55 [00:53<03:21,  4.57s/it][A
Batch:  22%|██▏       | 12/55 [00:59<03:35,  5.01s/it][A
Batch:  24%|██▎       | 13/55 [01:04<03:23,  4.85s/it][A
Batch:  25%|██▌       | 14/55 [01:07<03:03,  4.46s/it][A
Batch:  27%|██▋       | 15/55 [01:12<03:03,  4.59s/it][A
Batch:  29%|██▉       | 16/55 [01:16<02:52,  4.43s/it][A
Batch:  31%|███       | 17/55 [

Epoch 5 - Avg Loss: 19469.408538


Epoch:   7%|▋         | 5/70 [22:57<5:21:10, 296.47s/it]

Epoch 5/70 - Loss: 19469.408538 - Train PER: 1.000000 - Val PER: 1.000000



Batch:   0%|          | 0/55 [00:00<?, ?it/s][A
Batch:   2%|▏         | 1/55 [00:05<04:36,  5.11s/it][A
Batch:   4%|▎         | 2/55 [00:10<04:29,  5.09s/it][A
Batch:   5%|▌         | 3/55 [00:14<03:55,  4.54s/it][A
Batch:   7%|▋         | 4/55 [00:18<03:41,  4.34s/it][A
Batch:   9%|▉         | 5/55 [00:23<03:55,  4.71s/it][A
Batch:  11%|█         | 6/55 [00:29<04:09,  5.08s/it][A
Batch:  13%|█▎        | 7/55 [00:33<03:46,  4.71s/it][A
Batch:  15%|█▍        | 8/55 [00:38<03:50,  4.90s/it][A
Batch:  16%|█▋        | 9/55 [00:42<03:38,  4.76s/it][A
Batch:  18%|█▊        | 10/55 [00:46<03:17,  4.40s/it][A
Batch:  20%|██        | 11/55 [00:51<03:21,  4.57s/it][A
Batch:  22%|██▏       | 12/55 [00:56<03:21,  4.69s/it][A
Batch:  24%|██▎       | 13/55 [01:00<03:04,  4.40s/it][A
Batch:  25%|██▌       | 14/55 [01:04<03:04,  4.50s/it][A
Batch:  27%|██▋       | 15/55 [01:10<03:06,  4.66s/it][A
Batch:  29%|██▉       | 16/55 [01:15<03:08,  4.83s/it][A
Batch:  31%|███       | 17/55 [

RuntimeError: Function 'LogSoftmaxBackward0' returned nan values in its 0th output.