# Frame to Phoneme Classifier

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [14]:
drivepath_shallow = '/content/gdrive/MyDrive/DL_Group_Project/Dataset/Preprocessed_Data'

In [15]:
drivepath_spec = '/content/gdrive/MyDrive/DL_Group_Project/experiments/specialized_detectors/models'

In [4]:
!pip install tqdm



In [5]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
from tqdm import tqdm
import time
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import pandas as pd

In [6]:
NUM_EPOCHS = 100
BATCH_SIZE = 64
HIDDEN_SIZE_shallow = 128
HIDDEN_SIZE_spec = 128
MODEL_VERSION = 1
LEARNING_RATE = 0.01
LOGISTIC_THRESHOLD = 0.5
OTHER_PHONEMES_PERCENT = 0.1

In [7]:
cuda = torch.cuda.is_available()
num_workers = 8 if cuda else 0
DEVICE = "cuda" if cuda else "cpu"
print("Cuda = "+str(cuda)+" with num_workers = "+str(num_workers))

Cuda = True with num_workers = 8


In [8]:
class PhonemesDataset(Dataset):
    
    def __init__(self, basepath, mode):
      phoneme_features = np.zeros((1, 40))  # eliminate this row
      phoneme_labels = np.zeros((1))  # eliminate this row

      with os.scandir(basepath) as entries:
        for entry in entries:
          if entry.is_file():
            if "features" in entry.name and mode in entry.name:
              phoneme_tag = entry.name.split("_")[0]

              features_filepath = entry.path
              labels_filepath = f"{basepath}/{phoneme_tag}_{mode}_labels.npy"

              other_phoneme_features = np.load(features_filepath, allow_pickle=True)
              other_phoneme_labels = np.load(labels_filepath, allow_pickle=True)
              
              # stack to phoneme features
              phoneme_features = np.concatenate((phoneme_features, other_phoneme_features))
              phoneme_labels = np.concatenate((phoneme_labels, other_phoneme_labels))
              
      self.X = phoneme_features[1:]
      self.Y = phoneme_labels[1:]

    def __len__(self):
        return len(self.X)
 
    # get a row at an index
    def __getitem__(self, index):
        x = torch.Tensor(self.X[index]).float()
        y = torch.as_tensor(self.Y[index]).long()
    
        return x,y

In [9]:
class SpecializedDataset(Dataset):
    
    def __init__(self, datapath, mode, task_name, phonemes_class_0, phonemes_class_1):
      """
      phonemes_class_0: list of phoneme names for class 0
      phonemes_class_1: list of phoneme names for class 1
      """
      complete_features = np.zeros((1, 40))  # eliminate this row
      complete_labels = np.zeros((1))  # eliminate this row

      # go through all files in datapath, check phoneme if class=0 or class=1
      # assign that label and discard silence frames
      class_0_phonemes_found = []
      class_1_phonemes_found = []
      with os.scandir(datapath) as entries:
        for entry in entries:
          if entry.is_file():
            if "features" in entry.name and mode in entry.name:
              phoneme_tag = entry.name.split("_")[0]

              # find phoneme in class_0 or class_1 list and assign label
              phoneme_class = None
              if phoneme_tag in phonemes_class_0:
                phoneme_class = 0
                class_0_phonemes_found.append(phoneme_tag)
              if phoneme_tag in phonemes_class_1:
                phoneme_class = 1
                class_1_phonemes_found.append(phoneme_tag)

              if phoneme_class is None:
                print(f"phoneme '{phoneme_tag}' not found on class 0 nor class 1 lists; skip") if DEBUG else None
                continue

              print(f"phoneme '{phoneme_tag}' is class: {phoneme_class}") if DEBUG else None

              features_filepath = entry.path
              labels_filepath = f"{datapath}/{phoneme_tag}_{mode}_labels.npy"

              phoneme_features = np.load(features_filepath, allow_pickle=True)
              phoneme_labels = np.load(labels_filepath, allow_pickle=True)
              print(f"{phoneme_tag} total features: {phoneme_features.shape}") if DEBUG else None
              print(f"{phoneme_tag} total labels: {phoneme_labels.shape}") if DEBUG else None

              # find frames where label != 0 (non-silence)
              non_zero_indexes = phoneme_labels.nonzero()
              phoneme_features = phoneme_features[non_zero_indexes]
              phoneme_labels = phoneme_labels[non_zero_indexes]
              print(f"{phoneme_tag} no-silence features: {phoneme_features.shape}") if DEBUG else None
              print(f"{phoneme_tag} no-silence labels: {phoneme_labels.shape}") if DEBUG else None

              phoneme_labels[:] = phoneme_class  # label=class

              # stack to phoneme features
              complete_features = np.concatenate((complete_features, phoneme_features))
              complete_labels = np.concatenate((complete_labels, phoneme_labels))
      
      self.X = complete_features[1:]
      self.Y = complete_labels[1:]
      print(f"[task={task_name}] {self.X.shape} features")
      print(f"[task={task_name}] {self.Y.shape} labels")

      if sorted(class_0_phonemes_found) != sorted(phonemes_class_0):
        raise Exception(f"class 0 phonemes found ({sorted(class_0_phonemes_found)}) != expected phonemes ({sorted(phonemes_class_0)})")
 
      if sorted(class_1_phonemes_found) != sorted(phonemes_class_1):
        raise Exception(f"class 1 phonemes found ({sorted(class_1_phonemes_found)}) != expected phonemes ({sorted(phonemes_class_1)})")
 

    def __len__(self):
        return len(self.X)
 
    # get a row at an index
    def __getitem__(self, index):
        x = torch.Tensor(self.X[index]).float()
        y = torch.as_tensor(self.Y[index]).float()
    
        return x,y

In [10]:
def make_dataloader(dataset, train, batch_size):
  if train:
    shuffle = True
    drop_last = True
  else:
    shuffle = False
    drop_last = False
    
  loader = DataLoader(dataset=dataset, batch_size=batch_size,
                      drop_last=drop_last, shuffle=shuffle,
                      pin_memory=True, num_workers=8)
  
  return loader

In [11]:
class PhonemeShallowDetector(nn.Module):
  
  def __init__(self, hidden_size, activation):
    super(PhonemeShallowDetector, self).__init__()
    
    self.linear_layer = nn.Linear(in_features=40, out_features=hidden_size)
    self.bn_layer = nn.BatchNorm1d(num_features=hidden_size)
    self.activation = activation
    self.output_layer = nn.Linear(in_features=hidden_size, out_features=1)
    self.sigmoid = nn.Sigmoid()
    seq_params = [
      self.linear_layer,
      self.bn_layer,
      self.activation,
      self.output_layer,
      self.sigmoid
    ]

    self.network = nn.Sequential(*seq_params)
    
  def forward(self, x):
    return self.network(x)

In [12]:
class SpecializedShallowDetector(nn.Module):
  
  def __init__(self, hidden_size, activation):
    super(SpecializedShallowDetector, self).__init__()
    
    self.linear_layer = nn.Linear(in_features=40, out_features=hidden_size)
    self.bn_layer = nn.BatchNorm1d(num_features=hidden_size)
    self.activation = activation
    self.output_layer = nn.Linear(in_features=hidden_size, out_features=1)
    self.sigmoid = nn.Sigmoid()
    seq_params = [
      self.linear_layer,
      self.bn_layer,
      self.activation,
      self.output_layer,
      self.sigmoid
    ]

    self.network = nn.Sequential(*seq_params)
    
  def forward(self, x):
    return self.network(x)

In [27]:
class FramePhonemeClassifierModel(nn.Module):
  
  def __init__(self, phoneme_mapper, specialized_mapper):
    super(FramePhonemeClassifierModel, self).__init__()
    
    self.phoneme_mapper = phoneme_mapper
    self.specialized_mapper = specialized_mapper

    # specialized detectors

    specialized_detectors = []
    for spec_idx, spec_type in enumerate(SPECIALIZED_TASKS):
      specialized_detector = SpecializedShallowDetector(hidden_size=HIDDEN_SIZE_spec, 
                                            activation=nn.LeakyReLU()).to(DEVICE)
      specialized_detectors.append(specialized_detector)
    
    self.specialized_detectors = nn.ModuleList(specialized_detectors)


    shallow_detectors = []
    # generate the PhonemeShallowDetectors 
    for phoneme_index, phoneme_tag in phoneme_mapper.items():
      shallow_detector = PhonemeShallowDetector(hidden_size=HIDDEN_SIZE_shallow, 
                                                activation=nn.LeakyReLU())
      shallow_detectors.append(shallow_detector)

    self.shallow_detectors = nn.ModuleList(shallow_detectors)

    self.linear_layer = nn.Linear(in_features=len(phoneme_mapper), out_features=len(phoneme_mapper))

    self.initialize_specialized_detectors()
    self.initialize_shallow_detectors()

  def initialize_shallow_detectors(self):
    # load weights from shallow detectors pre-trained models
    for phoneme_index, phoneme_tag in self.phoneme_mapper.items():
      phoneme_shallow_detector = self.shallow_detectors[phoneme_index]

      phoneme_model_path = f"{drivepath_shallow}/shallow_detectors/model_{phoneme_tag}_{MODEL_VERSION}_99"
      temp = torch.load(phoneme_model_path)
      phoneme_shallow_detector.load_state_dict(temp['model_state_dict'])
  
  def initialize_specialized_detectors(self):
    # load weights from shallow detectors pre-trained models
    for spec_idx, spec_type in self.phoneme_mapper.items():
      specialized_detector = self.specialized_detectors[spec_idx]

      spec_model_path = f"{drivepath_spec}/model_{spec_type}_{MODEL_VERSION}_29"
      temp = torch.load(spec_model_path)
      specialized_detector.load_state_dict(temp['model_state_dict'])
  
  def forward(self, x):
    
    # go through specialized
    
    shallow_outputs = []
    for phoneme_index, phoneme_tag in self.phoneme_mapper.items():
      phoneme_shallow_detector = self.shallow_detectors[phoneme_index]

      # run frame through shallow detector
      output = phoneme_shallow_detector(x)
      shallow_outputs.append(output.reshape(-1))

    # convert to torch tensor
    shallow_outputs = torch.vstack(shallow_outputs).T
    outputs = self.linear_layer(shallow_outputs)

    return outputs

In [16]:
class FramePhonemeClassifier():

  def __init__(self, phoneme_mapper, specialized_mapper):

    train_data = PhonemesDataset(basepath=drivepath_shallow, mode="train")
    self.train_loader = make_dataloader(dataset=train_data, train=True, batch_size=BATCH_SIZE)
    print(f"train_data.shape: {train_data.X.shape}")

    dev_data = PhonemesDataset(basepath=drivepath_shallow, mode="dev")
    self.dev_loader = make_dataloader(dataset=dev_data, train=False, batch_size=BATCH_SIZE)
    print(f"dev_data.shape: {dev_data.X.shape}")
    
    self.model = FramePhonemeClassifierModel(phoneme_mapper, specialized_mapper).to(DEVICE)

    self.criterion = nn.CrossEntropyLoss()
    self.optimizer = torch.optim.SGD(self.model.parameters(), lr=LEARNING_RATE, momentum=0.9)
    self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'min')

    self.train_loss_per_epoch = []
    self.train_acc_per_epoch = []
    self.dev_loss_per_epoch = []
    self.dev_acc_per_epoch = []
  
  def save_model(self, epoch):
    model_epoch_path = "{}/complete_classifier/model_{}_{}".format(drivepath,
                                                                 MODEL_VERSION, 
                                                                 epoch)
    torch.save({
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'scheduler_state_dict': self.scheduler.state_dict(),
    }, model_epoch_path)
    # print('saved model: {}'.format(model_epoch_path))

  def train(self, epochs):
    # Run training and track with wandb
    total_batches = len(self.train_loader) * epochs
    example_ct = 0  # number of examples seen
    batch_ct = 0

    for epoch in tqdm(range(epochs)):
        train_loss = 0.0
        start_time = time.time()
        total_predictions = 0
        correct_predictions = 0

        true_labels = []
        predictions = []
        for _, (features, targets) in enumerate(self.train_loader):
            batch_loss, outputs = self.train_batch(features, targets)
            train_loss += batch_loss

            example_ct += len(features)
            batch_ct += 1

            # check number of correct predictions
            output_classes = torch.argmax(outputs.log_softmax(1), dim=1).detach().cpu()  # convert to class labels
            total_predictions += len(output_classes)
            correct_predictions += torch.sum(targets == output_classes)

            true_labels += list(targets)
            predictions += list(output_classes)

        end_time = time.time()

        train_loss /= example_ct
        print(f"training loss: {train_loss}; time: {end_time - start_time}s")
        
        if (epoch + 1) % 10 == 0 or epoch == (epochs - 1):
          report = classification_report(true_labels, predictions, output_dict=True)
          df = pd.DataFrame(report).transpose()
          df.to_csv(f"{drivepath}/final_classsifier/reports_train_{MODEL_VERSION}_{epoch + 1}.csv", index=False)
          self.save_model(epoch)

        train_acc = (correct_predictions/total_predictions) * 100.0
        print(f"training accuracy: {train_acc}%")

        self.train_loss_per_epoch.append(train_loss)
        self.train_acc_per_epoch.append(train_acc)

        # evaluate model with validation data
        dev_loss, dev_acc = self.evaluate_model(epoch)
        
        self.dev_loss_per_epoch.append(dev_loss)
        self.dev_acc_per_epoch.append(dev_acc)

        # Step with the scheduler
        self.scheduler.step(dev_loss)
      
    # epoch completed, save model
    self.save_model(epoch)

  def train_batch(self, features, targets):
    features, targets = features.to(DEVICE), targets.to(DEVICE)
    targets = targets.reshape(-1, 1)

    self.optimizer.zero_grad()

    # Forward pass ➡
    outputs = self.model(features)
    loss = self.criterion(outputs, targets.reshape(-1))  # compare with target outputs
    # Backward pass ⬅
    loss.backward()
    # Step with optimizer
    self.optimizer.step()

    return loss.item(), outputs

  def evaluate_model(self, epoch):

    with torch.no_grad():
      self.model.eval()

      running_loss = 0.0
      total_predictions = 0.0
      correct_predictions = 0.0

      true_labels = []
      predictions = []

      example_ct = 0
      start_time = time.time()
      for batch_idx, (features, targets) in enumerate(self.dev_loader):
        features = features.to(DEVICE)
        targets = targets.to(DEVICE)
        targets = targets.reshape(-1, 1)

        example_ct += len(features)

        outputs = self.model(features)
        outputs = outputs.to(DEVICE)

        # check number of correct predictions
        output_classes = torch.argmax(outputs.log_softmax(1), dim=1)  # convert to class labels
        total_predictions += len(output_classes)
        correct_predictions += torch.sum(targets.reshape(-1) == output_classes)

        loss = self.criterion(outputs, targets.reshape(-1)).detach()
        running_loss += loss.item()

        true_labels += list(targets.detach().cpu())
        predictions += list(output_classes.detach().cpu())
      
      end_time = time.time()

      running_loss /= example_ct
      print(f"testing loss: {running_loss}; time: {end_time - start_time}s")
      acc = (correct_predictions/total_predictions) * 100.0
      print(f"testing accuracy: {acc}%")

      if (epoch + 1) % 10 == 0:
        report = classification_report(true_labels, predictions, output_dict=True)
        df = pd.DataFrame(report).transpose()
        df.to_csv(f"{drivepath}/final_classsifier/reports_dev_{MODEL_VERSION}_{epoch + 1}.csv", index=False)
  
      return running_loss, acc

# Train classifier

In [21]:
%cd /content/gdrive/MyDrive/DL_Group_Project/Dataset/Preprocessed_Data

/content/gdrive/.shortcut-targets-by-id/1qwJK2jyGMl2dPnVFe6JNZvrrG45HoonZ/DL_Group_Project/Dataset/Preprocessed_Data


In [23]:
from utilities import PHONEME_MAPPER
from utilities import SPECIALIZED_TASKS

In [None]:
%cd /

/


In [None]:
print(PHONEME_MAPPER)

{0: 'SIL', 1: 'AE', 2: 'AH', 3: 'AW', 4: 'AY', 5: 'B', 6: 'BIT', 7: 'D', 8: 'DH', 9: 'EE', 10: 'FF', 11: 'G', 12: 'HH', 13: 'IH', 14: 'II', 15: 'J', 16: 'K', 17: 'LL', 18: 'MM', 19: 'NN', 20: 'OH', 21: 'OO', 22: 'OW', 23: 'OY', 24: 'P', 25: 'RR', 26: 'SH', 27: 'SS', 28: 'T', 29: 'TH', 30: 'UE', 31: 'UH', 32: 'VV', 33: 'WW', 34: 'YY', 35: 'ZZ'}


In [24]:
print(SPECIALIZED_TASKS)

{'1_vowel_vs_consonant': {0: ['EE', 'IH', 'EH', 'AE', 'UH', 'ER', 'AH', 'AW', 'OO', 'UE'], 1: ['FF', 'HH', 'MM', 'NN', 'NG', 'RR', 'SS', 'SH', 'VV', 'WW', 'YY', 'ZZ']}, '3_highvowel_vs_lowvowel': {0: ['EE', 'IH', 'UE', 'OO'], 1: ['AE', 'AH', 'AW']}, '4_voiced_vs_unvoiced_fricatives': {0: ['DH', 'VV', 'ZZ'], 1: ['FF', 'SS', 'SH', 'TH']}, '5_ss_vs_zz': {0: ['SS'], 1: ['ZZ']}, '6_b_vs_p': {0: ['B'], 1: ['P']}, '7_dh_vs_th': {0: ['DH'], 1: ['TH']}, '8_ww_vs_yy': {0: ['WW'], 1: ['YY']}, '9_ee_vs_aw': {0: ['EE'], 1: ['AW']}, '10_ah_vs_aw': {0: ['AH'], 1: ['AW']}, '11_mm_vs_nn': {0: ['MM'], 1: ['NN']}}


In [None]:
classifier = FramePhonemeClassifier(PHONEME_MAPPER, SPECIALIZED_TASKS)
classifier.train(epochs=NUM_EPOCHS)


  cpuset_checked))


train_data.shape: (49976, 40)
dev_data.shape: (10724, 40)








  0%|          | 0/100 [00:00<?, ?it/s][A[A[A[A[A[A

training loss: 0.026293274411597314; time: 17.92014217376709s
training accuracy: 60.637020111083984%








  1%|          | 1/100 [00:19<32:19, 19.59s/it][A[A[A[A[A[A

testing loss: 0.028341506538007384; time: 1.6592097282409668s
testing accuracy: 57.627750396728516%
training loss: 0.01059438670412279; time: 16.894349575042725s
training accuracy: 70.30648803710938%








  2%|▏         | 2/100 [00:37<31:25, 19.23s/it][A[A[A[A[A[A

testing loss: 0.023644539369537688; time: 1.4994187355041504s
testing accuracy: 68.25811004638672%
training loss: 0.005959935536066818; time: 16.883193016052246s
training accuracy: 75.68910217285156%








  3%|▎         | 3/100 [00:56<30:46, 19.04s/it][A[A[A[A[A[A

testing loss: 0.020977551176593803; time: 1.6836140155792236s
testing accuracy: 71.34465026855469%
training loss: 0.004016672528516023; time: 16.728933095932007s
training accuracy: 78.18109130859375%








  4%|▍         | 4/100 [01:14<30:04, 18.80s/it][A[A[A[A[A[A

testing loss: 0.019691348855384736; time: 1.5055067539215088s
testing accuracy: 74.35658264160156%
training loss: 0.002976094618415794; time: 16.619519233703613s
training accuracy: 79.94591522216797%








  5%|▌         | 5/100 [01:33<29:32, 18.66s/it][A[A[A[A[A[A

testing loss: 0.01889817395547307; time: 1.6868975162506104s
testing accuracy: 74.8508071899414%
training loss: 0.002346035377241862; time: 16.981719493865967s
training accuracy: 80.92548370361328%








  6%|▌         | 6/100 [01:51<29:09, 18.61s/it][A[A[A[A[A[A

testing loss: 0.018225480854811398; time: 1.5170307159423828s
testing accuracy: 76.0257339477539%
training loss: 0.0019377700736295882; time: 16.57915163040161s
training accuracy: 81.7528076171875%








  7%|▋         | 7/100 [02:09<28:36, 18.45s/it][A[A[A[A[A[A

testing loss: 0.017386962838557087; time: 1.486297845840454s
testing accuracy: 76.45468139648438%
training loss: 0.001632117870180175; time: 16.718029499053955s
training accuracy: 82.5340576171875%








  8%|▊         | 8/100 [02:27<28:10, 18.37s/it][A[A[A[A[A[A

testing loss: 0.017096040452907264; time: 1.4576747417449951s
testing accuracy: 77.1260757446289%
training loss: 0.001412103206325227; time: 16.773301362991333s
training accuracy: 82.97676086425781%








  9%|▉         | 9/100 [02:46<27:48, 18.34s/it][A[A[A[A[A[A

testing loss: 0.016770988544369314; time: 1.465083122253418s
testing accuracy: 78.26370239257812%
training loss: 0.0012355741031038073; time: 16.788450241088867s


  _warn_prf(average, modifier, msg_start, len(result))


training accuracy: 83.5556869506836%








 10%|█         | 10/100 [03:05<28:03, 18.70s/it][A[A[A[A[A[A

testing loss: 0.01632744227826084; time: 1.4573919773101807s
testing accuracy: 78.28235626220703%
training loss: 0.0011058059365256301; time: 16.75439429283142s
training accuracy: 83.81009674072266%








 11%|█         | 11/100 [03:24<27:37, 18.62s/it][A[A[A[A[A[A

testing loss: 0.015863287319936043; time: 1.6718213558197021s
testing accuracy: 78.61805725097656%
training loss: 0.000987828955000553; time: 16.62383508682251s
training accuracy: 84.19271087646484%








 12%|█▏        | 12/100 [03:42<27:07, 18.50s/it][A[A[A[A[A[A

testing loss: 0.01603064478883233; time: 1.5627880096435547s
testing accuracy: 77.9932861328125%
training loss: 0.0009008577972911372; time: 16.654590845108032s
training accuracy: 84.385009765625%








 13%|█▎        | 13/100 [04:00<26:43, 18.43s/it][A[A[A[A[A[A

testing loss: 0.01571180577939102; time: 1.6175179481506348s
testing accuracy: 79.7650146484375%
training loss: 0.0008226104624374964; time: 16.698410034179688s
training accuracy: 84.6394271850586%








 14%|█▍        | 14/100 [04:18<26:18, 18.36s/it][A[A[A[A[A[A

testing loss: 0.015784698440173987; time: 1.4704539775848389s
testing accuracy: 78.50615692138672%
training loss: 0.0007572538426352872; time: 16.82315969467163s
training accuracy: 84.93789672851562%








 15%|█▌        | 15/100 [04:37<26:00, 18.36s/it][A[A[A[A[A[A

testing loss: 0.016631874220750752; time: 1.5276532173156738s
testing accuracy: 77.74151611328125%
training loss: 0.0007011331536117583; time: 16.92078924179077s
training accuracy: 85.0821304321289%








 16%|█▌        | 16/100 [04:55<25:43, 18.37s/it][A[A[A[A[A[A

testing loss: 0.015037650419611043; time: 1.4587209224700928s
testing accuracy: 79.75569152832031%
training loss: 0.0006554265883203377; time: 16.694656133651733s
training accuracy: 85.38060760498047%








 17%|█▋        | 17/100 [05:13<25:22, 18.34s/it][A[A[A[A[A[A

testing loss: 0.01595822441871377; time: 1.5591380596160889s
testing accuracy: 79.69041442871094%
training loss: 0.000612076391244408; time: 16.924307584762573s
training accuracy: 85.33052825927734%








 18%|█▊        | 18/100 [05:32<25:07, 18.38s/it][A[A[A[A[A[A

testing loss: 0.014967180228542366; time: 1.5494422912597656s
testing accuracy: 80.4830322265625%
training loss: 0.000575239917228197; time: 16.671935081481934s
training accuracy: 85.50680541992188%








 19%|█▉        | 19/100 [05:50<24:47, 18.37s/it][A[A[A[A[A[A

testing loss: 0.014973872318635988; time: 1.6529953479766846s
testing accuracy: 79.97948455810547%
training loss: 0.0005382172934388598; time: 17.05228853225708s
training accuracy: 85.67507934570312%








 20%|██        | 20/100 [06:10<24:56, 18.71s/it][A[A[A[A[A[A

testing loss: 0.015479607552745724; time: 1.5276920795440674s
testing accuracy: 80.1193618774414%
training loss: 0.0005107157295956408; time: 16.64642024040222s
training accuracy: 85.80529022216797%








 21%|██        | 21/100 [06:28<24:25, 18.55s/it][A[A[A[A[A[A

testing loss: 0.015219622763017584; time: 1.5004088878631592s
testing accuracy: 79.90489196777344%
training loss: 0.0004835142929862429; time: 16.68054986000061s
training accuracy: 85.99158477783203%








 22%|██▏       | 22/100 [06:46<23:58, 18.44s/it][A[A[A[A[A[A

testing loss: 0.01458742088405155; time: 1.4967267513275146s
testing accuracy: 80.35247802734375%
training loss: 0.00045718554495091456; time: 16.73304510116577s
training accuracy: 86.0556869506836%








 23%|██▎       | 23/100 [07:04<23:35, 18.38s/it][A[A[A[A[A[A

testing loss: 0.014924306530854187; time: 1.4904823303222656s
testing accuracy: 80.52033233642578%
training loss: 0.0004361879275430145; time: 16.877310037612915s
training accuracy: 86.1338119506836%








 24%|██▍       | 24/100 [07:23<23:18, 18.41s/it][A[A[A[A[A[A

testing loss: 0.014099321217289703; time: 1.5791246891021729s
testing accuracy: 81.22901916503906%
training loss: 0.00041489737504758897; time: 16.78594422340393s
training accuracy: 86.32612609863281%








 25%|██▌       | 25/100 [07:41<22:58, 18.37s/it][A[A[A[A[A[A

testing loss: 0.014308764077153325; time: 1.5004889965057373s
testing accuracy: 80.9585952758789%
training loss: 0.0003962958672316286; time: 16.711018323898315s
training accuracy: 86.29808044433594%








 26%|██▌       | 26/100 [07:59<22:37, 18.34s/it][A[A[A[A[A[A

testing loss: 0.014944728906448766; time: 1.5379834175109863s
testing accuracy: 80.66952514648438%
training loss: 0.0003790121187999003; time: 16.640118837356567s
training accuracy: 86.35617065429688%








 27%|██▋       | 27/100 [08:17<22:14, 18.28s/it][A[A[A[A[A[A

testing loss: 0.014338642287914632; time: 1.4862592220306396s
testing accuracy: 80.52033233642578%
training loss: 0.00036263247291325515; time: 16.926597356796265s
training accuracy: 86.50440216064453%








 28%|██▊       | 28/100 [08:36<22:00, 18.34s/it][A[A[A[A[A[A

testing loss: 0.014348514130737122; time: 1.529355764389038s
testing accuracy: 80.57627868652344%
training loss: 0.0003482040620403188; time: 16.68149161338806s
training accuracy: 86.49439239501953%








 29%|██▉       | 29/100 [08:54<21:38, 18.29s/it][A[A[A[A[A[A

testing loss: 0.014464197678269846; time: 1.480614423751831s
testing accuracy: 80.38977813720703%
training loss: 0.00033518163054488984; time: 16.867227792739868s
training accuracy: 86.68870544433594%








 30%|███       | 30/100 [09:13<21:41, 18.60s/it][A[A[A[A[A[A

testing loss: 0.014675950223676574; time: 1.524510383605957s
testing accuracy: 80.4830322265625%
training loss: 0.00032232649980815445; time: 16.623897552490234s
training accuracy: 86.77684020996094%








 31%|███       | 31/100 [09:32<21:16, 18.50s/it][A[A[A[A[A[A

testing loss: 0.014608494003665665; time: 1.6253774166107178s
testing accuracy: 81.11712646484375%
training loss: 0.00031081551947304085; time: 16.778342247009277s
training accuracy: 86.7568130493164%








 32%|███▏      | 32/100 [09:50<20:54, 18.44s/it][A[A[A[A[A[A

testing loss: 0.014453029185005508; time: 1.5300781726837158s
testing accuracy: 80.88399505615234%
training loss: 0.0002996251471709712; time: 16.77653408050537s
training accuracy: 86.85295867919922%








 33%|███▎      | 33/100 [10:08<20:32, 18.40s/it][A[A[A[A[A[A

testing loss: 0.01444818414205804; time: 1.4972858428955078s
testing accuracy: 81.09847259521484%
training loss: 0.0002905817074770723; time: 16.708003997802734s
training accuracy: 86.92107391357422%








 34%|███▍      | 34/100 [10:27<20:11, 18.35s/it][A[A[A[A[A[A

testing loss: 0.014411213110572408; time: 1.516857624053955s
testing accuracy: 80.4830322265625%
training loss: 0.00028008843423293786; time: 16.728912353515625s
training accuracy: 87.0252456665039%








 35%|███▌      | 35/100 [10:45<19:53, 18.36s/it][A[A[A[A[A[A

testing loss: 0.014187616531982746; time: 1.6525294780731201s
testing accuracy: 81.39686584472656%
training loss: 0.00025225303697491334; time: 16.699416399002075s
training accuracy: 88.0769271850586%








 36%|███▌      | 36/100 [11:03<19:32, 18.31s/it][A[A[A[A[A[A

testing loss: 0.013388238724978367; time: 1.4855570793151855s
testing accuracy: 82.36665344238281%
training loss: 0.00024263251880867948; time: 17.00139880180359s
training accuracy: 88.21113586425781%








 37%|███▋      | 37/100 [11:22<19:17, 18.38s/it][A[A[A[A[A[A

testing loss: 0.013360777038195737; time: 1.5115885734558105s
testing accuracy: 82.11488342285156%
