In [1]:
from pathlib import Path
import csv
from itertools import groupby
import h5py
import numpy as np
import sklearn
from sklearn.cluster import KMeans
from sklearn.svm import SVC
# from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import random
import matplotlib
from matplotlib import pyplot as plt

# Configuration & Utilities

In [2]:
random.seed(42)
COLORS = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
ANNOTATIONS_PATH = "/media/xtrem/data/experiments/nicolingua-0001-language-id/language-id-annotations/metadata.csv"
FEATURE_DIRS = [
    '/media/xtrem/data/experiments/nicolingua-0001-language-id/wav2vec_features-c',
    '/media/xtrem/data/experiments/nicolingua-0001-language-id/wav2vec_features-z',
    '/media/xtrem/data/experiments/nicolingua-0001-language-id/retrained-wav2vec_features-c',
    '/media/xtrem/data/experiments/nicolingua-0001-language-id/retrained-wav2vec_features-z'
]
# RESULTS_DIR = 'results_005_lang_id_classification'
RESULTS_DIR = 'E007/results_007_gn_lang_classification_contrastive'
GPU_ID = 1

In [3]:
annotation_specification = {
    0: {
        'id': 0,
        'label': "maninka",
        'required_tags': set(['ct-speech', 'lng-maninka']),
        'forbidden_tags':  set(['lng-susu', 'lng-pular'])
    },
    1: {
        'id': 1,
        'label': "susu",
        'required_tags': set(['ct-speech', 'lng-susu']),
        'forbidden_tags':  set(['lng-maninka', 'lng-pular'])
    },
    2: {
        'id': 2,
        'label': "pular",
        'required_tags': set(['ct-speech', 'lng-pular']),
        'forbidden_tags':  set(['lng-susu', 'lng-maninka'])
    }
}

In [4]:
bias_category_specification = [
    {
        "category": "utterance",
        "subcategories": [
            {
                "subcategory": "verbal_nod",
                "tags": set(["utt-verbal-nod"])
            },
            {
                "subcategory": "multilingual",
                "tags": set(["utt-multi-lingual", "utt-multi-lingual-named-endity"])
            }
        ]
    },
    {
        "category": "speaker_count",
        "subcategories": [
            {
                "subcategory": "single",
                "tags": set(["spkr-single"])
            },
            {
                "subcategory": "multiple",
                "tags": set(["spkr-mult", "spkr-multi"])
            }
        ]
    },
    {
        "category": "gender",
        "subcategories": [
            {
                "subcategory": "male",
                "tags": set(["spkr-male"])
            },
            {
                "subcategory": "female",
                "tags": set(["spkr-female"])
            },
        ]
    },
    {
        "category": "language",
        "subcategories": [
            {
                "subcategory": "susu",
                "tags": set(["lng-susu"])
            },
            {
                "subcategory": "maninka",
                "tags": set(["lng-maninka"])
            },
            {
                "subcategory": "pular",
                "tags": set(["lng-pular"])
            }
        ]
    },
    {
        "category": "channel",
        "subcategories": [
            {
                "subcategory": "telephone",
                "tags": set(["ct-telephone"])
            },
            {
                "subcategory": "noise",
                "tags": set(["ct-noise"])
            },
            {
                "subcategory": "music",
                "tags": set(["ct-fg-music", "ct-tr-music", "ct-bg-music"])
            }
        ]
    }
]

flat_bias_category_specification = {}

for c in bias_category_specification:
    for sc in c['subcategories']:
        k = f"{c['category']}_{sc['subcategory']}"
        flat_bias_category_specification[k] = sc['tags']

_ = [print(f"{k}: {v}") for k,v in flat_bias_category_specification.items()]

utterance_verbal_nod: {'utt-verbal-nod'}
utterance_multilingual: {'utt-multi-lingual-named-endity', 'utt-multi-lingual'}
speaker_count_single: {'spkr-single'}
speaker_count_multiple: {'spkr-multi', 'spkr-mult'}
gender_male: {'spkr-male'}
gender_female: {'spkr-female'}
language_susu: {'lng-susu'}
language_maninka: {'lng-maninka'}
language_pular: {'lng-pular'}
channel_telephone: {'ct-telephone'}
channel_noise: {'ct-noise'}
channel_music: {'ct-tr-music', 'ct-bg-music', 'ct-fg-music'}


In [5]:
def to_user_friendly_feature_name(fv_name):
    name = fv_name \
        .replace("features-", "") \
        .replace("wav2vec_", "") \
        .replace("average", "avg") \
        .replace("timestep", "T") \
        .replace("c.", "Context") \
        .replace("z.", "Latent")
    return name

# Load annotations

In [6]:
def load_annotations(a_file_path, a_specification):
    with open(ANNOTATIONS_PATH) as f:
        reader = csv.DictReader(f)
        for row in reader:
            tag_set = set([t.strip() for t in row['tags'].split(";")])
            for label in annotation_specification.keys():
                spec = annotation_specification[label]
                if spec['required_tags'].issubset(tag_set):
                    if spec['forbidden_tags'].isdisjoint(tag_set):
                        yield row['file'], label, tag_set
                        break

data = list(load_annotations(ANNOTATIONS_PATH, annotation_specification))
audio_files, audio_labels, audio_tags = zip(*data)

## Inspect label counts

In [7]:
def inspect_label_counts():
    for label in annotation_specification:
        count = len([l for l in audio_labels if l == label])
        print("{:10} ({}): {}".format(
            annotation_specification[label]['label'],
            label, 
            count
        ))
inspect_label_counts()

maninka    (0): 114
susu       (1): 32
pular      (2): 28


## Balance data

In [8]:
count_per_class = 28
data = list(load_annotations(ANNOTATIONS_PATH, annotation_specification))
balanced_data = []
for label in annotation_specification:
    balanced_data.extend([d for d in data if d[1] == label][:count_per_class])
audio_files, audio_labels, audio_tags = zip(*balanced_data)

In [9]:
inspect_label_counts()

maninka    (0): 28
susu       (1): 28
pular      (2): 28


## Inspect bias category counts in balanded data

In [10]:
def inspect_bias_category_counts():
    for name, tags in flat_bias_category_specification.items():
        count = len([ts for ts in audio_tags if len(tags.intersection(ts))>0])
        print(name, count, "/", len(audio_tags))
            
inspect_bias_category_counts()

utterance_verbal_nod 48 / 84
utterance_multilingual 21 / 84
speaker_count_single 25 / 84
speaker_count_multiple 58 / 84
gender_male 81 / 84
gender_female 15 / 84
language_susu 28 / 84
language_maninka 28 / 84
language_pular 28 / 84
channel_telephone 27 / 84
channel_noise 21 / 84
channel_music 24 / 84


# Prepare 10 cross validation folds

In [11]:
TRAIN_PERCENT = .6
FOLD_COUNT = 10

n = len(audio_files)
n_train = int(np.ceil(n * .6))
n_test = n - n_train
all_indices = range(n)

cv_folds = {}
train_count_by_index = {i:0 for i in all_indices}
test_count_by_index = {i:0 for i in all_indices}

for fold_index in range(FOLD_COUNT):
    fold_rsampler = np.random.RandomState(seed=fold_index)
    train_index_set = set(fold_rsampler.choice(all_indices, n_train, replace=False))
    test_index_set = set(all_indices).difference(train_index_set)
        
    cv_folds[fold_index] = {
        'train_indices': sorted(list(train_index_set)),
        'test_indices': sorted(list(test_index_set)),
    }


# Load features

In [12]:
def load_features(audio_files, features_input_dir):
    id_list = []
    features_list = []

    for audio_file_name in audio_files:
        feature_file_name = audio_file_name.replace(".wav", ".h5context")
        feature_path = Path(features_input_dir) / feature_file_name
        with h5py.File(feature_path, 'r') as f:
            features_shape = f['info'][1:].astype(int)
            features = np.array(f['features'][:]).reshape(features_shape)
            # features = pool_feature_last_seq(features)
            features_list.append(features)
    return features_list

In [13]:
raw_features = {}
for feature_dir in FEATURE_DIRS:
    feature_name = Path(feature_dir).stem
    raw_features[feature_name] = load_features(audio_files, feature_dir)

## Inspect feature shapes

In [14]:
for feature_name in raw_features.keys():
    print("feature_name: {}. feature shape: {}".format(
        to_user_friendly_feature_name(feature_name),
        raw_features[feature_name][0].shape
    ))

feature_name: c. feature shape: (2998, 512)
feature_name: z. feature shape: (2998, 512)
feature_name: retrained-c. feature shape: (2998, 512)
feature_name: retrained-z. feature shape: (2998, 512)


## Extract feature vectors

In [15]:
def extract_last_timestep_features(raw_features):
    return raw_features[-1, :]

def extract_neuron_average_features(raw_features):
    return np.mean(raw_features, axis=0)

def identity(x):
    return x

feature_extractors = {
    'last_timestep': extract_last_timestep_features,
    'neuron_average': extract_neuron_average_features,
    'raw_features': identity
}

In [16]:
feature_vectors = {}
for feature_name in raw_features.keys():
    for feature_extractor_name in feature_extractors.keys():
        fv_name = f"{feature_name}__{feature_extractor_name}"
        feature_vectors[fv_name] = []
        for f in raw_features[feature_name]:
            feature_vectors[fv_name].append(
                feature_extractors[feature_extractor_name](f)
            )
            
        feature_vectors[fv_name] = np.array(feature_vectors[fv_name])

## Inspect feature vectors

In [17]:
for fvname in feature_vectors.keys():
    print(fvname, feature_vectors[fvname][0].shape)

wav2vec_features-c__last_timestep (512,)
wav2vec_features-c__neuron_average (512,)
wav2vec_features-c__raw_features (2998, 512)
wav2vec_features-z__last_timestep (512,)
wav2vec_features-z__neuron_average (512,)
wav2vec_features-z__raw_features (2998, 512)
retrained-wav2vec_features-c__last_timestep (512,)
retrained-wav2vec_features-c__neuron_average (512,)
retrained-wav2vec_features-c__raw_features (2998, 512)
retrained-wav2vec_features-z__last_timestep (512,)
retrained-wav2vec_features-z__neuron_average (512,)
retrained-wav2vec_features-z__raw_features (2998, 512)


# Classification Models

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_model_summary import summary
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

In [19]:
dropout_p = 0.7
    
class LangIdCNN_Mean2(nn.Module):
    def __init__(self):
        super(LangIdCNN_Mean2, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
        self.drop1 = nn.Dropout(p=dropout_p)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=dropout_p)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=dropout_p)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=dropout_p)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.lin5 = nn.Linear(in_features=9, out_features=3)
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        
        x = self.lin5(v)
        
        return v, x
    
class LangIdCNN_Mean2_Contrastive(nn.Module):
    def __init__(self):
        super(LangIdCNN_Mean2_Contrastive, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=1, kernel_size=3)
        self.drop1 = nn.Dropout(p=dropout_p)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=dropout_p)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=dropout_p)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=dropout_p)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.lin5 = nn.Linear(in_features=9, out_features=3)
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        
        x = self.lin5(v)
        
        return v, x
    
class LangIdCNN_Mean3(nn.Module):
    def __init__(self):
        super(LangIdCNN_Mean3, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop1 = nn.Dropout(p=dropout_p)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=dropout_p)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=dropout_p)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=dropout_p)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.lin5 = nn.Linear(in_features=9, out_features=3)
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        
        x = self.lin5(v)
        
        return v, x
    

class LangIdCNN_Mean3_Contrastive(nn.Module):
    def __init__(self):
        super(LangIdCNN_Mean3_Contrastive, self).__init__()
        
        self.conv0 = nn.Conv1d(in_channels=512, out_channels=3, kernel_size=1)
        
        self.conv1 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop1 = nn.Dropout(p=dropout_p)
        self.pool1 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop2 = nn.Dropout(p=dropout_p)
        self.pool2 = nn.AvgPool1d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop3 = nn.Dropout(p=dropout_p)
        self.pool3 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.conv4 = nn.Conv1d(in_channels=3, out_channels=3, kernel_size=3)
        self.drop4 = nn.Dropout(p=dropout_p)
        self.pool4 = nn.AvgPool1d(kernel_size=2, stride=2)

        self.lin5 = nn.Linear(in_features=9, out_features=3)
        
    def forward(self, x):
        x = x.permute(0, 2, 1)
        
        x = self.conv0(x)
        
        x = self.conv1(x)
        x = F.elu(x)
        x = self.drop1(x)
        x = self.pool1(x)
        
        
        x = self.conv2(x)
        x = F.elu(x)
        x = self.drop2(x)
        x = self.pool2(x)
        
        v1 = torch.mean(x, dim=2)
        
        x = self.conv3(x)
        x = F.elu(x)
        x = self.drop3(x)
        x = self.pool3(x)
        
        v2 = torch.mean(x, dim=2)
        
        x = self.conv4(x)
        x = F.elu(x)
        x = self.drop4(x)
        x = self.pool4(x)
        
        v3 = torch.mean(x, dim=2)
        
        v = torch.cat((v1, v2, v3), axis=1)
        
        x = self.lin5(v)
        
        return v, x

# Model classes and their contrastive version are identical. The Contrastive part of the name is a marker
# LangIdCNN_Mean3_Contrastive = LangIdCNN_Mean3

# LangIdCNN_Mean2_Contrastive = LangIdCNN_Mean2

# Train Classification Models

In [20]:
def get_data_for_fold(fold_id, feature_name, batch_size):
    train_indices = cv_folds[fold_id]['train_indices']
    test_indices = cv_folds[fold_id]['test_indices']    

    train_x = np.take(raw_features[feature_name], train_indices, axis=0)
    train_y = np.take(audio_labels, train_indices, axis=0)
    train_tags = np.take(audio_tags, train_indices, axis=0)
    train_bias_category_labels = {}
    for category, tags in flat_bias_category_specification.items():
        train_bias_category_labels[category] = [1 if len(tags.intersection(ts))>0 else 0 for ts in train_tags]

    test_x = np.take(raw_features[feature_name], test_indices, axis=0)
    test_y = np.take(audio_labels, test_indices, axis=0)
    test_tags = np.take(audio_tags, test_indices, axis=0)
    test_bias_category_labels = {}
    for category, tags in flat_bias_category_specification.items():
        test_bias_category_labels[category] = [1 if len(tags.intersection(ts))>0 else 0 for ts in test_tags]
        
    
    return train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels

def get_audio_files_for_fold(fold_id):
    train_indices = cv_folds[fold_id]['train_indices']
    test_indices = cv_folds[fold_id]['test_indices']
    
    train_files = np.take(audio_files, train_indices, axis=0)
    test_files = np.take(audio_files, test_indices, axis=0)
    
    return train_files, test_files
    

    
def get_loaders_for_fold(fold_id, feature_name, batch_size):
    
    train_x, train_y, test_x, test_y, train_bias_category_labels, test_bias_category_labels = \
        get_data_for_fold(fold_id, feature_name, batch_size)
    
    
    
    train_dataset = TensorDataset(
        torch.tensor(train_x), 
        torch.tensor(train_y)
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size)

    test_dataset = TensorDataset(
        torch.tensor(test_x), 
        torch.tensor(test_y)
    )

    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    return train_loader, test_loader, train_bias_category_labels, test_bias_category_labels


def get_predictions_for_logits(logits):
    probs = F.softmax(logits, dim=1)
    return torch.argmax(probs, dim=1)

In [21]:
def contrastive_loss(v, y):
    print(v)
    print(y)
    return 10

In [22]:
def train_on_fold(model_class, fold_id, feature_name, batch_size, epochs, use_contrastive_term):
    device = torch.device(f"cuda:{GPU_ID}")
    torch.manual_seed(0)
    results = {}
    
    model = model_class().to(device)

    train_loader, test_loader, train_bias_category_labels, test_bias_category_labels = get_loaders_for_fold(fold_id, feature_name, batch_size)

    print(summary(model_class(), torch.zeros((10, 2998, 512)), show_input=False))

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(1, epochs+1):
        model.train()
        train_loss = 0
        pred_train_classes = []
        true_train_classes = []

        for batch_idx, (x, y) in enumerate(train_loader):
            x = x.to(device)
            y = y.to(device)
            
            optimizer.zero_grad()
            representations, outputs = model(x)
            pred_train_classes.extend(
                get_predictions_for_logits(outputs).cpu().numpy()
            )
            true_train_classes.extend(y.cpu().numpy())
            loss = criterion(outputs, y)
            if (use_contrastive_term):
                loss += contrastive_loss(representations, y)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_n = len(true_train_classes)
        
        train_loss = train_loss/len(train_loader)
        train_acc = sklearn.metrics.accuracy_score(true_train_classes, pred_train_classes)
        train_acc_by_bais_category = {
            category: sklearn.metrics.accuracy_score(true_train_classes, pred_train_classes, sample_weight=sw)
            for category, sw in train_bias_category_labels.items()
        }
        

        pred_test_logits = []
        
        pred_test_classes = []
        true_test_classes = []
        
        model.eval()
        test_loss = 0
        for batch_idx, (x, y) in enumerate(test_loader):
            x = x.to(device)
            y = y.to(device)
            
            representations, outputs = model(x)
            
            pred_test_logits.extend(outputs.detach().cpu().numpy())

            pred_test_classes.extend(
                get_predictions_for_logits(outputs).cpu().numpy()
            )

            true_test_classes.extend(y.cpu().numpy())

            loss = criterion(outputs, y)
            if (use_contrastive_term):
                loss += contrastive_loss(representations, y)
                
            test_loss += loss.item()


        test_n = len(true_test_classes)
        
        test_loss = test_loss / len(test_loader)
        test_acc = sklearn.metrics.accuracy_score(true_test_classes, pred_test_classes)
        test_acc_by_bais_category = {
            category: sklearn.metrics.accuracy_score(true_test_classes, pred_test_classes, sample_weight=sw)
            for category, sw in test_bias_category_labels.items()
        }

        if epoch%10==0:
            print(f"Epoch: {epoch}. Train Loss: {train_loss:0.4}. Test Loss: {test_loss:0.4}. Train Acc: {train_acc:0.4}. Test Acc:{test_acc:0.4}")
        
        results[epoch] = {
            'epoch': epoch,
            'train_loss': train_loss,
            'test_loss': test_loss,
            'train_acc': train_acc,
            'test_acc': test_acc,
            'train_n': train_n,
            'test_n': test_n,
            'test_logits': pred_test_logits,
            'test_true_classes': true_test_classes
        }
        
        for c in train_acc_by_bais_category:
            results[epoch][f"train_acc_{c}"] = train_acc_by_bais_category[c]
            results[epoch][f"train_n_{c}"] = int(np.sum(train_bias_category_labels[c]))
            
        for c in test_acc_by_bais_category:
            results[epoch][f"test_acc_{c}"] = test_acc_by_bais_category[c]
            results[epoch][f"test_n_{c}"] = int(np.sum(test_bias_category_labels[c]))
        
    del model
    return results
    

In [23]:
import csv
from pathlib import Path

def write_epoch_test_logits(model_name, all_folds_results):
    for result_entry in all_folds_results:
        feature_name = result_entry['feature_name']
        fold_index = result_entry['fold_index']
        
        train_audio_files, test_audio_files = get_audio_files_for_fold(fold_index)
        
        for epoch in sorted(result_entry['epochs'].keys()):
            
            parent_dir = Path(f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}_data")
            parent_dir.mkdir(parents=True, exist_ok=True)

            file_name = parent_dir / f"epoch_{epoch:04}.csv"
            with open(file_name, "w") as f:
                writer = csv.DictWriter(f, fieldnames=["fold_id", "datum_index", "datum_name", "true_class_id", "logits"])
                writer.writeheader()
                
                test_logits = result_entry['epochs'][epoch]["test_logits"]
                test_true_classes = result_entry['epochs'][epoch]["test_true_classes"]
                
                for datum_index in range(len(test_logits)):
                    writer.writerow({
                        "fold_id": fold_index,
                        "datum_index": datum_index,
                        "datum_name": test_audio_files[datum_index],
                        "true_class_id": test_true_classes[datum_index],
                        "logits": test_logits[datum_index]
                    })
                

            annotation_specification.keys()

            field_names = ["index", "fname", "logits"]

            fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}.csv"

def save_results(model_name, all_folds_results):
    for result_entry in all_folds_results:
        feature_name = result_entry['feature_name']
        fold_index = result_entry['fold_index']
        
        Path(RESULTS_DIR).mkdir(exist_ok=True)
        fname = f"{RESULTS_DIR}/{model_name}/{feature_name}_{fold_index}.csv"
        Path(fname).parent.mkdir(parents=True, exist_ok=True)
        with open(fname, 'w') as f:
            fieldnames = sorted(result_entry['epochs'][1].keys())
            fieldnames.remove("test_logits") # logged separately, differently
            fieldnames.remove("test_true_classes") # logged separately, differently
            
            writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
            
            writer.writeheader()
            
            for epoch in sorted(result_entry['epochs'].keys()):
                writer.writerow(result_entry['epochs'][epoch])

In [24]:
model_classes = [
    LangIdCNN_Mean3_Contrastive,
    LangIdCNN_Mean3,
    LangIdCNN_Mean2_Contrastive,
    LangIdCNN_Mean2,
]

for model_class in model_classes:
    all_folds_results = []
    for fold_index in cv_folds:
        for feature_name in raw_features:
            use_contrastive_term = "Contrastive" in model_class.__name__
            
            print(f"{model_class.__name__} using {feature_name} on fold#{fold_index}. Contrastive: {use_contrastive_term}")
            resutls = train_on_fold(model_class, fold_index, feature_name, batch_size=100, epochs=1000, use_contrastive_term=use_contrastive_term)
            all_folds_results.append({
                'fold_index': fold_index,
                'feature_name': feature_name,
                'epochs': resutls
            })
            save_results(model_class.__name__, all_folds_results)
            write_epoch_test_logits(model_class.__name__, all_folds_results)

LangIdCNN_Mean3_Contrastive using wav2vec_features-c on fold#0. Contrastive: True
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Conv1d-1       [10, 3, 2998]           1,539           1,539
          Conv1d-2       [10, 3, 2996]              30              30
         Dropout-3       [10, 3, 2996]               0               0
       AvgPool1d-4       [10, 3, 1498]               0               0
          Conv1d-5       [10, 3, 1496]              30              30
         Dropout-6       [10, 3, 1496]               0               0
       AvgPool1d-7        [10, 3, 748]               0               0
          Conv1d-8        [10, 3, 746]              30              30
         Dropout-9        [10, 3, 746]               0               0
      AvgPool1d-10        [10, 3, 373]               0               0
         Conv1d-11        [10, 3, 371]              30           

tensor([[-0.1751,  0.0169,  0.0305,  0.2331,  0.1546,  0.2501,  0.0991, -0.1412,
          0.1426],
        [-0.1868,  0.0165,  0.0495,  0.2340,  0.1565,  0.2676,  0.1026, -0.1405,
          0.1454],
        [-0.1770,  0.0186,  0.0326,  0.2334,  0.1549,  0.2512,  0.0992, -0.1410,
          0.1426],
        [-0.1770,  0.0180,  0.0329,  0.2334,  0.1549,  0.2519,  0.0994, -0.1410,
          0.1428],
        [-0.1809,  0.0147,  0.0411,  0.2333,  0.1556,  0.2606,  0.1011, -0.1408,
          0.1443],
        [-0.1785,  0.0221,  0.0325,  0.2339,  0.1550,  0.2498,  0.0990, -0.1411,
          0.1422],
        [-0.1794,  0.0123,  0.0404,  0.2330,  0.1554,  0.2611,  0.1012, -0.1407,
          0.1445],
        [-0.1797,  0.0225,  0.0342,  0.2340,  0.1552,  0.2511,  0.0993, -0.1411,
          0.1424],
        [-0.1787,  0.0184,  0.0352,  0.2335,  0.1552,  0.2539,  0.0998, -0.1410,
          0.1431],
        [-0.1759,  0.0189,  0.0305,  0.2333,  0.1547,  0.2493,  0.0989, -0.1411,
          0.1423],


tensor([[-0.1691,  0.0190,  0.0113,  0.2308,  0.1550,  0.2316,  0.0923, -0.1460,
          0.1368],
        [-0.1812,  0.0185,  0.0315,  0.2317,  0.1571,  0.2501,  0.0960, -0.1455,
          0.1398],
        [-0.1710,  0.0207,  0.0134,  0.2312,  0.1553,  0.2326,  0.0925, -0.1459,
          0.1368],
        [-0.1711,  0.0201,  0.0138,  0.2311,  0.1553,  0.2334,  0.0927, -0.1459,
          0.1370],
        [-0.1762,  0.0160,  0.0247,  0.2311,  0.1563,  0.2450,  0.0949, -0.1457,
          0.1391],
        [-0.1731,  0.0237,  0.0145,  0.2316,  0.1556,  0.2326,  0.0925, -0.1459,
          0.1367],
        [-0.1737,  0.0145,  0.0217,  0.2307,  0.1559,  0.2429,  0.0945, -0.1457,
          0.1388],
        [-0.1743,  0.0241,  0.0162,  0.2317,  0.1558,  0.2339,  0.0928, -0.1460,
          0.1369],
        [-0.1735,  0.0200,  0.0175,  0.2313,  0.1557,  0.2368,  0.0933, -0.1459,
          0.1376],
        [-0.1702,  0.0208,  0.0118,  0.2311,  0.1552,  0.2313,  0.0923, -0.1460,
          0.1367],


tensor([[-0.1635,  0.0205, -0.0081,  0.2286,  0.1554,  0.2135,  0.0858, -0.1509,
          0.1313],
        [-0.1761,  0.0200,  0.0135,  0.2295,  0.1576,  0.2330,  0.0896, -0.1504,
          0.1344],
        [-0.1654,  0.0223, -0.0059,  0.2290,  0.1557,  0.2145,  0.0859, -0.1507,
          0.1313],
        [-0.1656,  0.0217, -0.0055,  0.2289,  0.1557,  0.2153,  0.0862, -0.1508,
          0.1315],
        [-0.1718,  0.0170,  0.0082,  0.2289,  0.1570,  0.2296,  0.0889, -0.1505,
          0.1340],
        [-0.1681,  0.0249, -0.0035,  0.2294,  0.1561,  0.2158,  0.0863, -0.1508,
          0.1314],
        [-0.1683,  0.0162,  0.0028,  0.2285,  0.1563,  0.2251,  0.0880, -0.1505,
          0.1333],
        [-0.1693,  0.0252, -0.0019,  0.2295,  0.1563,  0.2170,  0.0865, -0.1508,
          0.1316],
        [-0.1685,  0.0212, -0.0004,  0.2291,  0.1563,  0.2201,  0.0871, -0.1507,
          0.1323],
        [-0.1649,  0.0221, -0.0070,  0.2289,  0.1556,  0.2137,  0.0859, -0.1509,
          0.1312],


tensor([[-0.1569,  0.0284,  0.0017,  0.2561,  0.1623,  0.2108,  0.1161, -0.1343,
          0.1189],
        [-0.1520,  0.0239,  0.0007,  0.2205,  0.1729,  0.2254,  0.1120, -0.1648,
          0.1439],
        [-0.1715,  0.0215, -0.0040,  0.2204,  0.1525,  0.1804,  0.0805, -0.1534,
          0.1074],
        [-0.1639,  0.0205, -0.0193,  0.2256,  0.1378,  0.1934,  0.0851, -0.1317,
          0.0889],
        [-0.1580,  0.0144, -0.0320,  0.2195,  0.1512,  0.1976,  0.1027, -0.1648,
          0.1260],
        [-0.1626,  0.0179, -0.0054,  0.2385,  0.1503,  0.2146,  0.0601, -0.1657,
          0.1440],
        [-0.1529,  0.0224, -0.0241,  0.2276,  0.1388,  0.1799,  0.0662, -0.1151,
          0.1369],
        [-0.1648,  0.0220, -0.0112,  0.2258,  0.1424,  0.1986,  0.0781, -0.1269,
          0.1300],
        [-0.1610,  0.0225, -0.0095,  0.2350,  0.1489,  0.2104,  0.0583, -0.1566,
          0.1499],
        [-0.1445,  0.0267, -0.0112,  0.2325,  0.1459,  0.1710,  0.0808, -0.1276,
          0.1390],


tensor([[-0.1476,  0.0248, -0.0154,  0.2004,  0.1673,  0.2008,  0.0782, -0.1638,
          0.1352],
        [-0.1513,  0.0314, -0.0411,  0.2205,  0.1440,  0.1816,  0.0767, -0.1969,
          0.1420],
        [-0.1642,  0.0250, -0.0124,  0.2133,  0.1684,  0.1957,  0.0657, -0.1548,
          0.1501],
        [-0.1405,  0.0429, -0.0432,  0.2143,  0.1425,  0.1617,  0.0675, -0.1404,
          0.1192],
        [-0.1528,  0.0272, -0.0135,  0.2084,  0.1574,  0.2092,  0.0761, -0.1640,
          0.1216],
        [-0.1504,  0.0213, -0.0288,  0.2206,  0.1575,  0.1807,  0.0799, -0.1637,
          0.1310],
        [-0.1359,  0.0168, -0.0348,  0.2400,  0.1647,  0.1772,  0.0924, -0.1281,
          0.1055],
        [-0.1518,  0.0369, -0.0342,  0.2175,  0.1617,  0.1835,  0.0657, -0.1700,
          0.1262],
        [-0.1572,  0.0253, -0.0241,  0.2245,  0.1705,  0.1837,  0.0819, -0.1640,
          0.0971],
        [-0.1332,  0.0238, -0.0243,  0.2320,  0.1517,  0.1849,  0.0778, -0.1118,
          0.1246],


tensor([[-0.1423,  0.0191, -0.0464,  0.2226,  0.1548,  0.1626,  0.0750, -0.1609,
          0.1332],
        [-0.1467,  0.0234, -0.0503,  0.2327,  0.1456,  0.1831,  0.0656, -0.1369,
          0.1065],
        [-0.1575,  0.0194, -0.0324,  0.2159,  0.1707,  0.1963,  0.0900, -0.1475,
          0.1199],
        [-0.1331,  0.0357, -0.0553,  0.2161,  0.1773,  0.1380,  0.0624, -0.1629,
          0.1100],
        [-0.1528,  0.0213, -0.0519,  0.2107,  0.1610,  0.1889,  0.0961, -0.1767,
          0.1164],
        [-0.1393,  0.0194, -0.0340,  0.2081,  0.1601,  0.1728,  0.0691, -0.1460,
          0.1451],
        [-0.1330,  0.0352, -0.0540,  0.2266,  0.1547,  0.1440,  0.0609, -0.1592,
          0.1262],
        [-0.1450,  0.0289, -0.0437,  0.2230,  0.1406,  0.2117,  0.0959, -0.1314,
          0.1454],
        [-0.1328,  0.0259, -0.0416,  0.2318,  0.1638,  0.1572,  0.0822, -0.1664,
          0.1136],
        [-0.1425,  0.0289, -0.0460,  0.2167,  0.1301,  0.1589,  0.0633, -0.1280,
          0.1124],


tensor([[-0.1504,  0.0222, -0.0499,  0.2153,  0.1671,  0.1761,  0.0496, -0.1786,
          0.1216],
        [-0.1457,  0.0316, -0.0447,  0.2224,  0.1627,  0.1629,  0.0667, -0.1339,
          0.0985],
        [-0.1405,  0.0145, -0.0448,  0.2136,  0.1508,  0.1731,  0.0655, -0.1380,
          0.1241],
        [-0.1239,  0.0190, -0.0637,  0.2224,  0.1576,  0.1407,  0.0535, -0.1596,
          0.0981],
        [-0.1317,  0.0127, -0.0613,  0.2092,  0.1496,  0.1799,  0.0887, -0.1416,
          0.1289],
        [-0.1339,  0.0250, -0.0448,  0.2404,  0.1761,  0.1629,  0.0680, -0.1911,
          0.0963],
        [-0.1322,  0.0312, -0.0712,  0.2204,  0.1432,  0.1411,  0.0659, -0.1786,
          0.1007],
        [-0.1325,  0.0369, -0.0581,  0.2394,  0.1392,  0.1557,  0.0657, -0.1443,
          0.1165],
        [-0.1412,  0.0301, -0.0495,  0.2304,  0.1638,  0.1554,  0.0503, -0.1431,
          0.1152],
        [-0.1317,  0.0338, -0.0659,  0.2174,  0.1623,  0.1508,  0.0533, -0.1568,
          0.1200],


tensor([[-0.1383,  0.0144, -0.0635,  0.1962,  0.1476,  0.1682,  0.0683, -0.1675,
          0.1294],
        [-0.1398,  0.0280, -0.0832,  0.2238,  0.1805,  0.1351,  0.0394, -0.1795,
          0.0899],
        [-0.1494,  0.0264, -0.0586,  0.2328,  0.1639,  0.1623,  0.0793, -0.1695,
          0.1221],
        [-0.1313,  0.0248, -0.0858,  0.2288,  0.1507,  0.1387,  0.0378, -0.1288,
          0.1071],
        [-0.1399,  0.0224, -0.0788,  0.2202,  0.1567,  0.1404,  0.0605, -0.1819,
          0.0909],
        [-0.1440,  0.0144, -0.0696,  0.2308,  0.1447,  0.1648,  0.0766, -0.1518,
          0.1220],
        [-0.1167,  0.0248, -0.0676,  0.2394,  0.1539,  0.1554,  0.0771, -0.1538,
          0.1145],
        [-0.1314,  0.0342, -0.0686,  0.2185,  0.1333,  0.1324,  0.0673, -0.1406,
          0.1053],
        [-0.1181,  0.0249, -0.0767,  0.2187,  0.1590,  0.1396,  0.0790, -0.1876,
          0.0940],
        [-0.1194,  0.0169, -0.0767,  0.2279,  0.1499,  0.1360,  0.0412, -0.1415,
          0.1127],


tensor([[-0.1286,  0.0240, -0.0774,  0.2240,  0.1485,  0.1317,  0.0724, -0.1291,
          0.0682],
        [-0.1289,  0.0255, -0.0854,  0.2157,  0.1593,  0.1382,  0.0622, -0.1863,
          0.1177],
        [-0.1399,  0.0270, -0.0579,  0.2083,  0.1494,  0.1586,  0.0357, -0.1624,
          0.1200],
        [-0.1116,  0.0282, -0.1163,  0.1879,  0.1583,  0.1024,  0.0432, -0.1795,
          0.1125],
        [-0.1301,  0.0265, -0.0947,  0.2294,  0.1488,  0.1502,  0.0562, -0.1431,
          0.0958],
        [-0.1388,  0.0232, -0.0948,  0.2240,  0.1511,  0.1360,  0.0752, -0.1557,
          0.1128],
        [-0.1204,  0.0303, -0.0862,  0.2204,  0.1502,  0.1285,  0.0315, -0.1718,
          0.0869],
        [-0.1281,  0.0156, -0.0904,  0.2149,  0.1490,  0.1337,  0.0723, -0.1750,
          0.1018],
        [-0.1192,  0.0178, -0.0707,  0.1957,  0.1717,  0.1765,  0.0692, -0.1990,
          0.1124],
        [-0.1229,  0.0207, -0.0901,  0.2053,  0.1607,  0.1418,  0.0532, -0.1674,
          0.1123],


KeyboardInterrupt: 

In [None]:
model_classes[0].__name__

In [None]:
 "Contrastive" in "Contrasteiveee"