In [1]:
from pathlib import Path
import csv
import itertools
import numpy as np
from pprint import PrettyPrinter


import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_model_summary import summary
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

import h5py

import sklearn
import sklearn.metrics

pprint = PrettyPrinter()

## Configuration

In [2]:
SELECTED_LANGUAGES = {'_language_independent', 'francais', 'maninka', 'pular', 'susu'}

BASE_DIR = Path('/media/xtrem/data/datasets/taibou_annotations/2020-08-22_02/guinean_virtual_assistant_speech_recognition/')


ANNOTATIONS_PATH = BASE_DIR/ "annotated_segments" / "metadata.csv"


FEATURE_NAMES = [
    "wav2vec_features-c", 
    "wav2vec_features-z", 
    "retrained-wav2vec_features-c", 
    "retrained-wav2vec_features-z"
]


TRAIN_PERCENT = .8
FOLD_COUNT = 10

CONV_DROPOUT_P = 0.7
FC_DROPOUT_P = 0.5

RESULTS_DIR = f'results_101_gn_lang_classification__conv_dropout_{CONV_DROPOUT_P}__fc_dropout_{FC_DROPOUT_P}'


GPU_ID = 0
MAX_FEATURE_SEQUENCE_LENGTH = 200

## Load & shuffle metadata records

In [3]:
def load_metadata():
    with open(ANNOTATIONS_PATH) as f:
        reader = csv.DictReader(f)
        for r in reader:
            if r['language'] in SELECTED_LANGUAGES:
                yield r

In [4]:
# load
metadata_records = list(load_metadata())

# shuffle
metadata_shuffler_rs = np.random.RandomState(seed=42)
metadata_shuffler_rs.shuffle(metadata_records)

In [5]:
bias_category_fields = [
     "device_id"
    ,"language"
    ,"speaker_gender"
    ,"speaker_mothertongue"
]

bias_categories = {}
for c in bias_category_fields:
    bias_categories[c] = sorted({r[c] for r in metadata_records})


_ = [print(f"\n{k}: \n\t{','.join(v)}") for k,v in bias_categories.items()]


device_id: 
	d001,d002

language: 
	_language_independent,francais,maninka,pular,susu

speaker_gender: 
	F,M

speaker_mothertongue: 
	maninka,pular,susu


### Labels

In [6]:
asr_class_names = sorted({r['label'] for r in metadata_records})
asr_class_count = len(asr_class_names)
asr_class_id_by_name = {c:i for i, c in enumerate(asr_class_names)}

print("Classes")
[print(f"{v:4}: {k}") for k,v in asr_class_id_by_name.items()]

Classes
   0: 101_wake_word__francais
   1: 101_wake_word__maninka
   2: 101_wake_word__pular
   3: 101_wake_word__susu
   4: 201_add_contact__francais
   5: 201_add_contact__maninka
   6: 201_add_contact__pular
   7: 201_add_contact__susu
   8: 202_search_contact__francais
   9: 202_search_contact__maninka
  10: 202_search_contact__pular
  11: 202_search_contact__susu
  12: 203_update_contact__francais
  13: 203_update_contact__maninka
  14: 203_update_contact__pular
  15: 203_update_contact__susu
  16: 204_delete_contact__francais
  17: 204_delete_contact__maninka
  18: 204_delete_contact__pular
  19: 204_delete_contact__susu
  20: 205_call_contact__francais
  21: 205_call_contact__maninka
  22: 205_call_contact__pular
  23: 205_call_contact__susu
  24: 206_yes__francais
  25: 206_yes__maninka
  26: 206_yes__pular
  27: 206_yes__susu
  28: 207_no__francais
  29: 207_no__maninka
  30: 207_no__pular
  31: 207_no__susu
  32: 301_zero__francais
  33: 301_zero__maninka
  34: 301_zero__pul

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

### Inspect metadata

In [7]:
def count_by_attribute(records, attribute_names):
    attribute_name_instances = {}
    for attribute_name in attribute_names:
        attribute_name_instances[attribute_name] = {r[attribute_name] for r in records}
        
    l = [attribute_name_instances[attribute_name] for attribute_name in attribute_names]
    
    
    
    for attribute_values in sorted(itertools.product(*l)):
        
        def record_match(r):
            for i in range(len(attribute_names)):
                if r[attribute_names[i]] != attribute_values[i]:
                    return False
            return True
            
        record_instances = [r for r in records if record_match(r)]
        count = len(record_instances)
        
        yield (attribute_values, count)

In [8]:
print("RECORDS BY DEVICE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['device_id']))]
print("")

print("RECORDS BY LANGUAGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['language']))]
print("")

print("RECORDS BY GENDER")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_gender']))]
print("")

print("RECORDS BY AGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_age']))]
print("")

print("RECORDS BY SPEAKER")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_id']))]
print("")

print("RECORDS BY SPEAKER BY LANGUAGE")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['speaker_id', 'language']))]
print("")

print("RECORDS BY SPEAKER BY LABEL")
_ = [print(f"\t{r}") for r in sorted(count_by_attribute(metadata_records, ['label']))]
print("")

RECORDS BY DEVICE
	(('d001',), 2716)
	(('d002',), 2572)

RECORDS BY LANGUAGE
	(('_language_independent',), 1948)
	(('francais',), 820)
	(('maninka',), 866)
	(('pular',), 606)
	(('susu',), 1048)

RECORDS BY GENDER
	(('F',), 1837)
	(('M',), 3451)

RECORDS BY AGE
	(('12',), 158)
	(('13',), 168)
	(('15',), 402)
	(('17',), 700)
	(('18',), 527)
	(('19',), 176)
	(('20',), 176)
	(('27',), 170)
	(('28',), 122)
	(('29',), 158)
	(('31',), 170)
	(('32',), 194)
	(('33',), 122)
	(('34',), 43)
	(('35',), 332)
	(('37',), 206)
	(('38',), 294)
	(('43',), 122)
	(('44',), 315)
	(('5',), 43)
	(('55',), 122)
	(('61',), 240)
	(('63',), 158)
	(('67',), 170)

RECORDS BY SPEAKER
	(('s001',), 122)
	(('s002',), 86)
	(('s003',), 122)
	(('s004',), 122)
	(('s005',), 86)
	(('s006',), 86)
	(('s007',), 122)
	(('s008',), 43)
	(('s009',), 158)
	(('s010',), 176)
	(('s011',), 43)
	(('s012',), 122)
	(('s013',), 43)
	(('s014',), 158)
	(('s015',), 194)
	(('s016',), 122)
	(('s017',), 158)
	(('s018',), 158)
	(('s019',), 158)
	(

## Prepare Cross Validation Folds
- Partition by (speaker, language)
- Each (speaker, language) correspond to `utterance_count * device_count`
- For each fold, all `utterance_count * device_count` records for the same speaker in the same language are either in the TRAIN or the VALIDATION sets, but not both.


In [9]:
def generate_train_test_records_per_fold(all_records):
    records_per_fold = {}
    
    all_speaker_languages = sorted({(r['speaker_id'], r['language']) for r in all_records})

    sl_count = len(all_speaker_languages)
    all_sl_indices = range(sl_count)
    train_sl_count = int(np.ceil(sl_count*TRAIN_PERCENT))
    test_sl_count = sl_count - train_sl_count

    for fold_index in range(FOLD_COUNT):
        fold_rsampler = np.random.RandomState(seed=fold_index)

        train_sl_index_set = set(fold_rsampler.choice(all_sl_indices, train_sl_count, replace=False))
        train_sl_set = {all_speaker_languages[i] for i in train_sl_index_set}

        test_sl_index_set = set(all_sl_indices).difference(train_sl_index_set)
        test_sl_set = {all_speaker_languages[i] for i in test_sl_index_set}

        train_records = [r for r in all_records if (r['speaker_id'], r['language']) in train_sl_set]
        test_records = [r for r in all_records if (r['speaker_id'], r['language']) in test_sl_set]
        
        
        records_per_fold[fold_index] = {
            "train_records": train_records,
            "test_records": test_records
        }
    
    return records_per_fold

In [10]:
records_per_fold = generate_train_test_records_per_fold(metadata_records)

### Inspect Folds

In [11]:
for fold_index in range(FOLD_COUNT):
    train_records = records_per_fold[fold_index]["train_records"]
    test_records = records_per_fold[fold_index]["test_records"]

    print(f"Fold {fold_index} -- TRAIN")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(train_records, ['speaker_id', 'language'])) if r[1]>0]

    print(f"Fold {fold_index} -- TEST")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(test_records, ['speaker_id', 'language'])) if r[1]>0]
    
    print("---------------------")

Fold 0 -- TRAIN
	(('s001', '_language_independent'), 50)
	(('s001', 'maninka'), 36)
	(('s001', 'susu'), 36)
	(('s002', '_language_independent'), 50)
	(('s002', 'maninka'), 36)
	(('s003', '_language_independent'), 50)
	(('s003', 'maninka'), 36)
	(('s003', 'susu'), 36)
	(('s004', '_language_independent'), 50)
	(('s004', 'susu'), 36)
	(('s005', '_language_independent'), 50)
	(('s006', '_language_independent'), 50)
	(('s006', 'maninka'), 36)
	(('s007', '_language_independent'), 50)
	(('s007', 'pular'), 36)
	(('s007', 'susu'), 36)
	(('s008', '_language_independent'), 25)
	(('s008', 'maninka'), 18)
	(('s009', '_language_independent'), 50)
	(('s009', 'pular'), 36)
	(('s009', 'susu'), 36)
	(('s010', '_language_independent'), 50)
	(('s010', 'francais'), 36)
	(('s010', 'maninka'), 36)
	(('s010', 'pular'), 36)
	(('s010', 'susu'), 18)
	(('s011', '_language_independent'), 25)
	(('s011', 'susu'), 18)
	(('s012', '_language_independent'), 50)
	(('s012', 'francais'), 36)
	(('s012', 'susu'), 36)
	(('s01

	(('s001', 'maninka'), 36)
	(('s001', 'susu'), 36)
	(('s002', 'maninka'), 36)
	(('s003', '_language_independent'), 50)
	(('s003', 'maninka'), 36)
	(('s003', 'susu'), 36)
	(('s004', '_language_independent'), 50)
	(('s004', 'maninka'), 36)
	(('s005', '_language_independent'), 50)
	(('s005', 'susu'), 36)
	(('s006', '_language_independent'), 50)
	(('s007', '_language_independent'), 50)
	(('s007', 'pular'), 36)
	(('s007', 'susu'), 36)
	(('s008', '_language_independent'), 25)
	(('s009', 'susu'), 36)
	(('s010', 'francais'), 36)
	(('s010', 'pular'), 36)
	(('s010', 'susu'), 18)
	(('s011', '_language_independent'), 25)
	(('s011', 'susu'), 18)
	(('s012', '_language_independent'), 50)
	(('s012', 'francais'), 36)
	(('s012', 'susu'), 36)
	(('s013', '_language_independent'), 25)
	(('s013', 'susu'), 18)
	(('s014', '_language_independent'), 50)
	(('s014', 'francais'), 36)
	(('s015', '_language_independent'), 50)
	(('s015', 'maninka'), 36)
	(('s015', 'pular'), 36)
	(('s016', '_language_independent'), 50

	(('s001', '_language_independent'), 50)
	(('s002', '_language_independent'), 50)
	(('s002', 'maninka'), 36)
	(('s003', '_language_independent'), 50)
	(('s003', 'maninka'), 36)
	(('s003', 'susu'), 36)
	(('s004', 'maninka'), 36)
	(('s005', '_language_independent'), 50)
	(('s005', 'susu'), 36)
	(('s006', '_language_independent'), 50)
	(('s006', 'maninka'), 36)
	(('s007', '_language_independent'), 50)
	(('s007', 'pular'), 36)
	(('s007', 'susu'), 36)
	(('s008', '_language_independent'), 25)
	(('s008', 'maninka'), 18)
	(('s009', '_language_independent'), 50)
	(('s009', 'francais'), 36)
	(('s009', 'pular'), 36)
	(('s009', 'susu'), 36)
	(('s010', '_language_independent'), 50)
	(('s010', 'pular'), 36)
	(('s010', 'susu'), 18)
	(('s011', '_language_independent'), 25)
	(('s011', 'susu'), 18)
	(('s012', '_language_independent'), 50)
	(('s012', 'francais'), 36)
	(('s013', '_language_independent'), 25)
	(('s013', 'susu'), 18)
	(('s014', '_language_independent'), 50)
	(('s014', 'francais'), 36)
	(('s

	(('s001', 'maninka'), 36)
	(('s001', 'susu'), 36)
	(('s002', '_language_independent'), 50)
	(('s002', 'maninka'), 36)
	(('s003', '_language_independent'), 50)
	(('s003', 'maninka'), 36)
	(('s003', 'susu'), 36)
	(('s004', '_language_independent'), 50)
	(('s004', 'maninka'), 36)
	(('s004', 'susu'), 36)
	(('s005', '_language_independent'), 50)
	(('s006', '_language_independent'), 50)
	(('s006', 'maninka'), 36)
	(('s007', '_language_independent'), 50)
	(('s007', 'pular'), 36)
	(('s007', 'susu'), 36)
	(('s008', 'maninka'), 18)
	(('s009', '_language_independent'), 50)
	(('s009', 'francais'), 36)
	(('s009', 'susu'), 36)
	(('s010', '_language_independent'), 50)
	(('s010', 'francais'), 36)
	(('s010', 'maninka'), 36)
	(('s010', 'pular'), 36)
	(('s010', 'susu'), 18)
	(('s011', '_language_independent'), 25)
	(('s011', 'susu'), 18)
	(('s012', '_language_independent'), 50)
	(('s012', 'francais'), 36)
	(('s013', '_language_independent'), 25)
	(('s013', 'susu'), 18)
	(('s014', '_language_independent'

In [12]:
for fold_index in range(FOLD_COUNT):
    train_records = records_per_fold[fold_index]["train_records"]
    test_records = records_per_fold[fold_index]["test_records"]

    print(f"Fold {fold_index} -- TRAIN: ({len(train_records)})")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(train_records, ['language'])) if r[1]>0]

    print(f"Fold {fold_index} -- TEST: ({len(test_records)})")
    _ = [print(f"\t{r}") for r in sorted(count_by_attribute(test_records, ['language'])) if r[1]>0]
    
    print("---------------------")

Fold 0 -- TRAIN: (4302)
	(('_language_independent',), 1700)
	(('francais',), 604)
	(('maninka',), 682)
	(('pular',), 568)
	(('susu',), 748)
Fold 0 -- TEST: (986)
	(('_language_independent',), 248)
	(('francais',), 216)
	(('maninka',), 184)
	(('pular',), 38)
	(('susu',), 300)
---------------------
Fold 1 -- TRAIN: (4288)
	(('_language_independent',), 1598)
	(('francais',), 668)
	(('maninka',), 606)
	(('pular',), 494)
	(('susu',), 922)
Fold 1 -- TEST: (1000)
	(('_language_independent',), 350)
	(('francais',), 152)
	(('maninka',), 260)
	(('pular',), 112)
	(('susu',), 126)
---------------------
Fold 2 -- TRAIN: (4297)
	(('_language_independent',), 1573)
	(('francais',), 728)
	(('maninka',), 754)
	(('pular',), 418)
	(('susu',), 824)
Fold 2 -- TEST: (991)
	(('_language_independent',), 375)
	(('francais',), 92)
	(('maninka',), 112)
	(('pular',), 188)
	(('susu',), 224)
---------------------
Fold 3 -- TRAIN: (4140)
	(('_language_independent',), 1398)
	(('francais',), 668)
	(('maninka',), 720)
	

## Load Features

In [13]:
def load_features(records, feature_name):
    features_list = []
    
    features_input_dir = BASE_DIR / feature_name

    for r in records:
        feature_file_name = r['file'].replace(".wav", ".h5context")
        feature_path = Path(features_input_dir) / feature_file_name
        with h5py.File(feature_path, 'r') as f:
            features_shape = f['info'][1:].astype(int)
            features = np.array(f['features'][:]).reshape(features_shape)
            
            padded_features = np.zeros((MAX_FEATURE_SEQUENCE_LENGTH, 512), dtype=features.dtype)
            padded_features[:features_shape[0], :] = features
            
            
            features_list.append(padded_features)
    return features_list

In [14]:
def get_bias_category_labels(records):
    bias_category_labels = {}
    
    for cat in bias_categories:
        for cat_val in bias_categories[cat]:
            bias_category_labels[f"{cat}__{cat_val}"] = [1 if r[cat]==cat_val else 0 for r in records]
            
    return bias_category_labels

# Classification Models

In [15]:
class ASR_CNN1(nn.Module):
    def __init__(self):
        super(ASR_CNN1, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
        self.drop1 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.drop5 = nn.Dropout(p=FC_DROPOUT_P)

        self.lin6 = nn.Linear(in_features=9, out_features=asr_class_count)
        
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        v = self.drop5(v)
        
        x = self.lin6(v)
        
        return v, x
    
    
class ASR_CNN2(nn.Module):
    def __init__(self):
        super(ASR_CNN2, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop1 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=CONV_DROPOUT_P)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.drop5 = nn.Dropout(p=FC_DROPOUT_P)

        self.lin6 = nn.Linear(in_features=9, out_features=asr_class_count)
        
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        v = self.drop5(v)
        
        x = self.lin6(v)
        
        return v, x

In [16]:
def get_data_for_fold(fold_id, feature_name):
    
    train_records = records_per_fold[fold_index]["train_records"]
    test_records = records_per_fold[fold_index]["test_records"]
    
    train_features = load_features(train_records, feature_name)
    test_features = load_features(test_records, feature_name)
    
    train_x = np.array(train_features)
    test_x = np.array(test_features)
    
    train_y = np.array([asr_class_id_by_name[r['label']] for r in train_records])
    test_y = np.array([asr_class_id_by_name[r['label']] for r in test_records])
    
    train_bias_category_labels = get_bias_category_labels(train_records)
    test_bias_category_labels = get_bias_category_labels(test_records)
    
    return train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels

    
def get_loaders_for_fold(fold_id, feature_name, batch_size):
    
    train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels = \
        get_data_for_fold(fold_id, feature_name)
    
    
    
    train_dataset = TensorDataset(
        torch.tensor(train_x), 
        torch.tensor(train_y)
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    test_dataset = TensorDataset(
        torch.tensor(test_x), 
        torch.tensor(test_y)
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader, train_bias_category_labels, test_bias_category_labels


def get_predictions_for_logits(logits):
    probs = F.softmax(logits, dim=1)
    return torch.argmax(probs, dim=1)

In [17]:
def train_on_fold(model_class, fold_id, feature_name, batch_size, epochs, use_contrastive_term):
    device = torch.device(f"cuda:{GPU_ID}")
    torch.manual_seed(0)
    results = {}
    
    model = model_class().to(device)

    train_loader, test_loader, train_bias_category_labels, test_bias_category_labels = get_loaders_for_fold(fold_id, feature_name, batch_size)

    print(summary(model_class(), torch.zeros((10, MAX_FEATURE_SEQUENCE_LENGTH, 512)), show_input=False))

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0
        pred_train_classes = []
        true_train_classes = []

        for batch_idx, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            representations, outputs = model(x)
            pred_train_classes.extend(
                get_predictions_for_logits(outputs).cpu().numpy()
            )
            true_train_classes.extend(y.cpu().numpy())
            loss = criterion(outputs, y)
            if (use_contrastive_term):
                loss = 0.5*loss + 0.5*contrastive_loss(representations, y)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_n = len(true_train_classes)
        
        train_loss = train_loss/len(train_loader)
        train_acc = sklearn.metrics.accuracy_score(true_train_classes, pred_train_classes)
        train_acc_by_bais_category = {
            category: sklearn.metrics.accuracy_score(true_train_classes, pred_train_classes, sample_weight=sw)
            for category, sw in train_bias_category_labels.items()
        }
        

        pred_test_logits = []
        
        pred_test_classes = []
        true_test_classes = []
        
        model.eval()
        test_loss = 0
        for batch_idx, (x, y) in enumerate(test_loader):
            x = x.to(device)
            y = y.to(device)
            
            representations, outputs = model(x)
            
            pred_test_logits.extend(outputs.detach().cpu().numpy())

            pred_test_classes.extend(
                get_predictions_for_logits(outputs).cpu().numpy()
            )

            true_test_classes.extend(y.cpu().numpy())

            loss = criterion(outputs, y)
            if (use_contrastive_term):
                loss += contrastive_loss(representations, y)
                
            test_loss += loss.item()


        test_n = len(true_test_classes)
        
        test_loss = test_loss / len(test_loader)
        test_acc = sklearn.metrics.accuracy_score(true_test_classes, pred_test_classes)
        test_acc_by_bais_category = {
            category: sklearn.metrics.accuracy_score(true_test_classes, pred_test_classes, sample_weight=sw)
            for category, sw in test_bias_category_labels.items()
        }

        if epoch%10==0:
            print(f"Epoch: {epoch}. Train Loss: {train_loss:0.4}. Test Loss: {test_loss:0.4}. Train Acc: {train_acc:0.4}. Test Acc:{test_acc:0.4}")
        
        results[epoch] = {
            'epoch': epoch,
            'train_loss': train_loss,
            'test_loss': test_loss,
            'train_acc': train_acc,
            'test_acc': test_acc,
            'train_n': train_n,
            'test_n': test_n,
            'test_logits': pred_test_logits,
            'test_true_classes': true_test_classes
        }
        
        for c in train_acc_by_bais_category:
            results[epoch][f"train_acc_{c}"] = train_acc_by_bais_category[c]
            results[epoch][f"train_n_{c}"] = int(np.sum(train_bias_category_labels[c]))
            
        for c in test_acc_by_bais_category:
            results[epoch][f"test_acc_{c}"] = test_acc_by_bais_category[c]
            results[epoch][f"test_n_{c}"] = int(np.sum(test_bias_category_labels[c]))
        
    del model
    return results

In [18]:
import csv
from pathlib import Path

def write_epoch_test_logits(model_name, all_folds_results):
    for result_entry in all_folds_results:
        feature_name = result_entry['feature_name']
        fold_index = result_entry['fold_index']
        
        test_records = records_per_fold[fold_index]['test_records']
        
        for epoch in sorted(result_entry['epochs'].keys()):
            
            parent_dir = Path(f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}_data")
            parent_dir.mkdir(parents=True, exist_ok=True)

            file_name = parent_dir / f"epoch_{epoch:04}.csv"
            with open(file_name, "w") as f:
                writer = csv.DictWriter(f, fieldnames=["fold_id", "datum_index", "datum_name", "true_class_id", "logits"])
                writer.writeheader()
                
                test_logits = result_entry['epochs'][epoch]["test_logits"]
                test_true_classes = result_entry['epochs'][epoch]["test_true_classes"]
                
                for datum_index in range(len(test_logits)):
                    writer.writerow({
                        "fold_id": fold_index,
                        "datum_index": datum_index,
                        "datum_name": test_records[datum_index]['file'],
                        "true_class_id": test_true_classes[datum_index],
                        "logits": test_logits[datum_index]
                    })
                

            field_names = ["index", "fname", "logits"]

            fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}.csv"

def save_results(model_name, all_folds_results):
    for result_entry in all_folds_results:
        feature_name = result_entry['feature_name']
        fold_index = result_entry['fold_index']
        
        Path(RESULTS_DIR).mkdir(exist_ok=True, parents=True)
        fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}.csv"
        Path(fname).parent.mkdir(parents=True, exist_ok=True)
        with open(fname, 'w') as f:
            fieldnames = sorted(result_entry['epochs'][1].keys())
            fieldnames.remove("test_logits") # logged separately, differently
            fieldnames.remove("test_true_classes") # logged separately, differently
            
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
            
            writer.writeheader()
            
            for epoch in sorted(result_entry['epochs'].keys()):
                writer.writerow(result_entry['epochs'][epoch])

In [None]:
model_classes = [
    ASR_CNN1,
    ASR_CNN2,
]

for model_class in model_classes:
    all_folds_results = []
    for fold_index in range(FOLD_COUNT):
        for feature_name in FEATURE_NAMES:
            use_contrastive_term = "Contrastive" in model_class.__name__
            
            print(f"{model_class.__name__} using {feature_name} on fold#{fold_index}. Contrastive: {use_contrastive_term}")
            resutls = train_on_fold(model_class, fold_index, feature_name, batch_size=250, epochs=1000, use_contrastive_term=use_contrastive_term)
            all_folds_results.append({
                'fold_index': fold_index,
                'feature_name': feature_name,
                'epochs': resutls
            })
            save_results(model_class.__name__, all_folds_results)
            # write_epoch_test_logits(model_class.__name__, all_folds_results)

ASR_CNN1 using wav2vec_features-c on fold#0. Contrastive: False
