In [1]:
from pathlib import Path
import csv
import itertools
import numpy as np
from pprint import PrettyPrinter


import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_model_summary import summary
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

import h5py

import sklearn
import sklearn.metrics

pprint = PrettyPrinter()

## Configuration

In [2]:
SELECTED_LANGUAGES = {'_language_independent', 'francais', 'maninka', 'pular', 'susu'}

BASE_DIR = Path('/media/xtrem/data/experiments/nicolingua-0002-va-asr/datasets/gn_va_asr_dataset_2020-08-24_02')


ANNOTATIONS_PATH = BASE_DIR/ "annotated_segments" / "metadata.csv"


FEATURE_NAMES = [
    "wav2vec_features-c", 
    "wav2vec_features-z", 
    "retrained-wav2vec_features-c", 
    "retrained-wav2vec_features-z"
]

# CONV_POOLING_TYPES = ['avg', 'max']
CONV_POOLING_TYPES = ['max']
OBJECTIVE_TYPES = ['voice_cmd', 'voice_cmd__and__voice_cmd_lng']
CONV_DROPOUT_PROBABILITIES = [0.2]
FC_DROPOUT_PROBABILITIES = [0.2]


TRAIN_PERCENT = .7
FOLD_COUNT = 5


RESULTS_DIR = f'results_101'
EPOCHS = 1000
BATCH_SIZE = 512
MAX_FEATURE_SEQUENCE_LENGTH = 200

GPU_ID = 1
device = torch.device(f"cuda:{GPU_ID}")

## Load & shuffle metadata records

In [3]:
def load_metadata():
    with open(ANNOTATIONS_PATH) as f:
        reader = csv.DictReader(f)
        for r in reader:
            if r['language'] in SELECTED_LANGUAGES:
                yield r

In [4]:
# load
metadata_records = list(load_metadata())

# shuffle
metadata_shuffler_rs = np.random.RandomState(seed=42)
metadata_shuffler_rs.shuffle(metadata_records)

In [5]:
bias_category_fields = [
     "device_id"
    ,"language"
    ,"speaker_gender"
    ,"speaker_mothertongue"
]

bias_categories = {}
for c in bias_category_fields:
    bias_categories[c] = sorted({r[c] for r in metadata_records})


_ = [print(f"\n{k}: \n\t{','.join(v)}") for k,v in bias_categories.items()]


device_id: 
	d001,d002,d003

language: 
	_language_independent,francais,maninka,pular,susu

speaker_gender: 
	F,M

speaker_mothertongue: 
	maninka,pular,susu


### Labels

In [6]:
# VOICE COMMANDS
voice_cmd_class_names = sorted({r['label'] for r in metadata_records})
voice_cmd_class_count = len(voice_cmd_class_names)
voice_cmd_class_id_by_name = {c:i for i, c in enumerate(voice_cmd_class_names)}

print("Classes - Voice Commands")
_ = [print(f"{v:4}: {k}") for k,v in voice_cmd_class_id_by_name.items()]

print("---------------------")


# VOICE COMMAND LANGUAGES
voice_cmd_lng_class_names = sorted({r['language'] for r in metadata_records})
voice_cmd_lng_class_count = len(voice_cmd_lng_class_names)
voice_cmd_lng_class_id_by_name = {c:i for i, c in enumerate(voice_cmd_lng_class_names)}

print("Classes - Voice Command Languages")
_ = [print(f"{v:3}: {k}") for k,v in voice_cmd_lng_class_id_by_name.items()]

print("---------------------")



# SPEAKER MOTHERTONGUE
spkr_mothertongue_class_names = sorted({r['speaker_mothertongue'] for r in metadata_records})
spkr_mothertongue_class_count = len(spkr_mothertongue_class_names)
spkr_mothertongue_class_id_by_name = {c:i for i,c in enumerate(spkr_mothertongue_class_names)}

print("Classes - Speaker Mothertongues")
_ = [print(f"{v:3}: {k}") for k,v in spkr_mothertongue_class_id_by_name.items()]

print("---------------------")



# SPEAKER GENDER
spkr_gender_class_names = sorted({r['speaker_gender'] for r in metadata_records})
spkr_gender_class_count = len(spkr_gender_class_names)
spkr_gender_class_id_by_name = {c:i for i, c in enumerate(spkr_gender_class_names)}

print("Classes - Speaker Gender")
_ = [print(f"{v:3}: {k}") for k,v in spkr_gender_class_id_by_name.items()]

print("----------------------")



Classes - Voice Commands
   0: 101_wake_word__francais
   1: 101_wake_word__maninka
   2: 101_wake_word__pular
   3: 101_wake_word__susu
   4: 201_add_contact__francais
   5: 201_add_contact__maninka
   6: 201_add_contact__pular
   7: 201_add_contact__susu
   8: 202_search_contact__francais
   9: 202_search_contact__maninka
  10: 202_search_contact__pular
  11: 202_search_contact__susu
  12: 203_update_contact__francais
  13: 203_update_contact__maninka
  14: 203_update_contact__pular
  15: 203_update_contact__susu
  16: 204_delete_contact__francais
  17: 204_delete_contact__maninka
  18: 204_delete_contact__pular
  19: 204_delete_contact__susu
  20: 205_call_contact__francais
  21: 205_call_contact__maninka
  22: 205_call_contact__pular
  23: 205_call_contact__susu
  24: 206_yes__francais
  25: 206_yes__maninka
  26: 206_yes__pular
  27: 206_yes__susu
  28: 207_no__francais
  29: 207_no__maninka
  30: 207_no__pular
  31: 207_no__susu
  32: 301_zero__francais
  33: 301_zero__maninka
  

### Inspect metadata

In [7]:
def count_by_attribute(records, attribute_names):
    attribute_name_instances = {}
    for attribute_name in attribute_names:
        attribute_name_instances[attribute_name] = {r[attribute_name] for r in records}
        
    l = [attribute_name_instances[attribute_name] for attribute_name in attribute_names]
    
    
    
    for attribute_values in sorted(itertools.product(*l)):
        
        def record_match(r):
            for i in range(len(attribute_names)):
                if r[attribute_names[i]] != attribute_values[i]:
                    return False
            return True
            
        record_instances = [r for r in records if record_match(r)]
        count = len(record_instances)
        
        yield (attribute_values, count)

In [8]:
print("RECORDS BY DEVICE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['device_id']))]
print("")

print("RECORDS BY LANGUAGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['language']))]
print("")

print("RECORDS BY GENDER")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_gender']))]
print("")

print("RECORDS BY AGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_age']))]
print("")

print("RECORDS BY SPEAKER")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_id']))]
print("")

print("RECORDS BY SPEAKER BY LANGUAGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_id', 'language']))]
print("")

print("RECORDS BY SPEAKER BY LABEL")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['label']))]
print("")

RECORDS BY DEVICE
	(('d001',), 2759)
	(('d002',), 2741)
	(('d003',), 2759)

RECORDS BY LANGUAGE
	(('_language_independent',), 3072)
	(('francais',), 1260)
	(('maninka',), 1356)
	(('pular',), 909)
	(('susu',), 1662)

RECORDS BY GENDER
	(('F',), 2820)
	(('M',), 5439)

RECORDS BY AGE
	(('12',), 237)
	(('13',), 252)
	(('15',), 603)
	(('17',), 1050)
	(('18',), 855)
	(('19',), 273)
	(('20',), 291)
	(('27',), 255)
	(('28',), 183)
	(('29',), 237)
	(('31',), 255)
	(('32',), 291)
	(('33',), 183)
	(('34',), 129)
	(('35',), 498)
	(('37',), 309)
	(('38',), 441)
	(('43',), 183)
	(('44',), 540)
	(('5',), 129)
	(('55',), 183)
	(('61',), 390)
	(('63',), 237)
	(('67',), 255)

RECORDS BY SPEAKER
	(('s001',), 183)
	(('s002',), 129)
	(('s003',), 183)
	(('s004',), 183)
	(('s005',), 129)
	(('s006',), 129)
	(('s007',), 183)
	(('s008',), 129)
	(('s009',), 237)
	(('s010',), 291)
	(('s011',), 129)
	(('s012',), 183)
	(('s013',), 129)
	(('s014',), 237)
	(('s015',), 291)
	(('s016',), 183)
	(('s017',), 237)
	(('s018

## Prepare Cross Validation Folds
- Partition by (speaker, language)
- Each (speaker, language) correspond to `utterance_count * device_count`
- For each fold, all `utterance_count * device_count` records for the same speaker in the same language are either in the TRAIN or the VALIDATION sets, but not both.


In [9]:
def generate_train_test_records_per_fold(all_records):
    records_per_fold = {}
    
    all_speaker_languages = sorted({(r['speaker_id'], r['language']) for r in all_records})

    sl_count = len(all_speaker_languages)
    all_sl_indices = range(sl_count)
    train_sl_count = int(np.ceil(sl_count*TRAIN_PERCENT))
    test_sl_count = sl_count - train_sl_count

    for fold_index in range(FOLD_COUNT):
        fold_rsampler = np.random.RandomState(seed=fold_index)

        train_sl_index_set = set(fold_rsampler.choice(all_sl_indices, train_sl_count, replace=False))
        train_sl_set = {all_speaker_languages[i] for i in train_sl_index_set}

        test_sl_index_set = set(all_sl_indices).difference(train_sl_index_set)
        test_sl_set = {all_speaker_languages[i] for i in test_sl_index_set}

        train_records = [r for r in all_records if (r['speaker_id'], r['language']) in train_sl_set]
        test_records = [r for r in all_records if (r['speaker_id'], r['language']) in test_sl_set]
        
        
        records_per_fold[fold_index] = {
            "train_records": train_records,
            "test_records": test_records
        }
    
    return records_per_fold

In [10]:
records_per_fold = generate_train_test_records_per_fold(metadata_records)

### Inspect Folds

In [11]:
for fold_index in range(FOLD_COUNT):
    train_records = records_per_fold[fold_index]["train_records"]
    test_records = records_per_fold[fold_index]["test_records"]

    print(f"Fold {fold_index} -- TRAIN")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(train_records, ['speaker_id', 'language'])) if r[1]>0]

    print(f"Fold {fold_index} -- TEST")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(test_records, ['speaker_id', 'language'])) if r[1]>0]
    
    print("---------------------")

Fold 0 -- TRAIN
	(('s001', '_language_independent'), 75)
	(('s001', 'maninka'), 54)
	(('s001', 'susu'), 54)
	(('s002', '_language_independent'), 75)
	(('s002', 'maninka'), 54)
	(('s003', '_language_independent'), 75)
	(('s003', 'maninka'), 54)
	(('s003', 'susu'), 54)
	(('s004', '_language_independent'), 75)
	(('s004', 'susu'), 54)
	(('s005', '_language_independent'), 75)
	(('s006', '_language_independent'), 75)
	(('s006', 'maninka'), 54)
	(('s007', '_language_independent'), 75)
	(('s007', 'pular'), 54)
	(('s007', 'susu'), 54)
	(('s008', '_language_independent'), 75)
	(('s008', 'maninka'), 54)
	(('s009', 'pular'), 54)
	(('s009', 'susu'), 54)
	(('s010', '_language_independent'), 75)
	(('s010', 'maninka'), 54)
	(('s010', 'pular'), 54)
	(('s010', 'susu'), 54)
	(('s011', '_language_independent'), 75)
	(('s011', 'susu'), 54)
	(('s012', '_language_independent'), 75)
	(('s012', 'francais'), 54)
	(('s012', 'susu'), 54)
	(('s013', '_language_independent'), 75)
	(('s013', 'susu'), 54)
	(('s014', 

	(('s002', 'maninka'), 54)
	(('s003', '_language_independent'), 75)
	(('s003', 'maninka'), 54)
	(('s004', '_language_independent'), 75)
	(('s004', 'maninka'), 54)
	(('s005', '_language_independent'), 75)
	(('s005', 'susu'), 54)
	(('s006', '_language_independent'), 75)
	(('s007', '_language_independent'), 75)
	(('s007', 'pular'), 54)
	(('s007', 'susu'), 54)
	(('s008', '_language_independent'), 75)
	(('s009', 'susu'), 54)
	(('s010', 'francais'), 54)
	(('s010', 'pular'), 54)
	(('s010', 'susu'), 54)
	(('s011', 'susu'), 54)
	(('s012', '_language_independent'), 75)
	(('s012', 'francais'), 54)
	(('s012', 'susu'), 54)
	(('s013', '_language_independent'), 75)
	(('s013', 'susu'), 54)
	(('s014', '_language_independent'), 75)
	(('s014', 'francais'), 54)
	(('s015', '_language_independent'), 75)
	(('s015', 'maninka'), 54)
	(('s015', 'pular'), 54)
	(('s016', '_language_independent'), 75)
	(('s016', 'maninka'), 54)
	(('s016', 'susu'), 54)
	(('s017', '_language_independent'), 75)
	(('s017', 'francais')

In [12]:
for fold_index in range(FOLD_COUNT):
    train_records = records_per_fold[fold_index]["train_records"]
    test_records = records_per_fold[fold_index]["test_records"]

    print(f"Fold {fold_index} -- TRAIN: ({len(train_records)})")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(train_records, ['language'])) if r[1]>0]

    print(f"Fold {fold_index} -- TEST: ({len(test_records)})")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(test_records, ['language'])) if r[1]>0]
    
    print("---------------------")

Fold 0 -- TRAIN: (6018)
	(('_language_independent',), 2400)
	(('francais',), 738)
	(('maninka',), 960)
	(('pular',), 852)
	(('susu',), 1068)
Fold 0 -- TEST: (2241)
	(('_language_independent',), 672)
	(('francais',), 522)
	(('maninka',), 396)
	(('pular',), 57)
	(('susu',), 594)
---------------------
Fold 1 -- TRAIN: (5940)
	(('_language_independent',), 2325)
	(('francais',), 978)
	(('maninka',), 906)
	(('pular',), 519)
	(('susu',), 1212)
Fold 1 -- TEST: (2319)
	(('_language_independent',), 747)
	(('francais',), 282)
	(('maninka',), 450)
	(('pular',), 390)
	(('susu',), 450)
---------------------
Fold 2 -- TRAIN: (6036)
	(('_language_independent',), 2397)
	(('francais',), 972)
	(('maninka',), 912)
	(('pular',), 573)
	(('susu',), 1182)
Fold 2 -- TEST: (2223)
	(('_language_independent',), 675)
	(('francais',), 288)
	(('maninka',), 444)
	(('pular',), 336)
	(('susu',), 480)
---------------------
Fold 3 -- TRAIN: (5796)
	(('_language_independent',), 2097)
	(('francais',), 918)
	(('maninka',), 

## Load Features

In [13]:
def load_features(records, feature_name):
    features_list = []
    
    features_input_dir = BASE_DIR / feature_name

    for r in records:
        feature_file_name = r['file'].replace(".wav", ".h5context")
        feature_path = Path(features_input_dir) / feature_file_name
        with h5py.File(feature_path, 'r') as f:
            features_shape = f['info'][1:].astype(int)
            features = np.array(f['features'][:]).reshape(features_shape)
            
            padded_features = np.zeros((MAX_FEATURE_SEQUENCE_LENGTH, 512), dtype=features.dtype)
            padded_features[:features_shape[0], :] = features
            
            
            features_list.append(padded_features)
    return features_list

In [14]:
def get_bias_category_labels(records):
    bias_category_labels = {}
    
    for cat in bias_categories:
        for cat_val in bias_categories[cat]:
            bias_category_labels[f"{cat}__{cat_val}"] = [1 if r[cat]==cat_val else 0 for r in records]
            
    return bias_category_labels

# Classification Models

In [15]:
class ASRCNN(nn.Module):
    def __init__(self, 
                 conv_pooling_type, 
                 conv_dropout_p, 
                 fc_dropout_p, 
                 voice_cmd_neuron_count, 
                 voice_cmd_lng_neuron_count,
                 objective_type
                ):
        
        super(ASRCNN, self).__init__()
        
        if conv_pooling_type not in {"max", "avg"}:
            raise ValueError(f"Unknown Conv Pooling Type: {conv_pooling_type}")
            
        conv_pooling_class_by_type = {
            "max": nn.MaxPool1d,
            "avg": nn.AvgPool1d,
        }
        
        conv_pooling_class = conv_pooling_class_by_type[conv_pooling_type]
        
        self.objective_type = objective_type
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=8, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=8, out_channels=8, kernel_size=3)
        self.drop1 = nn.Dropout(p=conv_dropout_p)
        self.pool1 = conv_pooling_class(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=3)
        self.drop2 = nn.Dropout(p=conv_dropout_p)
        self.pool2 = conv_pooling_class(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3)
        self.drop3 = nn.Dropout(p=conv_dropout_p)
        self.pool3 = conv_pooling_class(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3)
        self.drop4 = nn.Dropout(p=conv_dropout_p)
        self.pool4 = conv_pooling_class(kernel_size=2, stride=2)
        
        self.drop5 = nn.Dropout(p=fc_dropout_p)
        
        self.lin61 = nn.Linear(in_features=112, out_features=voice_cmd_neuron_count)
        
        # 'voice_cmd', 'voice_cmd__and__voice_cmd_lng'
        if self.objective_type == 'voice_cmd__and__voice_cmd_lng':
            self.lin62 = nn.Linear(in_features=112, out_features=voice_cmd_lng_neuron_count)
                
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        v = self.drop5(v)
        
        if self.objective_type == 'voice_cmd':
            logits_voice_cmd = self.lin61(v)
            return logits_voice_cmd
        elif self.objective_type == 'voice_cmd__and__voice_cmd_lng':
            logits_voice_cmd = self.lin61(v)
            logits_voice_cmd_lng = self.lin62(v)
            return logits_voice_cmd, logits_voice_cmd_lng
        else:
            raise(f"Unknown objective type: {self.objective_type}")

In [16]:
def get_data_for_fold(fold_id, feature_name):
    
    train_records = records_per_fold[fold_id]["train_records"]
    test_records = records_per_fold[fold_id]["test_records"]
    
    train_features = load_features(train_records, feature_name)
    test_features = load_features(test_records, feature_name)
    
    train_x = np.array(train_features)
    test_x = np.array(test_features)
    
    train_y = {}
    train_y['voice_cmd'] = np.array([voice_cmd_class_id_by_name[r['label']] for r in train_records])
    train_y['voice_cmd_lng'] = np.array([voice_cmd_lng_class_id_by_name[r['language']] for r in train_records])
    train_y['spkr_mothertongue'] = np.array([spkr_mothertongue_class_id_by_name[r['speaker_mothertongue']] for r in train_records])
    train_y['spkr_gender'] = np.array([spkr_gender_class_id_by_name[r['speaker_gender']] for r in train_records])
    
    

    
    test_y = {}
    test_y['voice_cmd'] = np.array([voice_cmd_class_id_by_name[r['label']] for r in test_records])
    test_y['voice_cmd_lng'] = np.array([voice_cmd_lng_class_id_by_name[r['language']] for r in test_records])
    test_y['spkr_mothertongue'] = np.array([spkr_mothertongue_class_id_by_name[r['speaker_mothertongue']] for r in test_records])
    test_y['spkr_gender'] = np.array([spkr_gender_class_id_by_name[r['speaker_gender']] for r in test_records])

    train_bias_category_labels = get_bias_category_labels(train_records)
    test_bias_category_labels = get_bias_category_labels(test_records)
    
    return train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels

    
def get_loaders_for_fold(fold_id, feature_name, batch_size):
    
    train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels = \
        get_data_for_fold(fold_id, feature_name)
    
    
    
    train_dataset = TensorDataset(
        torch.tensor(train_x), 
        torch.tensor(train_y['voice_cmd']),
        torch.tensor(train_y['voice_cmd_lng']),
        # torch.tensor(train_y['spkr_mothertongue']),
        # torch.tensor(train_y['spkr_gender']),
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    test_dataset = TensorDataset(
        torch.tensor(test_x), 
        torch.tensor(test_y['voice_cmd']),
        torch.tensor(test_y['voice_cmd_lng']),
        # torch.tensor(test_y['spkr_mothertongue']),
        # torch.tensor(test_y['spkr_gender']),
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader, train_bias_category_labels, test_bias_category_labels


def get_predictions_for_logits(logits):
    probs = F.softmax(logits, dim=1)
    return torch.argmax(probs, dim=1)

In [17]:
def train(model, optimizer, criterion, objective_type, train_loader):
    model.train()
    train_loss = 0

    for batch_idx, (x, y_voice_cmd, y_voice_cmd_lng) in enumerate(train_loader):
        x = x.to(device)
        y_voice_cmd = y_voice_cmd.to(device)
        y_voice_cmd_lng = y_voice_cmd_lng.to(device)

        optimizer.zero_grad()
        outputs = model(x)

        if objective_type == 'voice_cmd':
            logits_voice_cmd = outputs
            loss = criterion(logits_voice_cmd, y_voice_cmd)
        elif objective_type == 'voice_cmd__and__voice_cmd_lng':
            logits_voice_cmd, logits_voice_cmd_lng = outputs    
            loss = (criterion(logits_voice_cmd, y_voice_cmd) + criterion(logits_voice_cmd_lng, y_voice_cmd_lng)) / 2
            
        else:
            raise ValueError(f"Unknown objective type: {objective_type}")

        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        

def test(model, criterion, objective_type, loader, bias_category_labels):
    model.eval()
    accumulated_loss = 0

    pred_classes = []
    true_classes = []

    pred_classes_lng = []
    true_classes_lng = []

    for batch_idx, (x, y_voice_cmd, y_voice_cmd_lng) in enumerate(loader):
        x = x.to(device)
        y_voice_cmd = y_voice_cmd.to(device)
        y_voice_cmd_lng = y_voice_cmd_lng.to(device)

        outputs = model(x)

        if objective_type == 'voice_cmd':
            logits_voice_cmd = outputs

            pred_classes.extend(
                get_predictions_for_logits(logits_voice_cmd).cpu().numpy()
            )
            true_classes.extend(y_voice_cmd.cpu().numpy())

            loss = criterion(logits_voice_cmd, y_voice_cmd)
        elif objective_type == 'voice_cmd__and__voice_cmd_lng':
            logits_voice_cmd, logits_voice_cmd_lng = outputs
            pred_classes.extend(
                get_predictions_for_logits(logits_voice_cmd).cpu().numpy()
            )
            true_classes.extend(y_voice_cmd.cpu().numpy())

            pred_classes_lng.extend(
                get_predictions_for_logits(logits_voice_cmd_lng).cpu().numpy()
            )
            true_classes_lng.extend(y_voice_cmd_lng.cpu().numpy())

            loss = (criterion(logits_voice_cmd, y_voice_cmd) + criterion(logits_voice_cmd_lng, y_voice_cmd_lng)) /2
        else:
            raise ValueError(f"Unknown objective type: {objective_type}")

        accumulated_loss += loss.item()

    n = len(true_classes)

    average_loss = accumulated_loss/n
    
    acc = sklearn.metrics.accuracy_score(true_classes, pred_classes)
    acc_by_bais_category = {
        category: sklearn.metrics.accuracy_score(true_classes, pred_classes, sample_weight=sw)
        for category, sw in bias_category_labels.items()
    }
    
    
    if objective_type == 'voice_cmd__and__voice_cmd_lng':
        acc_lng = sklearn.metrics.accuracy_score(true_classes_lng, pred_classes_lng)
        acc_by_bais_category_lng = {
            category: sklearn.metrics.accuracy_score(true_classes_lng, pred_classes_lng, sample_weight=sw)
            for category, sw in bias_category_labels.items()
        }
    else:
        acc_lng = -1
        acc_by_bais_category_lng = {
            category: -1
            for category, sw in bias_category_labels.items()
        }
        
    return n, average_loss, acc, acc_by_bais_category, acc_lng, acc_by_bais_category_lng
      
        
def train_on_fold(model, fold_id, feature_name, objective_type, batch_size, epochs):
    torch.manual_seed(0)
    results = {}
    
    train_loader, test_loader, train_bias_category_labels, test_bias_category_labels = get_loaders_for_fold(fold_id, feature_name, batch_size)

    print(summary(model, torch.zeros((10, MAX_FEATURE_SEQUENCE_LENGTH, 512)).to(device), show_input=False))
    print(f"train_n: {len(train_loader.dataset)}")
    print(f"test_n: {len(test_loader.dataset)}")

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss(reduction='sum')

    for epoch in range(1, epochs+1):
        
        # train on training set
        train(model, optimizer, criterion, objective_type, train_loader)
        
        # test on training set
        train_n, train_average_loss, train_acc, train_acc_by_bais_category, train_acc_lng, train_acc_by_bais_category_lng = \
            test(model, criterion, objective_type, train_loader, train_bias_category_labels)
        
        # test on test set
        test_n, test_average_loss, test_acc, test_acc_by_bais_category, test_acc_lng, test_acc_by_bais_category_lng = \
            test(model, criterion, objective_type, test_loader, test_bias_category_labels)
        

        if epoch%10==0:
            print(f"Epoch: {epoch}. Train Loss: {train_average_loss:0.4}. Test Loss: {test_average_loss:0.4}. Train Acc: {train_acc:0.4}. Test Acc:{test_acc:0.4}")
        
         
        results[epoch] = {
            'epoch': epoch,
            
            'train_n': train_n,
            'train_loss': train_average_loss,
            'train_acc': train_acc,
            'train_acc_lng': train_acc_lng,
            
            'test_n': test_n,
            'test_loss': test_average_loss,
            'test_acc': test_acc,
            'test_acc_lng': test_acc_lng
        }
        
        for c in train_acc_by_bais_category:
            results[epoch][f"train_acc_{c}"] = train_acc_by_bais_category[c]
            results[epoch][f"train_n_{c}"] = int(np.sum(train_bias_category_labels[c]))
            
        for c in train_acc_by_bais_category_lng:
            results[epoch][f"train_acc_lng_{c}"] = train_acc_by_bais_category_lng[c]
            
            
        for c in test_acc_by_bais_category:
            results[epoch][f"test_acc_{c}"] = test_acc_by_bais_category[c]
            results[epoch][f"test_n_{c}"] = int(np.sum(test_bias_category_labels[c]))

        for c in test_acc_by_bais_category_lng:
            results[epoch][f"test_acc_lng_{c}"] = test_acc_by_bais_category_lng[c]
            

    return results

In [18]:
import csv
from pathlib import Path

def results_exist(model_name, feature_name, fold_id):
    fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_id}.csv"
    return Path(fname).is_file()
    

def save_results(model_name, all_folds_results):
    for result_entry in all_folds_results:
        feature_name = result_entry['feature_name']
        fold_index = result_entry['fold_index']
        
        Path(RESULTS_DIR).mkdir(exist_ok=True, parents=True)
        fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}.csv"
        Path(fname).parent.mkdir(parents=True, exist_ok=True)
        with open(fname, 'w') as f:
            fieldnames = sorted(result_entry['epochs'][1].keys())
            
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='raise')
            
            writer.writeheader()
            
            for epoch in sorted(result_entry['epochs'].keys()):
                writer.writerow(result_entry['epochs'][epoch])

In [None]:
trial_params = list(
    itertools.product(
        range(FOLD_COUNT),
        CONV_POOLING_TYPES,
        CONV_DROPOUT_PROBABILITIES,
        FC_DROPOUT_PROBABILITIES,
        FEATURE_NAMES, 
        OBJECTIVE_TYPES
    )
)    

print("Plan:")
_ = [print(t) for t in trial_params]
print()

for fold_id, conv_pooling_type, conv_dropout_p, fc_dropout_p, feature_name, objective_type in trial_params:

    model_name = f"ASRCNN" + \
    f"__conv_pool_{conv_pooling_type}" + \
    f"__conv_dp_{conv_dropout_p}" + \
    f"__fc_dp_{fc_dropout_p}" + \
    f"__fn_{feature_name}" + \
    f"__obj_{objective_type}"

    if results_exist(model_name, feature_name, fold_id):
        print(f"skipping ({fold_id}, {conv_pooling_type}, {conv_dropout_p}, {fc_dropout_p}, {feature_name}, {objective_type})")
        continue

        
    model = ASRCNN(
        conv_pooling_type, 
        conv_dropout_p, 
        fc_dropout_p, 
        voice_cmd_neuron_count = voice_cmd_class_count, 
        voice_cmd_lng_neuron_count = voice_cmd_lng_class_count,
        objective_type = objective_type
    ).to(device)

    print(f"{model_name} using {feature_name} on fold#{fold_id}")

    epochs_results = train_on_fold(
        model, 
        fold_id, 
        feature_name, 
        objective_type, 
        batch_size = BATCH_SIZE, 
        epochs = EPOCHS
    )

    # results for only one fold
    folds_results = [{
        'fold_index': fold_id,
        'feature_name': feature_name,
        'epochs': epochs_results
    }]
    
    
    save_results(model_name, folds_results)
    # write_epoch_test_logits(model_name, all_folds_results)

    del model

Plan:
(0, 'max', 0.1, 0.1, 'wav2vec_features-c', 'voice_cmd')
(0, 'max', 0.1, 0.1, 'wav2vec_features-c', 'voice_cmd__and__voice_cmd_lng')
(0, 'max', 0.1, 0.1, 'wav2vec_features-z', 'voice_cmd')
(0, 'max', 0.1, 0.1, 'wav2vec_features-z', 'voice_cmd__and__voice_cmd_lng')
(0, 'max', 0.1, 0.1, 'retrained-wav2vec_features-c', 'voice_cmd')
(0, 'max', 0.1, 0.1, 'retrained-wav2vec_features-c', 'voice_cmd__and__voice_cmd_lng')
(0, 'max', 0.1, 0.1, 'retrained-wav2vec_features-z', 'voice_cmd')
(0, 'max', 0.1, 0.1, 'retrained-wav2vec_features-z', 'voice_cmd__and__voice_cmd_lng')
(1, 'max', 0.1, 0.1, 'wav2vec_features-c', 'voice_cmd')
(1, 'max', 0.1, 0.1, 'wav2vec_features-c', 'voice_cmd__and__voice_cmd_lng')
(1, 'max', 0.1, 0.1, 'wav2vec_features-z', 'voice_cmd')
(1, 'max', 0.1, 0.1, 'wav2vec_features-z', 'voice_cmd__and__voice_cmd_lng')
(1, 'max', 0.1, 0.1, 'retrained-wav2vec_features-c', 'voice_cmd')
(1, 'max', 0.1, 0.1, 'retrained-wav2vec_features-c', 'voice_cmd__and__voice_cmd_lng')
(1, 'max',

Epoch: 450. Train Loss: 0.04048. Test Loss: 1.135. Train Acc: 0.9953. Test Acc:0.7831
Epoch: 460. Train Loss: 0.03851. Test Loss: 1.132. Train Acc: 0.9963. Test Acc:0.7867
Epoch: 470. Train Loss: 0.03794. Test Loss: 1.156. Train Acc: 0.9952. Test Acc:0.7827
Epoch: 480. Train Loss: 0.0434. Test Loss: 1.175. Train Acc: 0.9925. Test Acc:0.7733
Epoch: 490. Train Loss: 0.02869. Test Loss: 1.163. Train Acc: 0.9978. Test Acc:0.7867
Epoch: 500. Train Loss: 0.03276. Test Loss: 1.174. Train Acc: 0.997. Test Acc:0.7791
Epoch: 510. Train Loss: 0.03528. Test Loss: 1.171. Train Acc: 0.9955. Test Acc:0.7769
Epoch: 520. Train Loss: 0.02681. Test Loss: 1.184. Train Acc: 0.9972. Test Acc:0.7805
Epoch: 530. Train Loss: 0.03155. Test Loss: 1.2. Train Acc: 0.9967. Test Acc:0.7715
Epoch: 540. Train Loss: 0.02142. Test Loss: 1.178. Train Acc: 0.9987. Test Acc:0.7885
Epoch: 550. Train Loss: 0.02521. Test Loss: 1.193. Train Acc: 0.9972. Test Acc:0.7863
Epoch: 560. Train Loss: 0.0225. Test Loss: 1.194. Train Ac

Epoch: 210. Train Loss: 0.3293. Test Loss: 0.97. Train Acc: 0.9377. Test Acc:0.7336
Epoch: 220. Train Loss: 0.2984. Test Loss: 0.9476. Train Acc: 0.9501. Test Acc:0.743
Epoch: 230. Train Loss: 0.2904. Test Loss: 0.951. Train Acc: 0.95. Test Acc:0.743
Epoch: 240. Train Loss: 0.269. Test Loss: 0.9556. Train Acc: 0.9551. Test Acc:0.7367
Epoch: 250. Train Loss: 0.2598. Test Loss: 0.9402. Train Acc: 0.9575. Test Acc:0.7421
Epoch: 260. Train Loss: 0.2474. Test Loss: 0.9417. Train Acc: 0.9566. Test Acc:0.743
Epoch: 270. Train Loss: 0.232. Test Loss: 0.9497. Train Acc: 0.9616. Test Acc:0.7434
Epoch: 280. Train Loss: 0.2391. Test Loss: 0.967. Train Acc: 0.9593. Test Acc:0.7291
Epoch: 290. Train Loss: 0.2183. Test Loss: 0.9446. Train Acc: 0.9614. Test Acc:0.7367
Epoch: 300. Train Loss: 0.2078. Test Loss: 0.9824. Train Acc: 0.9636. Test Acc:0.7385
Epoch: 310. Train Loss: 0.2019. Test Loss: 0.9631. Train Acc: 0.9673. Test Acc:0.7381
Epoch: 320. Train Loss: 0.1985. Test Loss: 0.9808. Train Acc: 0.9

Epoch: 10. Train Loss: 4.012. Test Loss: 4.066. Train Acc: 0.054. Test Acc:0.04328
Epoch: 20. Train Loss: 3.307. Test Loss: 3.428. Train Acc: 0.1437. Test Acc:0.1272
Epoch: 30. Train Loss: 2.894. Test Loss: 3.089. Train Acc: 0.2494. Test Acc:0.2044
Epoch: 40. Train Loss: 2.395. Test Loss: 2.695. Train Acc: 0.3785. Test Acc:0.2918
Epoch: 50. Train Loss: 1.998. Test Loss: 2.389. Train Acc: 0.487. Test Acc:0.3775
Epoch: 60. Train Loss: 1.68. Test Loss: 2.123. Train Acc: 0.5758. Test Acc:0.4498
Epoch: 70. Train Loss: 1.431. Test Loss: 1.925. Train Acc: 0.6338. Test Acc:0.502
Epoch: 80. Train Loss: 1.189. Test Loss: 1.668. Train Acc: 0.709. Test Acc:0.5716
Epoch: 90. Train Loss: 1.043. Test Loss: 1.537. Train Acc: 0.7433. Test Acc:0.5988
Epoch: 100. Train Loss: 0.9003. Test Loss: 1.405. Train Acc: 0.7836. Test Acc:0.6372
Epoch: 110. Train Loss: 0.8208. Test Loss: 1.353. Train Acc: 0.8038. Test Acc:0.6484
Epoch: 120. Train Loss: 0.759. Test Loss: 1.303. Train Acc: 0.8141. Test Acc:0.6609
Epo

Epoch: 980. Train Loss: 0.02882. Test Loss: 1.227. Train Acc: 0.9968. Test Acc:0.7671
Epoch: 990. Train Loss: 0.03591. Test Loss: 1.265. Train Acc: 0.9944. Test Acc:0.7657
Epoch: 1000. Train Loss: 0.03307. Test Loss: 1.228. Train Acc: 0.9947. Test Acc:0.7715
ASRCNN__conv_pool_max__conv_dp_0.1__fc_dp_0.1__fn_wav2vec_features-z__obj_voice_cmd__and__voice_cmd_lng using wav2vec_features-z on fold#0
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Conv1d-1        [10, 8, 200]           4,104           4,104
          Conv1d-2        [10, 8, 198]             200             200
         Dropout-3        [10, 8, 198]               0               0
       MaxPool1d-4         [10, 8, 99]               0               0
          Conv1d-5        [10, 16, 97]             400             400
         Dropout-6        [10, 16, 97]               0               0
       MaxPool1d-7        [10, 16

Epoch: 750. Train Loss: 0.1297. Test Loss: 1.095. Train Acc: 0.9754. Test Acc:0.7153
Epoch: 760. Train Loss: 0.1217. Test Loss: 1.091. Train Acc: 0.9819. Test Acc:0.7122
Epoch: 770. Train Loss: 0.1186. Test Loss: 1.092. Train Acc: 0.9804. Test Acc:0.718
Epoch: 780. Train Loss: 0.121. Test Loss: 1.083. Train Acc: 0.9831. Test Acc:0.7175
Epoch: 790. Train Loss: 0.1146. Test Loss: 1.098. Train Acc: 0.9837. Test Acc:0.7158
Epoch: 800. Train Loss: 0.107. Test Loss: 1.098. Train Acc: 0.9847. Test Acc:0.7242
Epoch: 810. Train Loss: 0.1044. Test Loss: 1.105. Train Acc: 0.9829. Test Acc:0.7202
Epoch: 820. Train Loss: 0.1167. Test Loss: 1.119. Train Acc: 0.9797. Test Acc:0.7126
Epoch: 830. Train Loss: 0.1061. Test Loss: 1.116. Train Acc: 0.9839. Test Acc:0.7224
Epoch: 840. Train Loss: 0.1053. Test Loss: 1.121. Train Acc: 0.9827. Test Acc:0.7166
Epoch: 850. Train Loss: 0.09664. Test Loss: 1.115. Train Acc: 0.9855. Test Acc:0.7251
Epoch: 860. Train Loss: 0.09791. Test Loss: 1.122. Train Acc: 0.985

Epoch: 530. Train Loss: 0.1212. Test Loss: 1.282. Train Acc: 0.9732. Test Acc:0.7238
Epoch: 540. Train Loss: 0.1053. Test Loss: 1.248. Train Acc: 0.9787. Test Acc:0.73
Epoch: 550. Train Loss: 0.09747. Test Loss: 1.26. Train Acc: 0.9816. Test Acc:0.7323
Epoch: 560. Train Loss: 0.09696. Test Loss: 1.267. Train Acc: 0.9804. Test Acc:0.734
Epoch: 570. Train Loss: 0.092. Test Loss: 1.273. Train Acc: 0.9831. Test Acc:0.7314
Epoch: 580. Train Loss: 0.08758. Test Loss: 1.263. Train Acc: 0.9831. Test Acc:0.7318
Epoch: 590. Train Loss: 0.08386. Test Loss: 1.278. Train Acc: 0.9835. Test Acc:0.7327
Epoch: 600. Train Loss: 0.07617. Test Loss: 1.286. Train Acc: 0.9872. Test Acc:0.7358
Epoch: 610. Train Loss: 0.07372. Test Loss: 1.269. Train Acc: 0.9875. Test Acc:0.7354
Epoch: 620. Train Loss: 0.07133. Test Loss: 1.287. Train Acc: 0.9869. Test Acc:0.7407
Epoch: 630. Train Loss: 0.06833. Test Loss: 1.294. Train Acc: 0.9877. Test Acc:0.7367
Epoch: 640. Train Loss: 0.0657. Test Loss: 1.288. Train Acc: 0

Epoch: 290. Train Loss: 0.3336. Test Loss: 1.107. Train Acc: 0.9337. Test Acc:0.7006
Epoch: 300. Train Loss: 0.3119. Test Loss: 1.077. Train Acc: 0.9405. Test Acc:0.7082
Epoch: 310. Train Loss: 0.2942. Test Loss: 1.079. Train Acc: 0.9516. Test Acc:0.7024
Epoch: 320. Train Loss: 0.2908. Test Loss: 1.07. Train Acc: 0.9472. Test Acc:0.7117
Epoch: 330. Train Loss: 0.2847. Test Loss: 1.074. Train Acc: 0.9482. Test Acc:0.7086
Epoch: 340. Train Loss: 0.2778. Test Loss: 1.101. Train Acc: 0.9477. Test Acc:0.7077
Epoch: 350. Train Loss: 0.2614. Test Loss: 1.093. Train Acc: 0.9581. Test Acc:0.7166
Epoch: 360. Train Loss: 0.2503. Test Loss: 1.076. Train Acc: 0.9585. Test Acc:0.7202
Epoch: 370. Train Loss: 0.2414. Test Loss: 1.082. Train Acc: 0.959. Test Acc:0.7207
Epoch: 380. Train Loss: 0.2475. Test Loss: 1.06. Train Acc: 0.9595. Test Acc:0.7153
Epoch: 390. Train Loss: 0.2479. Test Loss: 1.056. Train Acc: 0.9578. Test Acc:0.7282
Epoch: 400. Train Loss: 0.2264. Test Loss: 1.04. Train Acc: 0.9651. 

Epoch: 70. Train Loss: 1.345. Test Loss: 1.844. Train Acc: 0.6555. Test Acc:0.5216
Epoch: 80. Train Loss: 1.183. Test Loss: 1.736. Train Acc: 0.6996. Test Acc:0.548
Epoch: 90. Train Loss: 1.046. Test Loss: 1.614. Train Acc: 0.7363. Test Acc:0.5774
Epoch: 100. Train Loss: 0.9367. Test Loss: 1.531. Train Acc: 0.767. Test Acc:0.6015
Epoch: 110. Train Loss: 0.8439. Test Loss: 1.45. Train Acc: 0.7921. Test Acc:0.6234
Epoch: 120. Train Loss: 0.7864. Test Loss: 1.417. Train Acc: 0.8074. Test Acc:0.627
Epoch: 130. Train Loss: 0.7409. Test Loss: 1.409. Train Acc: 0.8174. Test Acc:0.6381
Epoch: 140. Train Loss: 0.6702. Test Loss: 1.341. Train Acc: 0.8363. Test Acc:0.6573
Epoch: 150. Train Loss: 0.6526. Test Loss: 1.355. Train Acc: 0.836. Test Acc:0.6515
Epoch: 160. Train Loss: 0.6207. Test Loss: 1.341. Train Acc: 0.8438. Test Acc:0.6568
Epoch: 170. Train Loss: 0.5775. Test Loss: 1.303. Train Acc: 0.8571. Test Acc:0.6689
Epoch: 180. Train Loss: 0.5393. Test Loss: 1.291. Train Acc: 0.8669. Test Ac

Epoch: 10. Train Loss: 2.616. Test Loss: 2.755. Train Acc: 0.08043. Test Acc:0.07318
Epoch: 20. Train Loss: 2.013. Test Loss: 2.284. Train Acc: 0.2832. Test Acc:0.2084
Epoch: 30. Train Loss: 1.63. Test Loss: 1.927. Train Acc: 0.4598. Test Acc:0.3597
Epoch: 40. Train Loss: 1.345. Test Loss: 1.65. Train Acc: 0.5788. Test Acc:0.4623
Epoch: 50. Train Loss: 1.148. Test Loss: 1.476. Train Acc: 0.665. Test Acc:0.5399
Epoch: 60. Train Loss: 1.014. Test Loss: 1.388. Train Acc: 0.723. Test Acc:0.5828
Epoch: 70. Train Loss: 0.9176. Test Loss: 1.313. Train Acc: 0.7669. Test Acc:0.6145
Epoch: 80. Train Loss: 0.8322. Test Loss: 1.261. Train Acc: 0.7993. Test Acc:0.6359
Epoch: 90. Train Loss: 0.7722. Test Loss: 1.222. Train Acc: 0.8174. Test Acc:0.6484
Epoch: 100. Train Loss: 0.7214. Test Loss: 1.205. Train Acc: 0.838. Test Acc:0.66
Epoch: 110. Train Loss: 0.6821. Test Loss: 1.174. Train Acc: 0.8456. Test Acc:0.6546
Epoch: 120. Train Loss: 0.6397. Test Loss: 1.141. Train Acc: 0.8629. Test Acc:0.664
E

Epoch: 980. Train Loss: 0.04673. Test Loss: 1.277. Train Acc: 0.9924. Test Acc:0.7122
Epoch: 990. Train Loss: 0.04034. Test Loss: 1.262. Train Acc: 0.9958. Test Acc:0.718
Epoch: 1000. Train Loss: 0.03937. Test Loss: 1.29. Train Acc: 0.9963. Test Acc:0.7193
ASRCNN__conv_pool_max__conv_dp_0.1__fc_dp_0.1__fn_wav2vec_features-c__obj_voice_cmd using wav2vec_features-c on fold#1
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Conv1d-1        [10, 8, 200]           4,104           4,104
          Conv1d-2        [10, 8, 198]             200             200
         Dropout-3        [10, 8, 198]               0               0
       MaxPool1d-4         [10, 8, 99]               0               0
          Conv1d-5        [10, 16, 97]             400             400
         Dropout-6        [10, 16, 97]               0               0
       MaxPool1d-7        [10, 16, 48]               0 

Epoch: 760. Train Loss: 0.02631. Test Loss: 1.314. Train Acc: 0.9975. Test Acc:0.7477
Epoch: 770. Train Loss: 0.02634. Test Loss: 1.254. Train Acc: 0.997. Test Acc:0.7577
Epoch: 780. Train Loss: 0.02709. Test Loss: 1.324. Train Acc: 0.9965. Test Acc:0.7516
Epoch: 790. Train Loss: 0.03205. Test Loss: 1.298. Train Acc: 0.9951. Test Acc:0.7581
Epoch: 800. Train Loss: 0.03374. Test Loss: 1.355. Train Acc: 0.9941. Test Acc:0.7529
Epoch: 810. Train Loss: 0.02352. Test Loss: 1.295. Train Acc: 0.9978. Test Acc:0.7615
Epoch: 820. Train Loss: 0.01802. Test Loss: 1.313. Train Acc: 0.9988. Test Acc:0.7637
Epoch: 830. Train Loss: 0.01691. Test Loss: 1.318. Train Acc: 0.9988. Test Acc:0.7697
Epoch: 840. Train Loss: 0.01419. Test Loss: 1.334. Train Acc: 0.9997. Test Acc:0.7598
Epoch: 850. Train Loss: 0.02019. Test Loss: 1.367. Train Acc: 0.997. Test Acc:0.7598
Epoch: 860. Train Loss: 0.01706. Test Loss: 1.326. Train Acc: 0.9985. Test Acc:0.7572
Epoch: 870. Train Loss: 0.01571. Test Loss: 1.329. Train

Epoch: 520. Train Loss: 0.08711. Test Loss: 1.061. Train Acc: 0.9875. Test Acc:0.737
Epoch: 530. Train Loss: 0.08615. Test Loss: 1.072. Train Acc: 0.9875. Test Acc:0.7322
Epoch: 540. Train Loss: 0.08026. Test Loss: 1.069. Train Acc: 0.987. Test Acc:0.7322
Epoch: 550. Train Loss: 0.07655. Test Loss: 1.064. Train Acc: 0.9911. Test Acc:0.7404
Epoch: 560. Train Loss: 0.08143. Test Loss: 1.086. Train Acc: 0.9854. Test Acc:0.7262
Epoch: 570. Train Loss: 0.07958. Test Loss: 1.078. Train Acc: 0.9887. Test Acc:0.7374
Epoch: 580. Train Loss: 0.0834. Test Loss: 1.096. Train Acc: 0.986. Test Acc:0.7322
Epoch: 590. Train Loss: 0.07429. Test Loss: 1.082. Train Acc: 0.9884. Test Acc:0.7236
Epoch: 600. Train Loss: 0.07568. Test Loss: 1.112. Train Acc: 0.985. Test Acc:0.7365
Epoch: 610. Train Loss: 0.07422. Test Loss: 1.096. Train Acc: 0.9875. Test Acc:0.7395
Epoch: 620. Train Loss: 0.07328. Test Loss: 1.088. Train Acc: 0.9874. Test Acc:0.7456
Epoch: 630. Train Loss: 0.07308. Test Loss: 1.138. Train Ac