# Specialized Task Shallow Detectors

*   Using latest email from Baker



In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
drivepath = '/content/gdrive/MyDrive/DL_Group_Project'
datapath = '/content/gdrive/MyDrive/preprocessed_data/preprocessed_data'
output_path = f'{drivepath}/experiments/specialized_detectors/'

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import time
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
NUM_EPOCHS = 30
BATCH_SIZE = 64
HIDDEN_SIZE = 128
MODEL_VERSION = 1
LEARNING_RATE = 0.01
LOGISTIC_THRESHOLD = 0.5

In [None]:
cuda = torch.cuda.is_available()
num_workers = 8 if cuda else 0
DEVICE = "cuda" if cuda else "cpu"
print("Cuda = "+str(cuda)+" with num_workers = "+str(num_workers))

Cuda = True with num_workers = 8


In [None]:
DEBUG = False

In [None]:
class SpecializedDataset(Dataset):
    
    def __init__(self, datapath, mode, task_name, phonemes_class_0, phonemes_class_1):
      """
      phonemes_class_0: list of phoneme names for class 0
      phonemes_class_1: list of phoneme names for class 1
      """
      complete_features = np.zeros((1, 40))  # eliminate this row
      complete_labels = np.zeros((1))  # eliminate this row

      # go through all files in datapath, check phoneme if class=0 or class=1
      # assign that label and discard silence frames
      class_0_phonemes_found = []
      class_1_phonemes_found = []
      with os.scandir(datapath) as entries:
        for entry in entries:
          if entry.is_file():
            if "features" in entry.name and mode in entry.name:
              phoneme_tag = entry.name.split("_")[0]

              # find phoneme in class_0 or class_1 list and assign label
              phoneme_class = None
              if phoneme_tag in phonemes_class_0:
                phoneme_class = 0
                class_0_phonemes_found.append(phoneme_tag)
              if phoneme_tag in phonemes_class_1:
                phoneme_class = 1
                class_1_phonemes_found.append(phoneme_tag)

              if phoneme_class is None:
                print(f"phoneme '{phoneme_tag}' not found on class 0 nor class 1 lists; skip") if DEBUG else None
                continue

              print(f"phoneme '{phoneme_tag}' is class: {phoneme_class}") if DEBUG else None

              features_filepath = entry.path
              labels_filepath = f"{datapath}/{phoneme_tag}_{mode}_labels.npy"

              phoneme_features = np.load(features_filepath, allow_pickle=True)
              phoneme_labels = np.load(labels_filepath, allow_pickle=True)
              print(f"{phoneme_tag} total features: {phoneme_features.shape}") if DEBUG else None
              print(f"{phoneme_tag} total labels: {phoneme_labels.shape}") if DEBUG else None

              # find frames where label != 0 (non-silence)
              non_zero_indexes = phoneme_labels.nonzero()
              phoneme_features = phoneme_features[non_zero_indexes]
              phoneme_labels = phoneme_labels[non_zero_indexes]
              print(f"{phoneme_tag} no-silence features: {phoneme_features.shape}") if DEBUG else None
              print(f"{phoneme_tag} no-silence labels: {phoneme_labels.shape}") if DEBUG else None

              phoneme_labels[:] = phoneme_class  # label=class

              # stack to phoneme features
              complete_features = np.concatenate((complete_features, phoneme_features))
              complete_labels = np.concatenate((complete_labels, phoneme_labels))
      
      self.X = complete_features[1:]
      self.Y = complete_labels[1:]
      print(f"[task={task_name}] {self.X.shape} features")
      print(f"[task={task_name}] {self.Y.shape} labels")

      if sorted(class_0_phonemes_found) != sorted(phonemes_class_0):
        raise Exception(f"class 0 phonemes found ({sorted(class_0_phonemes_found)}) != expected phonemes ({sorted(phonemes_class_0)})")
 
      if sorted(class_1_phonemes_found) != sorted(phonemes_class_1):
        raise Exception(f"class 1 phonemes found ({sorted(class_1_phonemes_found)}) != expected phonemes ({sorted(phonemes_class_1)})")
 

    def __len__(self):
        return len(self.X)
 
    # get a row at an index
    def __getitem__(self, index):
        x = torch.Tensor(self.X[index]).float()
        y = torch.as_tensor(self.Y[index]).float()
    
        return x,y

In [None]:
def make_dataloader(dataset, train, batch_size):
  if train:
    shuffle = True
    drop_last = True
  else:
    shuffle = False
    drop_last = False
    
  loader = DataLoader(dataset=dataset, batch_size=batch_size,
                      drop_last=drop_last, shuffle=shuffle,
                      pin_memory=True, num_workers=8)
  
  return loader

In [None]:
class SpecializedShallowDetector(nn.Module):
  
  def __init__(self, hidden_size, activation):
    super(SpecializedShallowDetector, self).__init__()
    
    self.linear_layer = nn.Linear(in_features=40, out_features=hidden_size)
    self.bn_layer = nn.BatchNorm1d(num_features=hidden_size)
    self.activation = activation
    self.output_layer = nn.Linear(in_features=hidden_size, out_features=1)
    self.sigmoid = nn.Sigmoid()
    seq_params = [
      self.linear_layer,
      self.bn_layer,
      self.activation,
      self.output_layer,
      self.sigmoid
    ]

    self.network = nn.Sequential(*seq_params)
    
  def forward(self, x):
    return self.network(x)

## 1. Generating predictions csv for specialized detectors 

In [None]:
%cd /content/gdrive/MyDrive/DL_Group_Project/Utilities
from utilities import SPECIALIZED_TASKS
%cd /
print(SPECIALIZED_TASKS)

/content/gdrive/.shortcut-targets-by-id/1qwJK2jyGMl2dPnVFe6JNZvrrG45HoonZ/DL_Group_Project/Utilities
/
{'1_vowel_vs_consonant': {0: ['EE', 'IH', 'EH', 'AE', 'UH', 'ER', 'AH', 'AW', 'OO', 'UE'], 1: ['FF', 'HH', 'MM', 'NN', 'NG', 'RR', 'SS', 'SH', 'VV', 'WW', 'YY', 'ZZ']}, '3_highvowel_vs_lowvowel': {0: ['EE', 'IH', 'UE', 'OO'], 1: ['AE', 'AH', 'AW']}, '4_voiced_vs_unvoiced_fricatives': {0: ['DH', 'VV', 'ZZ'], 1: ['FF', 'SS', 'SH', 'TH']}, '5_ss_vs_zz': {0: ['SS'], 1: ['ZZ']}, '6_b_vs_p': {0: ['B'], 1: ['P']}, '7_dh_vs_th': {0: ['DH'], 1: ['TH']}, '8_ww_vs_yy': {0: ['WW'], 1: ['YY']}, '9_ee_vs_aw': {0: ['EE'], 1: ['AW']}, '10_ah_vs_aw': {0: ['AH'], 1: ['AW']}, '11_mm_vs_nn': {0: ['MM'], 1: ['NN']}}


In [92]:
class ValidateSpecializedDetector():

  def __init__(self, task_name, phonemes_class_0, phonemes_class_1):
    self.task_name = task_name
    self.class_0 = phonemes_class_0
    self.class_1 = phonemes_class_0

    dev_data = SpecializedDataset(datapath=datapath, mode="dev", 
                                  task_name=task_name, 
                                  phonemes_class_0=phonemes_class_0, 
                                  phonemes_class_1=phonemes_class_1)
    self.dev_loader = make_dataloader(dataset=dev_data, train=False, batch_size=BATCH_SIZE)

    self.model = SpecializedShallowDetector(hidden_size=HIDDEN_SIZE, 
                                            activation=nn.LeakyReLU()).to(DEVICE)
    self.criterion = nn.BCELoss()
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=LEARNING_RATE, momentum=0.9)
    self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')
  
  def load_model(self, epoch, model_version = MODEL_VERSION):
    self.model_epoch_path = "{}/models/model_{}_{}_{}".format(output_path, self.task_name, 
                                                         model_version, epoch)
    checkpoint = torch.load(self.model_epoch_path)
    self.model.load_state_dict(checkpoint['model_state_dict'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    self.scheduler = checkpoint['scheduler_state_dict']
    self.model.to(DEVICE)

    self.epoch = epoch
    self.model_version = model_version

    print('loaded model: {}'.format(self.model_epoch_path))

  def plot_misclassification(self, df, s):

      df['frame'] = df.index

      misclass_df = df[df['true_labels']!=df['predictions']]

      # gca stands for 'get current axis'
      ax = plt.gca()

      misclass_df.plot(kind='scatter',x = 'frame', y='true_labels', ax = ax, style=".", s = s)
      misclass_df.plot(kind='scatter',x = 'frame', y='predictions', color='red', ax = ax, style=".", s = s)

      title = "{}".format(self.task_name)
      plt.title(title)
      plot_path = "{}/updated_misclass/plot_{}_{}_{}.png".format(output_path, self.task_name, 
                                                         self.model_version, self.epoch)
      plt.savefig(plot_path)
      print('saved plot: {}'.format(plot_path))
      plt.clf()

  def plot_true(self, df):

      # gca stands for 'get current axis'
      ax = plt.gca()

      df.plot(kind='line',y='true_labels', ax = ax)

      title = "{}".format(self.task_name)
      plt.title(title)
      plot_path = "{}/updated_misclass/true_labels_{}_{}_{}.png".format(output_path, self.task_name, 
                                                         self.model_version, self.epoch)
      plt.savefig(plot_path)
      print('saved plot: {}'.format(plot_path))
      plt.clf()

  def evaluate_model_misclassification(self):

    with torch.no_grad():
      self.model.eval()

      running_loss = 0.0
      total_predictions = 0.0
      correct_predictions = 0.0

      true_labels = []
      predictions = []

      example_ct = 0
      start_time = time.time()
      for batch_idx, (features, targets) in enumerate(self.dev_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        targets = targets.reshape(-1, 1)

        example_ct += len(features)

        outputs = self.model(features)
        outputs = outputs.to(DEVICE)

        # check number of correct predictions
        output_classes = torch.where(outputs > LOGISTIC_THRESHOLD, 1, 0)  # convert to class labels
        total_predictions += len(output_classes)
        correct_predictions += torch.sum(targets == output_classes)

        loss = self.criterion(outputs, targets).detach()
        running_loss += loss.item()

        true_labels += list(targets.detach().cpu())
        predictions += list(output_classes.detach().cpu())
      
      end_time = time.time()

      running_loss /= example_ct
      print(f"testing loss: {running_loss}; time: {end_time - start_time}s")
  
      print(classification_report(true_labels, predictions))
      print(confusion_matrix(true_labels, predictions))
      classification_pd = pd.DataFrame(
    {'true_labels': [a.item() for a in true_labels],
     'predictions': [a.item() for a in predictions]
    })
      csv_path = "{}/updated_misclass/report_{}_{}_{}.csv".format(output_path, self.task_name, 
                                                         self.model_version, self.epoch)
      classification_pd.to_csv(csv_path, index=True)
      if len(classification_pd.index) < 1000:
        s = 0.5
      else:
        s = 0.09
      self.plot_misclassification(classification_pd, s)
      self.plot_true(classification_pd)


In [93]:
not_trained = []
for task_name, classes_dict in SPECIALIZED_TASKS.items():
  try:
    detector = ValidateSpecializedDetector(task_name, classes_dict[0], classes_dict[1])
    detector.load_model(epoch = 29, model_version = 1)
    detector.evaluate_model_misclassification()
  except Exception as e:
    not_trained.append((task_name, str(e)))

[task=1_vowel_vs_consonant] (3307, 40) features
[task=1_vowel_vs_consonant] (3307,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_1_vowel_vs_consonant_1_29


  cpuset_checked))


testing loss: 0.0035523334779983063; time: 0.5781617164611816s
              precision    recall  f1-score   support

         0.0       0.90      0.83      0.86      1383
         1.0       0.88      0.94      0.91      1924

    accuracy                           0.89      3307
   macro avg       0.89      0.88      0.89      3307
weighted avg       0.89      0.89      0.89      3307

[[1143  240]
 [ 120 1804]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_1_vowel_vs_consonant_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_1_vowel_vs_consonant_1_29.png
[task=3_highvowel_vs_lowvowel] (1065, 40) features
[task=3_highvowel_vs_lowvowel] (1065,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_3_highvowel_vs_lowvowel_1_29


  cpuset_checked))


testing loss: 0.0029278355654976176; time: 0.4590907096862793s
              precision    recall  f1-score   support

         0.0       0.89      0.99      0.94       616
         1.0       0.99      0.83      0.90       449

    accuracy                           0.92      1065
   macro avg       0.94      0.91      0.92      1065
weighted avg       0.93      0.92      0.92      1065

[[611   5]
 [ 77 372]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_3_highvowel_vs_lowvowel_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_3_highvowel_vs_lowvowel_1_29.png
[task=4_voiced_vs_unvoiced_fricatives] (1116, 40) features
[task=4_voiced_vs_unvoiced_fricatives] (1116,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_4_voiced_vs_unvoiced_fricatives_1_29


  cpuset_checked))


testing loss: 0.0038350047510264168; time: 0.4625999927520752s
              precision    recall  f1-score   support

         0.0       0.81      0.90      0.85       463
         1.0       0.92      0.85      0.89       653

    accuracy                           0.87      1116
   macro avg       0.87      0.88      0.87      1116
weighted avg       0.88      0.87      0.87      1116

[[415  48]
 [ 95 558]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_4_voiced_vs_unvoiced_fricatives_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_4_voiced_vs_unvoiced_fricatives_1_29.png
[task=5_ss_vs_zz] (321, 40) features
[task=5_ss_vs_zz] (321,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_5_ss_vs_zz_1_29


  cpuset_checked))


testing loss: 0.004418619725609494; time: 0.41211986541748047s
              precision    recall  f1-score   support

         0.0       0.85      0.99      0.91       158
         1.0       0.99      0.83      0.90       163

    accuracy                           0.91       321
   macro avg       0.92      0.91      0.91       321
weighted avg       0.92      0.91      0.91       321

[[156   2]
 [ 28 135]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_5_ss_vs_zz_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_5_ss_vs_zz_1_29.png
[task=6_b_vs_p] (169, 40) features
[task=6_b_vs_p] (169,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_6_b_vs_p_1_29


  cpuset_checked))


testing loss: 0.008018413062631732; time: 0.39914369583129883s
              precision    recall  f1-score   support

         0.0       0.74      0.96      0.83        96
         1.0       0.91      0.55      0.68        73

    accuracy                           0.78       169
   macro avg       0.82      0.75      0.76       169
weighted avg       0.81      0.78      0.77       169

[[92  4]
 [33 40]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_6_b_vs_p_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_6_b_vs_p_1_29.png
[task=7_dh_vs_th] (310, 40) features
[task=7_dh_vs_th] (310,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_7_dh_vs_th_1_29


  cpuset_checked))


testing loss: 0.004478254000986776; time: 0.42003583908081055s
              precision    recall  f1-score   support

         0.0       0.88      0.78      0.83       138
         1.0       0.84      0.92      0.88       172

    accuracy                           0.85       310
   macro avg       0.86      0.85      0.85       310
weighted avg       0.86      0.85      0.85       310

[[107  31]
 [ 14 158]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_7_dh_vs_th_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_7_dh_vs_th_1_29.png
[task=8_ww_vs_yy] (310, 40) features
[task=8_ww_vs_yy] (310,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_8_ww_vs_yy_1_29


  cpuset_checked))


testing loss: 0.005077230882260107; time: 0.40673232078552246s
              precision    recall  f1-score   support

         0.0       0.92      0.84      0.88       155
         1.0       0.85      0.92      0.89       155

    accuracy                           0.88       310
   macro avg       0.88      0.88      0.88       310
weighted avg       0.88      0.88      0.88       310

[[130  25]
 [ 12 143]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_8_ww_vs_yy_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_8_ww_vs_yy_1_29.png
[task=9_ee_vs_aw] (388, 40) features
[task=9_ee_vs_aw] (388,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_9_ee_vs_aw_1_29


  cpuset_checked))


testing loss: 0.0030441126538939853; time: 0.4337780475616455s
              precision    recall  f1-score   support

         0.0       0.93      0.99      0.96       226
         1.0       0.99      0.90      0.94       162

    accuracy                           0.95       388
   macro avg       0.96      0.94      0.95       388
weighted avg       0.95      0.95      0.95       388

[[224   2]
 [ 17 145]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_9_ee_vs_aw_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_9_ee_vs_aw_1_29.png
[task=10_ah_vs_aw] (305, 40) features
[task=10_ah_vs_aw] (305,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_10_ah_vs_aw_1_29


  cpuset_checked))


testing loss: 0.0038384346199817343; time: 0.4168875217437744s
              precision    recall  f1-score   support

         0.0       0.85      0.99      0.92       143
         1.0       0.99      0.85      0.91       162

    accuracy                           0.91       305
   macro avg       0.92      0.92      0.91       305
weighted avg       0.93      0.91      0.91       305

[[142   1]
 [ 25 137]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_10_ah_vs_aw_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_10_ah_vs_aw_1_29.png
[task=11_mm_vs_nn] (391, 40) features
[task=11_mm_vs_nn] (391,) labels
loaded model: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//models/model_11_mm_vs_nn_1_29


  cpuset_checked))


testing loss: 0.00415665073239285; time: 0.4244225025177002s
              precision    recall  f1-score   support

         0.0       0.87      0.96      0.92       191
         1.0       0.96      0.86      0.91       200

    accuracy                           0.91       391
   macro avg       0.92      0.91      0.91       391
weighted avg       0.92      0.91      0.91       391

[[184   7]
 [ 27 173]]
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/plot_11_mm_vs_nn_1_29.png
saved plot: /content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors//updated_misclass/true_labels_11_mm_vs_nn_1_29.png


<Figure size 432x288 with 0 Axes>

In [94]:
for task_name, error in not_trained:
  print(f"'{task_name}' not trained")
  print(f"error: {error}")
  print("")