# Configuring the Environment

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import librosa
from datasets import load_dataset, Dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2FeatureExtractor
from jiwer import wer

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score
from sklearn.model_selection import train_test_split
import numpy as np 
import pandas as pd
import os
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [None]:
import divexplorer 
import pandas as pd
pd.set_option('max_colwidth', None)
import os
import numpy as np

from utils_analysis import filter_itemset_df_by_attributes, slice_by_itemset

from divexplorer.FP_DivergenceExplorer import FP_DivergenceExplorer
from divexplorer.FP_Divergence import FP_Divergence

In [None]:
## Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

# LibriSpeech Dataset - Inference and Evaluation

## Load Model

In [None]:
model = Wav2Vec2ForCTC.from_pretrained(
    "wav2vec2-base-960h", 
    output_hidden_states=True,
    local_files_only=True
    ).to(device)
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

## Load Dataset

In [None]:
dataset_train = load_dataset("librispeech_asr", "clean", split="train.360")
dataset_valid = load_dataset("librispeech_asr", "clean", split="validation")
dataset_test = load_dataset("librispeech_asr", "clean", split="test")

## Inference

In [None]:
## Train set
last_hidden_states_concatenation_train = []
avg_hidden_states_concatenation_train = []
logits_concatenation_train = []
sequence_lengths_train = []
transcriptions_train = []
wers_train = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_train))):
        input_values = feature_extractor(
            dataset_train[i]["audio"]["array"], 
            return_tensors="pt", 
            padding="longest",
            sampling_rate=feature_extractor.sampling_rate
            ).input_values
        outputs = model(input_values.to(device))
        last_hidden_states_concatenation_train.append(outputs.hidden_states[-1])
        logits_concatenation_train.append(outputs.logits)
        sequence_lengths_train.append(outputs.logits.shape[1])

        avg = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(a, axis=1) for a in avg ]
        avg_hidden_states_concatenation_train.append(np.mean(avg, axis=0))

        predicted_ids = torch.argmax(outputs.logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)
        transcriptions_train.append(transcription[0])

        wers_train.append(dataset_train[i]["text"], transcription[0])

print("WER:", wer(dataset_train["text"], transcriptions_train))

In [None]:
## Valid set
last_hidden_states_concatenation_valid = []
avg_hidden_states_concatenation_valid = []
logits_concatenation_valid = []
sequence_lengths_valid = []
transcriptions_valid = []
wers_valid = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_valid))):
        input_values = feature_extractor(
            dataset_valid[i]["audio"]["array"], 
            return_tensors="pt", 
            padding="longest",
            sampling_rate=feature_extractor.sampling_rate
            ).input_values
        outputs = model(input_values.to(device))
        last_hidden_states_concatenation_valid.append(outputs.hidden_states[-1])
        logits_concatenation_valid.append(outputs.logits)
        sequence_lengths_valid.append(outputs.logits.shape[1])

        avg = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(a, axis=1) for a in avg ]
        avg_hidden_states_concatenation_valid.append(np.mean(avg, axis=0))

        predicted_ids = torch.argmax(outputs.logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)
        transcriptions_valid.append(transcription[0])

        wers_valid.append(dataset_valid[i]["text"], transcription[0])

print("WER:", wer(dataset_valid["text"], transcriptions_valid))

In [None]:
## Test set  
last_hidden_states_concatenation_test = []
avg_hidden_states_concatenation_test = []
logits_concatenation_test = []
transcriptions_test = []
sequence_lengths_test = []
wers_test = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_test))):
        input_values = feature_extractor(
            dataset_test[i]["audio"]["array"], 
            return_tensors="pt", 
            padding="longest",
            sampling_rate=feature_extractor.sampling_rate
            ).input_values
        outputs = model(input_values.to(device))
        last_hidden_states_concatenation_test.append(outputs.hidden_states[-1])
        logits_concatenation_test.append(outputs.logits)
        sequence_lengths_test.append(outputs.logits.shape[1])
        
        avg = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(a, axis=1) for a in avg ]
        avg_hidden_states_concatenation_valid.append(np.mean(avg, axis=0))

        predicted_ids = torch.argmax(outputs.logits, dim=-1)
        transcription = processor.batch_decode(predicted_ids)
        transcriptions_test.append(transcription[0])

        wers_test.append(dataset_test[i]["text"], transcription[0])

print("WER:", wer(dataset_test["text"], transcriptions_test))

In [None]:
## Save features
torch.save(last_hidden_states_concatenation_train, "pretrained/last_hidden_states_train.pt")
torch.save(avg_hidden_states_concatenation_train, "pretrained/avg_hidden_states_train.pt")
torch.save(logits_concatenation_train, "pretrained/logits_train.pt")
torch.save(sequence_lengths_train, "pretrained/sequence_lengths_train.pt")
torch.save(transcriptions_train, "pretrained/transcriptions_train.pt")
torch.save(wers_train, "pretrained/wers_train.pt")

torch.save(last_hidden_states_concatenation_valid, "pretrained/last_hidden_states_valid.pt")
torch.save(avg_hidden_states_concatenation_valid, "pretrained/avg_hidden_states_valid.pt")
torch.save(logits_concatenation_valid, "pretrained/logits_valid.pt")
torch.save(sequence_lengths_valid, "pretrained/sequence_lengths_valid.pt")
torch.save(transcriptions_valid, "pretrained/transcriptions_valid.pt")

torch.save(last_hidden_states_concatenation_test, "pretrained/last_hidden_states_test.pt")
torch.save(avg_hidden_states_concatenation_test, "pretrained/avg_hidden_states_test.pt")
torch.save(logits_concatenation_test, "pretrained/logits_test.pt")
torch.save(sequence_lengths_test, "pretrained/sequence_lengths_test.pt")
torch.save(transcriptions_test, "pretrained/transcriptions_test.pt")

# Loading pretrained features

In [None]:
## Load dataset
dataset_train = load_dataset("librispeech_asr", "clean", split="train.360")
dataset_valid = load_dataset("librispeech_asr", "clean", split="validation")
dataset_test = load_dataset("librispeech_asr", "clean", split="test")

In [None]:
## Load hidden states and logits
print("Loading train features...")
avg_hidden_states_train = torch.load('pretrained/avg_hidden_states_train.pt')
last_hidden_states_train = torch.load('pretrained/last_hidden_states_train.pt')
logits_concatenation_train = torch.load('pretrained/logits_concatenation_train.pt')
sequence_lengths_train = torch.load('pretrained/sequence_lengths_train.pt')
transcriptions_train = torch.load('pretrained/transcriptions_train.pt')
wers_train = torch.load('pretrained/wers_train.pt')

print("Loading valid features...")
avg_hidden_states_valid = torch.load('pretrained/avg_hidden_states_valid.pt')
last_hidden_states_valid = torch.load('pretrained/last_hidden_states_valid.pt')
logits_concatenation_valid = torch.load('pretrained/logits_concatenation_valid.pt')
sequence_lengths_valid = torch.load('pretrained/sequence_lengths_valid.pt')
transcriptions_valid = torch.load('pretrained/transcriptions_valid.pt')
wers_valid = torch.load('pretrained/wers_valid.pt')

print("Loading test features...")
avg_hidden_states_test = torch.load('pretrained/avg_hidden_states_test.pt')
last_hidden_states_test = torch.load('pretrained/last_hidden_states_test.pt')
logits_concatenation_test = torch.load('pretrained/logits_concatenation_test.pt')
sequence_lengths_test = torch.load('pretrained/sequence_lengths_test.pt')
transcriptions_test = torch.load('pretrained/transcriptions_test.pt')
wers_test = torch.load('pretrained/wers_test.pt')

# Prediction

In [None]:
prediction_train = (np.array(dataset_train["text"]) == np.array(transcriptions_train)).astype(int)
prediction_valid = (np.array(dataset_valid["text"]) == np.array(transcriptions_valid)).astype(int)
prediction_test = (np.array(dataset_test["text"]) == np.array(transcriptions_test)).astype(int)

# Confidence Model 

In [None]:
## Confidence model
class ConfidenceModel(nn.Module):
    def __init__(self, input_size=768, hidden_size=500, output_size=1):
        super(ConfidenceModel, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.1)
        self.norm = nn.LayerNorm(hidden_size)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.zeros_(m.bias)
                                     
    def forward(self,x):
        x = self.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.norm(x)
        x = self.relu(self.linear2(x))
        x = self.dropout(x)
        x = self.norm(x)
        x = self.sigmoid(self.linear3(x))
        return x

In [None]:
## Train, valid and test
def train(model, inputs, labels, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return outputs, loss.item()

def val(model, inputs, labels, criterion):
    model.eval()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    return outputs, loss.item()

def test(model, inputs, labels=None, criterion=None):
    model.eval()
    if labels is None and criterion is None:
        outputs = model(inputs.float())
        return outputs
    else:
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        return outputs, loss.item()

In [None]:
from torchsummary import summary

model = ConfidenceModel(input_size=768, hidden_size=500, output_size=1)
model = model.to(device)
summary(model, input_size=(768,))

# Problem Setup

In [None]:
HIDDEN_SIZE = 500
BATCH_SIZE = 4096
NUM_SUBGROUPS = 2
EPOCHS = 10000
MIN_SUP = 0.05
TH_REDUNDANDY = 0.001 
PRETRAIN = True

# Problematic Subgroup Identification - DivExplorer, Step 2

## Prepare df

In [None]:
speakers = {}
with open('SPEAKERS.TXT', 'r') as f:
    lines = f.readlines()
    for i,line in enumerate(lines):
            speaker_id = line.strip().split(' ')[0]
            if len(speaker_id) == 2:
                gender = line.strip().split(' ')[4]
            elif len(speaker_id) == 3:
                gender = line.strip().split(' ')[3]
            else:
                gender = line.strip().split(' ')[2]
            speakers[speaker_id] = gender

gender_train = [speakers[str(sID)] for sID in dataset_train["speaker_id"]]
dataset_train = dataset_train.add_column("gender", gender_train)

gender_valid = [speakers[str(sID)] for sID in dataset_valid["speaker_id"]]
dataset_valid = dataset_valid.add_column("gender", gender_valid) 

gender_test = [speakers[str(sID)] for sID in dataset_test["speaker_id"]]
dataset_test = dataset_test.add_column("gender", gender_test) 

In [None]:
df_train = pd.read_csv('speech_metadata_train.csv')
df_train['gender'] = gender_train
df_train['prediction'] = prediction_train 
df_train['WER'] = wers_train

df_valid = pd.read_csv('speech_metadata_valid.csv')
df_valid['gender'] = gender_valid
df_valid['prediction'] = prediction_valid 
df_valid['WER'] = wers_valid

df_test = pd.read_csv('speech_metadata_test.csv')
df_test['gender'] = gender_test
df_test['prediction'] = prediction_test  
df_test['WER'] = wers_test

In [None]:
# concat df train and valid 
df_trainvalid = pd.concat([df_train, df_valid], ignore_index=True)
len(df_trainvalid)

## Utils

In [None]:
## Define abbreviations for plot and visualization
from divexplorer.FP_Divergence import abbreviateDict
abbreviations = {
    'total_silence': 'tot_silence', \
    'speaker_id' : 'spkID', \
    'trimmed': 'trim', \
    'total_':'tot_', \
    'speed_rate_word_trimmed': 'speakRate_trim', \
    'trim_duration': 'trim_dur', \
    'speed_rate_word':'speakRate', \
    'speed_rate_char':'speakCharRate', \
    'duration': 'dur'
    }

abbreviations_shorter = abbreviations.copy()

## Function for sorting data cohorts
def sortItemset(x, abbreviations={}):
    x = list(x)
    x.sort()
    x = ", ".join(x)
    for k, v in abbreviations.items():
        x = x.replace(k, v)
    return x

def attributes_in_itemset(itemset, attributes, alls = True):
    """ Check if attributes are in the itemset (all or at least one)
    
    Args:
        itemset (frozenset): the itemset
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        
    """
    # Avoid returning the empty itemset (i.e., info of entire dataset)
    if itemset == frozenset() and attributes:
        return False
    
    for item in itemset:
        # Get the attribute
        attr_i = item.split("=")[0]
        
        #If True, check if ALL attributes of the itemset are the input attributes.
        if alls:
            # Check if the attribute is present. If not, the itemset is not admitted
            if attr_i not in attributes:
                return False
        else:
            # Check if least one attribute. If yes, return True
            if attr_i in attributes:
                return True
    if alls:
        # All attributes of the itemset are indeed admitted
        return True
    else:
        # Otherwise, it means that we find None
        return False
    
def filter_itemset_df_by_attributes(df: pd.DataFrame, attributes: list, alls = True, itemset_col_name: str = "itemsets") -> pd.DataFrame:
    """Get the set of itemsets that have the attributes in the input list (all or at least one)
    
    Args:
        df (pd.DataFrame): the input itemsets (with their info). 
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        itemset_col_name (str) : the name of the itemset column, "itemsets" as default
        
    Returns:
        pd.DataFrame: the set of itemsets (with their info)
    """

    return df.loc[df[itemset_col_name].apply(lambda x: attributes_in_itemset(x, attributes, alls = alls))]

In [None]:
# ## Target for DivExplorer: 
# # 'prediction' is 1 if predicted_intet == original_intent, 0 otherwise
# target_col = 'prediction' 
# target_metric = 'd_posr'
# target_div = 'd_accuracy'
# t_value_col = 't_value_tp_fn'

## Target for DivExplorer: 'WER'
target_col = 'WER' 
target_metric = 'd_outcome'
target_div = f'd_{target_col}'
t_value_col = 't_value_outcome'
printable_columns = ['support', 'itemsets','WER', 'd_WER', 't_value']

In [None]:
# ## Columns for visualization
# show_cols = [
#        'support', 
#        'itemsets', 
#        '#errors', 
#        '#corrects', 
#        'accuracy', 
#        'd_accuracy', 
#        't_value', 
#        'support_count', 
#        'length']
# remapped_cols = {
#        'tn': '#errors', 
#        'tp': '#corrects', 
#        'posr': 'accuracy', 
#        target_metric: target_div, 
#        't_value_tp_fn': 't_value'
#        }

## Columns for visualization
remapped_cols = { 
       "outcome": target_col, 
       "d_outcome": target_div, 
       t_value_col: 't_value'}
show_cols = [
       'support', 
       'itemsets', 
       target_col, 
       target_div, 
       'support_count', 
       'length', 
       't_value'
       ]

## Columns of the df file that we are going to analyze 
demo_cols = ['gender']

signal_cols = ['total_silence', 'total_duration', 'n_words', 'speed_rate_word']

input_cols = demo_cols + signal_cols 

## Train

In [None]:
## Discretize the dataframe
from util_discretization import discretize

df_discretized = discretize(
    df_train[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized[signal_cols[i]].replace(replace_values, inplace=True)

## Create dict of Divergence df
# fp_diver = FP_DivergenceExplorer(df_discretized, true_class_name=target_col, class_map={"P":1, "N":0})
fp_diver = FP_DivergenceExplorer(df_discretized, target_name=target_col)
FP_fm = fp_diver.getFrequentPatternDivergence(min_support=MIN_SUP, metrics=[target_metric])
FP_fm.rename(columns=remapped_cols, inplace=True)
FP_fm = FP_fm[show_cols].copy()
# FP_fm['accuracy'] = round(FP_fm['accuracy'], 5)
# FP_fm['d_accuracy'] = round(FP_fm['d_accuracy'], 5)
FP_fm['WER'] = round(FP_fm['WER'], 5)
FP_fm['d_WER'] = round(FP_fm['d_WER'], 5)
FP_fm['t_value'] = round(FP_fm['t_value'], 2)
fp_divergence = FP_Divergence(FP_fm, target_div)

In [None]:
## Compute the divergence for Wav2Vec2-Base
FPdiv = fp_divergence.getDivergence(th_redundancy=TH_REDUNDANCY)[::-1] 

## Retrieve Most Divergent Itemsets 
from copy import deepcopy
pr = FPdiv.head(NUM_SUBGROUPS).copy()
pr["support"] = pr["support"].round(2)
pr["WER"] = (pr["WER"]*100).round(3)
pr["d_WER"] = (pr["d_WER"]*100).round(3)
# pr["#errors"] = pr["#errors"].astype(int)
# pr["#corrects"] = pr["#corrects"].astype(int)
# pr["accuracy"] = (pr["accuracy"]*100).round(3)
# pr["d_accuracy"] = (pr["d_accuracy"]*100).round(3)
display(pr)

## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized["subgID"] = 0
itemsets = []
for i in range(NUM_SUBGROUPS):
    itemsets.append(list(pr.itemsets.values[i]))
for i in tqdm(range(0, len(df_discretized))):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized.loc[i, ks] == vs):
            if df_discretized.loc[i, "subgID"] == 0:
                df_discretized.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

for i in range(0,NUM_SUBGROUPS+1):
    print(len(df_discretized.loc[df_discretized["subgID"]==i]))
    
df_discretized.to_csv("df_discretized.csv", index=False)

## Valid

In [None]:
## Discretize the dataframe
from util_discretization import discretize

df_discretized_valid = discretize(
    df_valid[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized_valid[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized_valid[signal_cols[i]].replace(replace_values, inplace=True)

In [None]:
## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized_valid["subgID"] = 0
for i in tqdm(range(0, len(df_discretized_valid))):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized_valid.loc[i, ks] == vs):
            if df_discretized_valid.loc[i, "subgID"] == 0:
                df_discretized_valid.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

for i in range(0,NUM_SUBGROUPS+1):
    print(len(df_discretized_valid.loc[df_discretized_valid["subgID"]==i]))

df_discretized_valid.to_csv("df_discretized_valid.csv", index=False)

## Test

In [None]:
## Discretize the dataframe
from util_discretization import discretize

df_discretized_test = discretize(
    df_test[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized_test[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized_test[signal_cols[i]].replace(replace_values, inplace=True)

In [None]:
## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized_test["subgID"] = 0
for i in tqdm(range(0, len(df_discretized_test))):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized_test.loc[i, ks] == vs):
            if df_discretized_test.loc[i, "subgID"] == 0:
                df_discretized_test.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

for i in range(0,NUM_SUBGROUPS+1):
    print(len(df_discretized_test.loc[df_discretized_test["subgID"]==i]))
    
df_discretized_test.to_csv("df_discretized_test.csv", index=False)

# Full Pipeline - Steps 1 & 3

In [None]:
df_cm = df_trainvalid[[
    'total_silence', 'n_words', 'speed_rate_word'
    ]]
df_cm_test = df_test[[
    'total_silence', 'n_words', 'speed_rate_word'
    ]]

## Pretraining the CM

In [None]:
X_train = torch.tensor(avg_hidden_states_train).squeeze()
X_train = torch.mean(X_train, dim=1)
X_train = torch.cat((
    torch.tensor(logits_concatenation_train),
    torch.tensor(sequence_lengths_train).unsqueeze(dim=1),
    X_train,
    torch.tensor(df_cm['total_silence'])[:len(dataset_train)].unsqueeze(1),
    torch.tensor(df_cm['n_words'])[:len(dataset_train)].unsqueeze(1),
    torch.tensor(df_cm['speed_rate_word'])[:len(dataset_train)].unsqueeze(1),
    ), dim=1)
y_train = torch.tensor(prediction_train).unsqueeze(1)

X_val = torch.tensor(avg_hidden_states_valid).squeeze()
X_val = torch.mean(X_val, dim=1)
X_val = torch.cat((
    torch.tensor(logits_concatenation_valid),
    torch.tensor(sequence_lengths_valid).unsqueeze(dim=1),
    X_val,
    torch.tensor(df_cm['total_silence'])[len(dataset_train):].unsqueeze(1),
    torch.tensor(df_cm['n_words'])[len(dataset_train):].unsqueeze(1),
    torch.tensor(df_cm['speed_rate_word'])[len(dataset_train):].unsqueeze(1),
    ), dim=1)
y_val = torch.tensor(prediction_valid).unsqueeze(1)

X_test = torch.tensor(avg_hidden_states_test).squeeze()
X_test = torch.mean(X_test, dim=1)
X_test = torch.cat((
    torch.tensor(logits_concatenation_test),
    torch.tensor(sequence_lengths_test).unsqueeze(dim=1),
    X_test,
    torch.tensor(df_cm_test['total_silence']).unsqueeze(1),
    torch.tensor(df_cm_test['n_words']).unsqueeze(1),
    torch.tensor(df_cm_test['speed_rate_word']).unsqueeze(1),
    ), dim=1)
y_test = torch.tensor(prediction_test).unsqueeze(1)

In [None]:
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

best_auc = 0
best_acc = 0
best_output = 0
best_model = 0

if PRETRAIN:
    ## Create model
    model = ConfidenceModel(
        input_size=X_train.shape[1],
        hidden_size=HIDDEN_SIZE, 
        output_size=1
        ).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.NAdam(model.parameters(), lr=0.005) 

    ## Train model
    train_losses = []
    val_losses = []
    val_aucs = []

    for epoch in range(EPOCHS):
        
        ## Train in batches
        for i in range(0, len(X_train), BATCH_SIZE):
            train_output, train_loss = train(
                model, 
                X_train[i:i+BATCH_SIZE].float().to(device), 
                y_train[i:i+BATCH_SIZE].float().to(device), 
                criterion, 
                optimizer
                )
        train_losses.append(train_loss)
            
        val_output, val_loss = val(
            model, 
            X_val.float().to(device), 
            y_val.float().to(device),
            criterion
            )
        val_losses.append(val_loss)
        val_output = (val_output > 0.5).float()
        val_acc = accuracy_score(y_val, val_output.cpu().detach().numpy())
        val_auc = roc_auc_score(y_val, val_output.cpu().detach().numpy())
        val_aucs.append(val_auc)     

        if val_losses[-1] > val_losses[-2] and val_losses[-2] > val_losses[-3]:
            break

    ## Test accuracy and AUC
    test_output, test_loss = test(
        model, 
        X_test.float().to(device), 
        y_test.float().to(device), 
        criterion
        )
    test_output = (test_output > 0.5).float()
    test_acc = accuracy_score(y_test, test_output.cpu().detach().numpy())
    test_auc = roc_auc_score(y_test, test_output.cpu().detach().numpy())
    test_losses.append(test_loss)   
    test_aucs.append(test_auc)         
            
    best_auc = test_auc
    best_acc = test_acc
    best_output = test_output
    best_model = model

    ## Confusion matrix
    from sklearn.metrics import confusion_matrix
    print("Test accuracy: ", round(best_acc*100, 2), "%")
    print("Test AUC: ", round(best_auc, 2))

## Problematic Subgroups Prediction

In [None]:
## Create train, val, test split
y_train_subs = torch.tensor(df_discretized['subgID'])
y_val_subs = torch.tensor(df_discretized_valid['subgID'])
y_test_subs = torch.tensor(df_discretized_test['subgID'])

In [None]:
## Train, valid and test
def train(
    model, 
    inputs, 
    labels, 
    criterion, 
    optimizer, 
    criterion1=None,
    labels1=None,
    alpha=0.5,
    criterion2=None,
    labels2=None,
    ):
    model.train()
    optimizer.zero_grad()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    if criterion1 is not None:
        if labels1 is None:
            loss1 = criterion1(outputs, labels)
        else:
            loss1 = criterion1(outputs, labels1.to(device))
        loss = alpha * loss + (1-alpha) * loss1
    if criterion2 is not None:
        if labels2 is None:
            loss2 = criterion2(outputs, labels).float()
        else:
            loss2 = criterion2(outputs, labels2.to(device))
        loss = alpha * loss + (1-alpha) * loss2
    loss.backward()
    optimizer.step()
    return outputs, loss.item()

def val(
    model, 
    inputs, 
    labels, 
    criterion, 
    criterion1=None, 
    labels1=None,
    alpha=0.5,
    criterion2=None,
    labels2=None,
    ):
    model.eval()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    if criterion1 is not None:
        if labels1 is None:
            loss1 = criterion1(outputs, labels)
        else:
            loss1 = criterion1(outputs, labels1.to(device))
        loss = alpha * loss + (1-alpha) * loss1
    if criterion2 is not None:
        if labels2 is None:
            loss2 = criterion2(outputs, labels)
        else:
            loss2 = criterion2(outputs, labels2.to(device))
        loss = alpha * loss + (1-alpha) * loss2
    return outputs, loss.item()

def test(
    model, 
    inputs, 
    labels=None, 
    criterion=None, 
    criterion1=None, 
    labels1=None,
    alpha=0.5,
    criterion2=None,
    labels2=None,
    ):
    model.eval()
    if labels is None and criterion is None:
        outputs = model(inputs.float())
        return outputs
    else:
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        if criterion1 is not None:
            if labels1 is None:
                loss1 = criterion1(outputs, labels)
            else:
                loss1 = criterion1(outputs, labels1.to(device))
            loss = alpha * loss + (1-alpha) * loss1
        if criterion2 is not None:
            if labels2 is None:
                loss2 = criterion2(outputs, labels)
            else:
                loss2 = criterion2(outputs, labels2.to(device))
            loss = alpha * loss + (1-alpha) * loss2
        return outputs, loss.item()

In [None]:
criterion_loss = 'CE+MSE'
ALPHA = 0.6

y_train_subs = torch.tensor(df_discretized['subgID'])
y_val_subs = torch.tensor(df_discretized_valid['subgID'])
y_test_subs = torch.tensor(df_discretized_test['subgID'])

torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

best_f1macro = 0
best_acc = 0
best_output = 0
best_model_step3 = None
train_losses = []
val_losses = []

if PRETRAIN:
    best_model.linear3 = nn.Linear(
        HIDDEN_SIZE, 
        NUM_SUBGROUPS+1
        ).to(device)
    model = best_model
else:
    model = ConfidenceModel(
        input_size=X_train.shape[1],
        hidden_size=HIDDEN_SIZE, 
        output_size=NUM_SUBGROUPS+1
        ).to(device)

## Criterion and optimizer
if criterion_loss == 'MSE' or criterion_loss == 'L1' or criterion_loss == 'MLML':
    criterion = nn.MSELoss() if criterion_loss == 'MSE' else \
        (nn.L1Loss() if criterion_loss == 'L1' else nn.MultiLabelSoftMarginLoss())
    X_train = X_train.float()
    y_train_subs = y_train_subs.unsqueeze(dim=1).float()
    X_val = X_val.float()
    y_val_subs = y_val_subs.unsqueeze(dim=1).float()
    X_test = X_test.float()
    y_test_subs = y_test_subs.unsqueeze(dim=1).float()
    criterion1 = None
    criterion2 = None
    y_train_subs_1 = None
    y_val_subs_1 = None
    y_test_subs_1 = None
    y_train_subs_2 = None
    y_val_subs_2 = None
    y_test_subs_2 = None
elif criterion_loss == 'MSE+L1' or criterion_loss == 'MSE+MLML':
    criterion = nn.MSELoss()
    criterion1 = nn.L1Loss() if criterion_loss == 'MSE+L1' else nn.MultiLabelSoftMarginLoss()
    criterion2 = None
    X_train = X_train.float()
    y_train_subs = y_train_subs.unsqueeze(dim=1).float()
    y_train_subs_1 = None
    X_val = X_val.float()
    y_val_subs = y_val_subs.unsqueeze(dim=1).float()
    y_val_subs_1 = None
    X_test = X_test.float()
    y_test_subs = y_test_subs.unsqueeze(dim=1).float()
    y_test_subs_1 = None
    y_train_subs_2 = None
    y_val_subs_2 = None
    y_test_subs_2 = None
elif criterion_loss == 'CE':
    criterion = nn.CrossEntropyLoss()
    criterion1 = None
    criterion2 = None
    y_train_subs_1 = None
    y_val_subs_1 = None
    y_test_subs_1 = None
    y_train_subs_2 = None
    y_val_subs_2 = None
    y_test_subs_2 = None
elif criterion_loss == 'CE+MLML' or criterion_loss == 'CE+MSE' or criterion_loss == 'CE+L1':
    criterion = nn.CrossEntropyLoss()
    criterion1 = nn.MultiLabelSoftMarginLoss() if criterion_loss == 'CE+MLML' else \
        (nn.MSELoss() if criterion_loss == 'CE+MSE' else nn.L1Loss())
    y_train_subs_1 = y_train_subs.unsqueeze(dim=1).float().to(device)
    y_val_subs_1 = y_val_subs.unsqueeze(dim=1).float().to(device)
    y_test_subs_1 = y_test_subs.unsqueeze(dim=1).float().to(device)
    y_train_subs_2 = None
    y_val_subs_2 = None
    y_test_subs_2 = None
    criterion2 = None
elif criterion_loss == 'CE+MSE+L1' or criterion_loss == 'CE+MSE+MLML':
    criterion = nn.CrossEntropyLoss()
    criterion1 = nn.MSELoss() 
    criterion2 = nn.L1Loss() if criterion_loss == 'CE+MSE+L1' else nn.MultiLabelSoftMarginLoss()
    y_train_subs_1 = y_train_subs.unsqueeze(dim=1).float().to(device)
    y_val_subs_1 = y_val_subs.unsqueeze(dim=1).float().to(device)
    y_test_subs_1 = y_test_subs.unsqueeze(dim=1).float().to(device)
    y_train_subs_2 = y_train_subs.unsqueeze(dim=1).float().to(device)
    y_val_subs_2 = y_val_subs.unsqueeze(dim=1).float().to(device)
    y_test_subs_2 = y_test_subs.unsqueeze(dim=1).float().to(device)       
elif criterion_loss == 'MLML+MSE+CE':
    criterion = nn.MultiLabelSoftMarginLoss()
    y_train_subs = y_train_subs.unsqueeze(dim=1).float().to(device)
    y_val_subs = y_val_subs.unsqueeze(dim=1).float().to(device)
    y_test_subs = y_test_subs.unsqueeze(dim=1).float().to(device)  
    criterion1 = nn.MSELoss() 
    y_train_subs_1 = y_train_subs.unsqueeze(dim=1).float().to(device)
    y_val_subs_1 = y_val_subs.unsqueeze(dim=1).float().to(device)
    y_test_subs_1 = y_test_subs.unsqueeze(dim=1).float().to(device)
    criterion2 = nn.CrossEntropyLoss
    y_train_subs_2 = y_train_subs.to(device)
    y_val_subs_2 = y_val_subs.to(device)
    y_test_subs_2 = y_test_subs.to(device)  

else:
    print("Error: criterion_loss not valid")
    break

optimizer = optim.NAdam(model.parameters(), lr=0.005)

## Train and validate model
train_losses = []
val_losses = []
for epoch in tqdm(range(EPOCHS)):

    train_output, train_loss = train(
        model, 
        X_train.to(device), 
        y_train_subs.to(device), 
        criterion, 
        optimizer,
        criterion1,
        y_train_subs_1,
        alpha=ALPHA,
        criterion2=criterion2,
        labels2=y_train_subs_2
        )

    val_output, val_loss = val(
        model, 
        X_val.to(device), 
        y_val_subs.to(device), 
        criterion,
        criterion1,
        y_val_subs_1,
        alpha=ALPHA,
        criterion2=criterion2,
        labels2=y_val_subs_2
        )

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    if val_loss >= val_losses[-2] and val_loss >= val_losses[-3]:
        break

test_output, test_loss = test(
    model, 
    X_test.to(device),
    y_test_subs.to(device), 
    criterion,
    criterion1,
    y_test_subs_1,
    alpha=ALPHA,
    criterion2=criterion2,
    labels2=y_test_subs_2
    )
test_output = test_output.cpu().detach().numpy()
test_output = np.argmax(test_output, axis=1)
test_acc = accuracy_score(y_test_subs, test_output)
test_f1 = f1_score(y_test_subs, test_output, average='macro')

best_f1macro = test_f1
best_acc = test_acc
best_output = test_output
best_model_step3 = model

print("Test Accuracy: ", best_acc)
print("Test EER: ", 1-best_acc)
print("Test F1 Macro: ", best_f1macro)
print("--------------------\n")

# Random baselines

In [None]:
## Random baseline: assign random class to each sample

SEED = 1
np.random.seed(SEED)

random_pred = np.random.randint(0, NUM_SUBGROUPS+1, len(y_test_subs))
print("Test Accuracy: ", accuracy_score(y_test_subs, random_pred))
print("F1 Macro: ", f1_score(y_test_subs, random_pred, average='macro'))

## K = 2
# Test Accuracy:    0.3316793893129771
# F1 Macro:         0.3248792544071231

## K = 3
# Test Accuracy:    0.24923664122137404
# F1 Macro:         0.2250682740747604

## K = 4
# Test Accuracy:    0.21259541984732824
# F1 Macro:         0.16816526324538308

## K = 5    
# Test Accuracy:    0.17022900763358778
# F1 Macro:         0.13941870409296506

In [None]:
## Random baseline: assign each sample to the most frequent class

SEED = 1
np.random.seed(SEED)

# count number of samples for each class
counts = np.zeros(NUM_SUBGROUPS+1)
for i in range(NUM_SUBGROUPS+1):
    counts[i] = len(y_test_subs[y_test_subs == i])

most_frequent_pred = np.ones(len(y_test_subs))
most_frequent_pred = most_frequent_pred * np.argmax(counts)
print("Test Accuracy: ", accuracy_score(y_test_subs, most_frequent_pred))
print("F1 Macro: ", f1_score(y_test_subs, most_frequent_pred, average='macro'))

## K = 2
# Test Accuracy:    0.44236641221374046
# F1 Macro:         0.20446326188586048

## K = 3
# Test Accuracy:    0.40419847328244274
# F1 Macro:         0.14392497961402556

## K = 4
# Test Accuracy:    0.40419847328244274
# F1 Macro:         0.14392497961402556

## K = 5
# Test Accuracy:    0.33053435114503815
# F1 Macro:         0.09936890418818131