# Configuring the Environment

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import librosa
from datasets import load_dataset, Dataset
from transformers import Wav2Vec2ForSequenceClassification, Wav2Vec2FeatureExtractor

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, roc_auc_score
from sklearn.model_selection import train_test_split
import numpy as np 
import pandas as pd
import os
from tqdm import tqdm
import json

import warnings
warnings.filterwarnings('ignore')

In [None]:
import divexplorer 
import pandas as pd
pd.set_option('max_colwidth', None)
import os
import numpy as np

from utils_analysis import filter_itemset_df_by_attributes, slice_by_itemset

from divexplorer.FP_DivergenceExplorer import FP_DivergenceExplorer
from divexplorer.FP_Divergence import FP_Divergence

In [None]:
## Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)

# ITALIC Dataset - Inference and Evaluation

## Utils

In [None]:
def map_to_array(example, audio_col = 'path'):
    speech, _ = librosa.load(example[audio_col], sr=16000, mono=True)
    example["speech"] = speech
    return example

In [None]:
def preprocess_function(examples, max_duration=10):
  inputs = feature_extractor(
    examples,
    sampling_rate=feature_extractor.sampling_rate, 
    padding='max_length', 
    max_length=int(feature_extractor.sampling_rate * max_duration),
    truncation=True,
    return_tensors="pt")      
  return inputs

## Model

In [None]:
HF = False

In [None]:
## Load model
if HF == True:
    print("Loading model from HF")
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
else: 
    print("Loading model from local directory")
    model_checkpoint = "italic_original"
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained(
        model_checkpoint, 
        output_hidden_states=True,
        local_files_only=True
        ).to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

    with open(os.path.join(model_checkpoint, "config.json"), "r") as f:
        config = json.load(f)
        label2id = config["label2id"]
        id2label = config["id2label"]

## Train

In [None]:
## Load and preprocess dataset
df = pd.read_csv('data/italic/train_data_80.csv')

dataset = Dataset.from_pandas(df)
dataset = dataset.map(lambda x: map_to_array(x, audio_col='path'))

In [None]:
## Inference
hidden_states_concatenation = []
logits_concatenation = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset))):
        inputs = preprocess_function(dataset[i]["speech"]).to(device)
        outputs = model_w2v2(**inputs)
        hidden_states = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(hs, axis=1) for hs in hidden_states ]
        avg_hs = np.mean(avg, axis=0)

        hidden_states_concatenation.append(avg_hs)
        logits_concatenation.append(outputs.logits)

In [None]:
## Save hidden states and logits
if HF == True:
    torch.save(hidden_states_concatenation, 'pretrained/italic/hidden_states_train_hf.pt')
    torch.save(logits_concatenation, 'pretrained/italic/logits_train_hf.pt')
else:
    torch.save(hidden_states_concatenation, 'pretrained/italic/hidden_states_train.pt')
    torch.save(logits_concatenation, 'pretrained/italic/logits_train.pt')

In [None]:
## Load hidden states and logits
if HF == True:
    hidden_states_concatenation = torch.load('pretrained/italic/hidden_states_train_hf.pt')
    logits_concatenation = torch.load('pretrained/italic/logits_train_hf.pt')
else:
    hidden_states_concatenation = torch.load('pretrained/italic/hidden_states_train.pt')
    logits_concatenation = torch.load('pretrained/italic/logits_train.pt')

### Intent Accuracy

In [None]:
intents_predicted_train = []
for i in tqdm(range(len(logits_concatenation))):
    logits_train = logits_concatenation[i].detach().cpu()
    intents_predicted_train.append(torch.argmax(logits_train, dim=-1).item())

intents_train_gt = [ label2id[dt['intent']] for dt in tqdm(dataset) ]
is_correct_train = [ int(np.array(intents_predicted_train[i]) == np.array(int(intents_train_gt[i]))) for i in range(len(intents_train_gt))]
df['prediction'] = is_correct_train

df['hidden_states'] = [hs.squeeze() for hs in hidden_states_concatenation]
df['hidden_states'] = df['hidden_states'].apply(lambda x: x.astype(float))

In [None]:
df['predicted_intent'] = intents_predicted_train

In [None]:
if HF == True:
    output_folder = os.path.join(f'data/italic/italic_train_80_hf.csv')
else:
    output_folder = os.path.join(f'data/italic/italic_train_80.csv')
df.to_csv(output_folder, index=False)

## Validation

In [None]:
## Load and preprocess dataset
df_valid = pd.read_csv('data/italic/valid_data.csv')

dataset_valid = Dataset.from_pandas(df_valid) 
dataset_valid = dataset_valid.map(lambda x: map_to_array(x, audio_col='path'))

In [None]:
## Inference
hidden_states_concatenation_valid = []
logits_concatenation_valid = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_valid))):
        inputs = preprocess_function(dataset_valid[i]["speech"]).to(device)
        outputs = model_w2v2(**inputs)
        hidden_states = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(hs, axis=1) for hs in hidden_states ]
        avg_hs = np.mean(avg, axis=0)

        hidden_states_concatenation_valid.append(avg_hs)
        logits_concatenation_valid.append(outputs.logits)

In [None]:
## Save hidden states and logits
if HF == True:
    torch.save(hidden_states_concatenation_valid, 'pretrained/italic/hidden_states_valid_hf.pt')
    torch.save(logits_concatenation_valid, 'pretrained/italic/logits_valid_hf.pt')
else:
    torch.save(hidden_states_concatenation_valid, 'pretrained/italic/hidden_states_valid.pt')
    torch.save(logits_concatenation_valid, 'pretrained/italic/logits_valid.pt')

In [None]:
## Load hidden states and logits
if HF == True:
    hidden_states_concatenation_valid = torch.load('pretrained/italic/hidden_states_valid_hf.pt')
    logits_concatenation_valid = torch.load('pretrained/italic/logits_valid_hf.pt')
else:
    hidden_states_concatenation_valid = torch.load('pretrained/italic/hidden_states_valid.pt')
    logits_concatenation_valid = torch.load('pretrained/italic/logits_valid.pt')

### Intent Accuracy

In [None]:
intents_predicted_valid = []
for i in tqdm(range(len(logits_concatenation_valid))):
    logits_valid = logits_concatenation_valid[i].detach().cpu()
    intents_predicted_valid.append(torch.argmax(logits_valid, dim=-1).item())

intents_valid_gt = [ label2id[dv['intent']] for dv in tqdm(dataset_valid) ]
is_correct_valid = [ int(np.array(intents_predicted_valid[i]) == np.array(int(intents_valid_gt[i]))) for i in range(len(intents_valid_gt))]
df_valid['prediction'] = is_correct_valid

df_valid['hidden_states'] = [hs.squeeze() for hs in hidden_states_concatenation_valid]
df_valid['hidden_states'] = df_valid['hidden_states'].apply(lambda x: x.astype(float))

In [None]:
if HF == True:
    output_folder = os.path.join(f'data/italic/italic_valid_hf.csv')
else:
    output_folder = os.path.join(f'data/italic/italic_valid.csv')
df_valid.to_csv(output_folder, index=False)

# Loading Pretrained Features

In [None]:
HF = False

In [None]:
## Load model
if HF == True:
    print("Loading model from HF")
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
else: 
    print("Loading model from local directory")
    model_checkpoint = "italic_original"
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained(
        model_checkpoint, 
        output_hidden_states=True,
        local_files_only=True
        ).to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

    with open(os.path.join(model_checkpoint, "config.json"), "r") as f:
        config = json.load(f)
        label2id = config["label2id"]
        id2label = config["id2label"]

In [None]:
## Load and preprocess dataset
df = pd.read_csv('data/italic/train_data_80.csv')
dataset = Dataset.from_pandas(df)

## Load hidden states and logits
if HF == True:
    hidden_states_concatenation_train = torch.load('pretrained/italic/hidden_states_train_hf.pt')
    logits_concatenation_train = torch.load('pretrained/italic/logits_train_hf.pt')
else:
    hidden_states_concatenation_train = torch.load('pretrained/italic/hidden_states_train.pt')
    logits_concatenation_train = torch.load('pretrained/italic/logits_train.pt')

intents_predicted_train = []
for i in tqdm(range(len(logits_concatenation_train))):
    logits_train = logits_concatenation_train[i].detach().cpu()
    intents_predicted_train.append(torch.argmax(logits_train, dim=-1).item())

intents_train_gt = [ label2id[dt['intent']] for dt in tqdm(dataset) ]
is_correct_train = [ int(np.array(intents_predicted_train[i]) == np.array(int(intents_train_gt[i]))) for i in range(len(intents_train_gt))]
df['prediction'] = is_correct_train

df['hidden_states'] = [hs.squeeze() for hs in hidden_states_concatenation_train]
df['hidden_states'] = df['hidden_states'].apply(lambda x: x.astype(float))

In [None]:
## Load and preprocess dataset
df_valid = pd.read_csv('data/italic/valid_data.csv')
dataset_valid = Dataset.from_pandas(df_valid)

## Load hidden states and logits
if HF == True:
    hidden_states_concatenation_valid = torch.load('pretrained/italic/hidden_states_valid_hf.pt')
    logits_concatenation_valid = torch.load('pretrained/italic/logits_train_hf.pt')
else:
    hidden_states_concatenation_valid = torch.load('pretrained/italic/hidden_states_valid.pt')
    logits_concatenation_valid = torch.load('pretrained/italic/logits_valid.pt')

intents_predicted_valid = []
for i in tqdm(range(len(logits_concatenation_valid))):
    logits_valid = logits_concatenation_valid[i].detach().cpu()
    intents_predicted_valid.append(torch.argmax(logits_valid, dim=-1).item())

intents_valid_gt = [ label2id[dv['intent']] for dv in tqdm(dataset_valid) ]
is_correct_valid = [ int(np.array(intents_predicted_valid[i]) == np.array(int(intents_valid_gt[i]))) for i in range(len(intents_valid_gt))]
df_valid['prediction'] = is_correct_valid

df_valid['hidden_states'] = [hs.squeeze() for hs in hidden_states_concatenation_valid]
df_valid['hidden_states'] = df_valid['hidden_states'].apply(lambda x: x.astype(float))

# Confidence Model 

In [None]:
## Confidence model
class ConfidenceModel(nn.Module):
    def __init__(self, input_size=768, hidden_size=1000, output_size=1):
        super(ConfidenceModel, self).__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, hidden_size)
        self.linear3 = nn.Linear(hidden_size, output_size)
        self.relu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
        self.dropout = nn.Dropout(0.1)
        self.norm = nn.LayerNorm(hidden_size)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                nn.init.zeros_(m.bias)
                                     
    def forward(self,x):
        x = self.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.norm(x)
        x = self.relu(self.linear2(x))
        x = self.dropout(x)
        x = self.norm(x)
        x = self.sigmoid(self.linear3(x))
        return x

In [None]:
## Train, valid and test
def train(model, inputs, labels, criterion, optimizer):
    model.train()
    optimizer.zero_grad()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    return outputs, loss.item()

def val(model, inputs, labels, criterion):
    model.eval()
    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    return outputs, loss.item()

def test(model, inputs, labels=None, criterion=None):
    model.eval()
    if labels is None and criterion is None:
        outputs = model(inputs.float())
        return outputs
    else:
        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        return outputs, loss.item()

# Problem Setup

In [None]:
HIDDEN_SIZE = 1000
BATCH_SIZE = 4096
NUM_SUBGROUPS = 2
EPOCHS = 10000
PRETRAIN = True

# DivExplorer

## Utils

In [None]:
## Define abbreviations for plot and visualization
from divexplorer.FP_Divergence import abbreviateDict
abbreviations = {'total_silence':'tot_silence', \
                  'trimmed':'trim', \
                  'total_':'', \
                  'speed_rate_word':'speakRate', \
                  'speed_rate_char':'speakCharRate', \
                  'duration': 'dur'}

abbreviations_shorter = abbreviations.copy()

## Function for sorting data cohorts
def sortItemset(x, abbreviations={}):
    x = list(x)
    x.sort()
    x = ", ".join(x)
    for k, v in abbreviations.items():
        x = x.replace(k, v)
    return x

def attributes_in_itemset(itemset, attributes, alls = True):
    """ Check if attributes are in the itemset (all or at least one)
    
    Args:
        itemset (frozenset): the itemset
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        
    """
    # Avoid returning the empty itemset (i.e., info of entire dataset)
    if itemset == frozenset() and attributes:
        return False
    
    for item in itemset:
        # Get the attribute
        attr_i = item.split("=")[0]
        
        #If True, check if ALL attributes of the itemset are the input attributes.
        if alls:
            # Check if the attribute is present. If not, the itemset is not admitted
            if attr_i not in attributes:
                return False
        else:
            # Check if least one attribute. If yes, return True
            if attr_i in attributes:
                return True
    if alls:
        # All attributes of the itemset are indeed admitted
        return True
    else:
        # Otherwise, it means that we find None
        return False
    
def filter_itemset_df_by_attributes(df: pd.DataFrame, attributes: list, alls = True, itemset_col_name: str = "itemsets") -> pd.DataFrame:
    """Get the set of itemsets that have the attributes in the input list (all or at least one)
    
    Args:
        df (pd.DataFrame): the input itemsets (with their info). 
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        itemset_col_name (str) : the name of the itemset column, "itemsets" as default
        
    Returns:
        pd.DataFrame: the set of itemsets (with their info)
    """

    return df.loc[df[itemset_col_name].apply(lambda x: attributes_in_itemset(x, attributes, alls = alls))]

In [None]:
## Target for DivExplorer: 
# 'prediction' is 1 if predicted_intet == original_intent, 0 otherwise
target_col = 'prediction' 
target_metric = 'd_posr'
target_div = 'd_accuracy'
t_value_col = 't_value_tp_fn'

In [None]:
## Columns for visualization
show_cols = ['support', 'itemsets', '#errors', '#corrects', 'accuracy', \
                'd_accuracy', 't_value', 'support_count', 'length']
remapped_cols = {'tn': '#errors', 'tp': '#corrects', 'posr': 'accuracy', \
                target_metric: target_div, 't_value_tp_fn': 't_value'}

## Columns of the df file that we are going to analyze 
demo_cols = ['gender', 'age', 'region', 'nationality', 'lisp', 'education']

slot_cols = ['action', 'scenario']

rec_set_cols = ['environment', 'device', 'field']

signal_cols = ['total_silence', 'total_duration', 'trimmed_duration', 
       'n_words', 'speed_rate_word', 'speed_rate_word_trimmed']   
 
input_cols = demo_cols + signal_cols + slot_cols

In [None]:
# select the columns of interest
df_divexpl = df[[
    'prediction', 
    'intent',
    'gender', 'age', 'region', 'nationality', 'lisp', 'education', 'environment', 'device', 'field',
    'total_silence', 'total_duration', 'trimmed_duration', 'n_words', 'speed_rate_word', 'speed_rate_word_trimmed'
    ]]

df_valid_divexpl = df_valid[[
    'prediction', 
    'intent',
    'gender', 'age', 'region', 'nationality', 'lisp', 'education', 'environment', 'device', 'field',
    'total_silence', 'total_duration', 'trimmed_duration', 'n_words', 'speed_rate_word', 'speed_rate_word_trimmed'
    ]]

## Train

In [None]:
MIN_SUP =       0.2     
th_redundancy = 0.01    

In [None]:
df_divexpl['action'] = df_divexpl['intent'].apply(lambda x: x.split("_")[0])
df_divexpl['scenario'] = df_divexpl['intent'].apply(lambda x: x.split("_")[1])

## Discretize the dataframe
from divergence_utils import discretize

df_discretized = discretize(
    df_divexpl[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized[signal_cols[i]].replace(replace_values, inplace=True)

## Create dict of Divergence df
fp_diver = FP_DivergenceExplorer(df_discretized, true_class_name=target_col, class_map={"P":1, "N":0})
FP_fm = fp_diver.getFrequentPatternDivergence(min_support=MIN_SUP, metrics=[target_metric])
FP_fm.rename(columns=remapped_cols, inplace=True)
FP_fm = FP_fm[show_cols].copy()
FP_fm['accuracy'] = round(FP_fm['accuracy'], 5)
FP_fm['d_accuracy'] = round(FP_fm['d_accuracy'], 5)
FP_fm['t_value'] = round(FP_fm['t_value'], 2)
fp_divergence = FP_Divergence(FP_fm, target_div)

In [None]:
## Compute the divergence for Wav2Vec2-Base
FPdiv = fp_divergence.getDivergence(th_redundancy=th_redundancy)[::-1] #0.05

## Retrieve Most Divergent Itemsets 
from copy import deepcopy
pr = FPdiv.copy()
pr["support"] = pr["support"].round(2)
pr["#errors"] = pr["#errors"].astype(int)
pr["#corrects"] = pr["#corrects"].astype(int)
pr["accuracy"] = (pr["accuracy"]*100).round(3)
pr["d_accuracy"] = (pr["d_accuracy"]*100).round(3)
pr.head(5)

In [None]:
## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise

df_discretized["subgID"] = 0
itemsets = []
for i in range(NUM_SUBGROUPS):
    itemsets.append(list(pr.itemsets.values[i]))
for i in tqdm(range(0, len(df_discretized))):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized.loc[i, ks] == vs):
            if df_discretized.loc[i, "subgID"] == 0:
                df_discretized.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

for i in range(0,NUM_SUBGROUPS+1):
    print(len(df_discretized.loc[df_discretized["subgID"]==i]))

# df_discretized.to_csv("df_discretized.csv", index=False)

## Valid

In [None]:
df_valid_divexpl['action'] = df_valid_divexpl['intent'].apply(lambda x: x.split("_")[0])
df_valid_divexpl['scenario'] = df_valid_divexpl['intent'].apply(lambda x: x.split("_")[1])

## Discretize the dataframe
from divergence_utils import discretize

df_discretized_valid = discretize(
    df_valid_divexpl[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized_valid[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized_valid[signal_cols[i]].replace(replace_values, inplace=True)

In [None]:
## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized_valid["subgID"] = 0
for i in tqdm(range(0, len(df_discretized_valid))):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized_valid.loc[i, ks] == vs):
            if df_discretized_valid.loc[i, "subgID"] == 0:
                df_discretized_valid.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

for i in range(0,NUM_SUBGROUPS+1):
    print(len(df_discretized_valid.loc[df_discretized_valid["subgID"]==i]))

# df_discretized_valid.to_csv("df_discretized_valid.csv", index=False)

# CM Pretraining and Finetuning

In [None]:
df_cm = df[[
    'prediction', 
    'hidden_states',
    'total_silence', 'n_words', 'speed_rate_word'
    ]]
df_cm_valid = df_valid[[
    'prediction', 
    'hidden_states',
    'total_silence', 'n_words', 'speed_rate_word'
    ]]

In [None]:
logits_train = [ lgt.detach().cpu().numpy().squeeze() for lgt in logits_concatenation_train ] 
logits_valid = [ lgt.detach().cpu().numpy().squeeze() for lgt in logits_concatenation_valid ]

## Pretraining the CM

In [None]:
## Create train and val split
X_train = torch.cat((
    torch.tensor(logits_train),
    torch.tensor(df_cm['hidden_states']),
    torch.tensor(df_cm['total_silence']).unsqueeze(1),
    torch.tensor(df_cm['n_words']).unsqueeze(1),
    torch.tensor(df_cm['speed_rate_word']).unsqueeze(1),
    ), dim=1)
y_train = torch.tensor(df_cm['prediction']).unsqueeze(1)

X_val = torch.cat((
    torch.tensor(logits_valid),
    torch.tensor(df_cm_valid['hidden_states']),
    torch.tensor(df_cm_valid['total_silence']).unsqueeze(1),
    torch.tensor(df_cm_valid['n_words']).unsqueeze(1),
    torch.tensor(df_cm_valid['speed_rate_word']).unsqueeze(1),
    ), dim=1)
y_val = torch.tensor(df_cm_valid['prediction']).unsqueeze(1)

In [None]:
seeds = [1, 10, 42] 

for seed in seeds:

    print("Seed: ", seed)
    SEED = seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)

    best_auc = 0
    best_acc = 0
    best_output = 0
    best_model = 0
    best_epoch = 0

    ## Create model
    model = ConfidenceModel(
        input_size=X_train.shape[1],
        hidden_size=HIDDEN_SIZE, 
        output_size=1
        ).to(device)        

    criterion = nn.BCELoss()
    optimizer = optim.NAdam(model.parameters(), lr=0.005)

    ## Train model
    train_losses = []
    val_losses = []
    val_aucs = []

    for epoch in range(EPOCHS):
            
        ## Train in batches
        for i in range(0, len(X_train), BATCH_SIZE):
            train_output, train_loss = train(
                model, 
                X_train[i:i+BATCH_SIZE].float().to(device), 
                y_train[i:i+BATCH_SIZE].float().to(device), 
                criterion, 
                optimizer
                )
        train_losses.append(train_loss)
            
        ## Validate
        val_output, val_loss = val(
            model, 
            X_val.float().to(device), 
            y_val.float().to(device),
            criterion
            )
        val_losses.append(val_loss)
        val_output = (val_output > 0.5).float()
        val_acc = accuracy_score(y_val, val_output.cpu().detach().numpy())
        val_auc = roc_auc_score(y_val, val_output.cpu().detach().numpy())
        val_aucs.append(val_auc)     
    
        if val_auc > best_auc:
            best_auc = val_auc
            best_acc = val_acc
            best_output = val_output
            best_model = model
            best_epoch = epoch

        # print("Epoch: ", epoch, \
        #     "Train loss: ", round(train_loss, 5), \
        #     "Val loss: ", round(val_loss, 5), \
        #     "Val acc: ", round(val_acc*100, 2), "%", \
        #     "Val auc: ", round(val_auc, 2))

        if epoch > 1000:
            if val_losses[-1] > val_losses[-2] and val_losses[-2] > val_losses[-3]:
                break

    ## Print metrics 
    print("Best epoch: ", best_epoch)
    print("Val accuracy: ", round(best_acc*100, 2), "%")
    print("Val AUC: ", round(best_auc, 2))

    ## Save model
    if HF == True:
        torch.save(best_model, f'cm_pt_ft/italic/confidence_model_pt_hf.pt')
    else:
        torch.save(best_model, f'cm_pt_ft/italic/confidence_model_pt.pt')
    print("Model saved!")

## Problematic Subgroups Prediction

In [None]:
## Create train and val split
X_train = torch.cat((
    torch.tensor(logits_train),
    torch.tensor(df_cm['hidden_states']),
    torch.tensor(df_cm['total_silence']).unsqueeze(1),
    torch.tensor(df_cm['n_words']).unsqueeze(1),
    torch.tensor(df_cm['speed_rate_word']).unsqueeze(1),
    ), dim=1)
y_train_subs = torch.tensor(df_discretized['subgID'])

X_val = torch.cat((
    torch.tensor(logits_valid),
    torch.tensor(df_cm_valid['hidden_states']),
    torch.tensor(df_cm_valid['total_silence']).unsqueeze(1),
    torch.tensor(df_cm_valid['n_words']).unsqueeze(1),
    torch.tensor(df_cm_valid['speed_rate_word']).unsqueeze(1),
    ), dim=1)
y_val_subs = torch.tensor(df_discretized_valid['subgID'])

In [None]:
seeds = [1, 10, 42]

for seed in seeds:

    print("Seed: ", seed)
    SEED = seed
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    np.random.seed(SEED)

    best_f1macro = 0
    best_acc = 0
    best_output = 0
    best_epoch = 0
    train_losses = []
    val_losses = []

    if PRETRAIN:
        best_model.linear3 = nn.Linear(
            HIDDEN_SIZE,
            NUM_SUBGROUPS+1
            ).to(device)
        model = best_model
    else:
        model = ConfidenceModel(
            input_size=X_train.shape[1],
            hidden_size=HIDDEN_SIZE,
            output_size=NUM_SUBGROUPS+1
            ).to(device)

    ## Criterion and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.NAdam(model.parameters(), lr=0.005)

    ## Train and validate model
    train_losses = []
    val_losses = []
    for epoch in range(EPOCHS):
        train_output, train_loss = train(
            model,
            X_train.to(device),
            y_train_subs.to(device),
            criterion,
            optimizer
            )
        val_output, val_loss = val(
            model,
            X_val.to(device),
            y_val_subs.to(device),
            criterion
            )
        val_output = val_output.cpu().detach().numpy()
        val_output = np.argmax(val_output, axis=1)
        val_acc = accuracy_score(y_val_subs, val_output)
        val_f1 = f1_score(y_val_subs, val_output, average='macro')
        if val_f1 > best_f1macro:
            best_f1macro = val_f1
            best_acc = val_acc
            best_output = val_output
            best_epoch = epoch

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        if epoch > 2000 and val_loss > val_losses[-2] and val_loss > val_losses[-3]:
            break

    print("Best Epoch: ", best_epoch)
    print("Val Accuracy: ", best_acc)
    print("Val F1 Macro: ", best_f1macro)
    # print("Confusion Matrix: \n", confusion_matrix(y_val_subs, best_output))
    print("--------------------\n")

    ## Save model
    if HF == True:
        torch.save(model, f'cm_pt_ft/italic/confidence_model_ft_hf.pt')
    else:
        torch.save(model, f'cm_pt_ft/italic/confidence_model_ft.pt')
    print("Model saved!")

# Select new data

## Prepare data

In [None]:
def map_to_array(example, audio_col = 'path'):
  speech, _ = librosa.load(example[audio_col], sr=16000, mono=True)
  example["speech"] = speech
  return example

def preprocess_function(examples, max_duration=10):
  inputs = feature_extractor(
    examples,
    sampling_rate=feature_extractor.sampling_rate, 
    padding='max_length', 
    max_length=int(feature_extractor.sampling_rate * max_duration),
    truncation=True,
    return_tensors="pt")      
  return inputs

In [None]:
## Load model
if HF == True:
    print("Loading model from HF")
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
else: 
    print("Loading model from local directory")
    model_checkpoint = "italic_original"
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained(
        model_checkpoint, 
        output_hidden_states=True,
        local_files_only=True
        ).to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

    with open(os.path.join(model_checkpoint, "config.json"), "r") as f:
        config = json.load(f)
        label2id = config["label2id"]
        id2label = config["id2label"]

In [None]:
df_left_out = pd.read_csv('data/italic/train_data_20.csv')

dataset_left_out = Dataset.from_pandas(df_left_out)
dataset_left_out = dataset_left_out.map(lambda x: map_to_array(x, audio_col='path'))

In [None]:
## Inference
hidden_states_concatenation = []
logits_concatenation = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_left_out))):
        inputs = preprocess_function(dataset_left_out[i]["speech"]).to(device)
        outputs = model_w2v2(**inputs)
        hidden_states = [ hs.detach().cpu().numpy().squeeze() for hs in outputs.hidden_states ]
        avg = [ np.mean(hs, axis=1) for hs in hidden_states ]
        avg_hs = np.mean(avg, axis=0)

        hidden_states_concatenation.append(avg_hs)
        logits_concatenation.append(outputs.logits)

In [None]:
if HF == True:
    torch.save(hidden_states_concatenation, 'pretrained/hidden_states_train_left_out_hf.pt')
    torch.save(logits_concatenation, 'pretrained/logits_train_left_out_hf.pt')
else:
    torch.save(hidden_states_concatenation, 'pretrained/hidden_states_train_left_out.pt')
    torch.save(logits_concatenation, 'pretrained/logits_train_left_out.pt')

In [None]:
intents_predicted = []
for i in tqdm(range(len(logits_concatenation))):
    logits = logits_concatenation[i].detach().cpu()
    intents_predicted.append(torch.argmax(logits, dim=-1).item())

intents_gt = [ label2id[dv['intent']] for dv in tqdm(dataset_left_out) ]
is_correct = [ int(np.array(intents_predicted[i]) == np.array(int(intents_gt[i]))) for i in range(len(intents_gt))]
df_left_out['prediction'] = is_correct

df_left_out['hidden_states'] = [hs.squeeze() for hs in hidden_states_concatenation]
df_left_out['hidden_states'] = df_left_out['hidden_states'].apply(lambda x: x.astype(float))

In [None]:
output_folder = os.path.join(f'data/italic/italic_train_20.csv')
df_left_out.to_csv(output_folder, index=False)

## Discretize

In [None]:
## Define abbreviations for plot and visualization
from divexplorer.FP_Divergence import abbreviateDict
abbreviations = {'total_silence':'tot_silence', \
                  'trimmed':'trim', \
                  'total_':'', \
                  'speed_rate_word':'speakRate', \
                  'speed_rate_char':'speakCharRate', \
                  'duration': 'dur'}

abbreviations_shorter = abbreviations.copy()

## Function for sorting data cohorts
def sortItemset(x, abbreviations={}):
    x = list(x)
    x.sort()
    x = ", ".join(x)
    for k, v in abbreviations.items():
        x = x.replace(k, v)
    return x

def attributes_in_itemset(itemset, attributes, alls = True):
    """ Check if attributes are in the itemset (all or at least one)
    
    Args:
        itemset (frozenset): the itemset
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        
    """
    # Avoid returning the empty itemset (i.e., info of entire dataset)
    if itemset == frozenset() and attributes:
        return False
    
    for item in itemset:
        # Get the attribute
        attr_i = item.split("=")[0]
        
        #If True, check if ALL attributes of the itemset are the input attributes.
        if alls:
            # Check if the attribute is present. If not, the itemset is not admitted
            if attr_i not in attributes:
                return False
        else:
            # Check if least one attribute. If yes, return True
            if attr_i in attributes:
                return True
    if alls:
        # All attributes of the itemset are indeed admitted
        return True
    else:
        # Otherwise, it means that we find None
        return False
    
def filter_itemset_df_by_attributes(df: pd.DataFrame, attributes: list, alls = True, itemset_col_name: str = "itemsets") -> pd.DataFrame:
    """Get the set of itemsets that have the attributes in the input list (all or at least one)
    
    Args:
        df (pd.DataFrame): the input itemsets (with their info). 
        attributes (list): list of itemset of interest
        alls (bool): If True, check if ALL attributes of the itemset are the input attributes. 
        If False, check AT LEAST one attribute of the itemset is in the input attributes.
        itemset_col_name (str) : the name of the itemset column, "itemsets" as default
        
    Returns:
        pd.DataFrame: the set of itemsets (with their info)
    """

    return df.loc[df[itemset_col_name].apply(lambda x: attributes_in_itemset(x, attributes, alls = alls))]

In [None]:
## Target for DivExplorer: 
# 'prediction' is 1 if predicted_intent == original_intent, 0 otherwise
target_col = 'prediction' 
target_metric = 'd_posr'
target_div = 'd_accuracy'
t_value_col = 't_value_tp_fn'

In [None]:
## Columns for visualization
show_cols = ['support', 'itemsets', '#errors', '#corrects', 'accuracy', \
                'd_accuracy', 't_value', 'support_count', 'length']
remapped_cols = {'tn': '#errors', 'tp': '#corrects', 'posr': 'accuracy', \
                target_metric: target_div, 't_value_tp_fn': 't_value'}

## Columns of the df file that we are going to analyze 
demo_cols = ['gender', 'age', 'region', 'nationality', 'lisp', 'education']

slot_cols = ['action', 'scenario']

rec_set_cols = ['environment', 'device', 'field']

signal_cols = ['total_silence', 'total_duration', 'trimmed_duration', 
       'n_words', 'speed_rate_word', 'speed_rate_word_trimmed']   
 
input_cols = demo_cols + signal_cols + slot_cols

In [None]:
# select the columns of interest
df_left_out_divexpl = df_left_out[[
    'prediction', 
    'intent',
    'gender', 'age', 'region', 'nationality', 'lisp', 'education', 'environment', 'device', 'field',
    'total_silence', 'total_duration', 'trimmed_duration', 'n_words', 'speed_rate_word', 'speed_rate_word_trimmed'
    ]]

In [None]:
df_left_out_divexpl['action'] = df_left_out_divexpl['intent'].apply(lambda x: x.split("_")[0])
df_left_out_divexpl['scenario'] = df_left_out_divexpl['intent'].apply(lambda x: x.split("_")[1])

## Discretize the dataframe
from divergence_utils import discretize

df_discretized_train_left_out = discretize(
    df_left_out_divexpl[input_cols+[target_col]],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized_train_left_out[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized_train_left_out[signal_cols[i]].replace(replace_values, inplace=True)

## Predict challenging subgroups IDs

In [None]:
logits_train_left_out = [ lgt.detach().cpu().numpy().squeeze() for lgt in logits_concatenation ] 

In [None]:
## Create train input
X_train_left_out = torch.cat((
    torch.tensor(logits_train_left_out),
    torch.tensor(df_left_out['hidden_states']),
    torch.tensor(df_left_out['total_silence']).unsqueeze(1),
    torch.tensor(df_left_out['n_words']).unsqueeze(1),
    torch.tensor(df_left_out['speed_rate_word']).unsqueeze(1),
    ), dim=1)

In [None]:
train_left_out_output = test(
    model,
    X_train_left_out.to(device),
    )
train_left_out_output = train_left_out_output.cpu().detach().numpy()
train_left_out_output = np.argmax(train_left_out_output, axis=1)

In [None]:
# Retrieve the rows in df_left_out for which train_left_out_output is different from 0
df_left_out['subgID'] = train_left_out_output
print(len(df_left_out))

divergent_samples = df_left_out.loc[df_left_out['subgID']!=0]
print(len(divergent_samples))

divergent_samples = divergent_samples.drop(
    columns=['hidden_states', 'prediction', 'subgID']
    )

In [None]:
df_train_80 = pd.read_csv('data/italic/train_data_80.csv')
print(len(df_train_80))

In [None]:
## concat the datasets df_train_80 and divergent_samples
df_new = pd.concat([df_train_80, divergent_samples], axis=0, ignore_index=True)

In [None]:
approach = 'csi'
num_samples = len(divergent_samples)

df_new.to_csv(f'data/italic/new_data/train_data_{approach}_k{NUM_SUBGROUPS}_{num_samples}.csv')

# Random Baseline

In [None]:
## Random baseline: assing each sample a random sample
random_pred = np.random.randint(0, NUM_SUBGROUPS+1, len(X_train_left_out))

In [None]:
# Retrieve the rows in df_left_out for which most_frequent_pred is different from 0
df_left_out['subgID'] = random_pred
print(len(df_left_out))

divergent_samples = df_left_out.loc[df_left_out['subgID']!=0]
print(len(divergent_samples))
divergent_samples = divergent_samples.sample(frac=1, random_state=42).reset_index(drop=True)
divergent_samples = divergent_samples[:num_samples]
print(len(divergent_samples))

divergent_samples = divergent_samples.drop(
    columns=['hidden_states', 'prediction', 'subgID']
    )

In [None]:
## concat the datasets df_train_80 and divergent_samples
df_new = pd.concat([df_train_80, divergent_samples], axis=0, ignore_index=True)

In [None]:
approach = 'random'

df_new.to_csv(f'data/italic/new_data/train_data_{approach}_k{NUM_SUBGROUPS}_{num_samples}.csv')

# KNN Baseline

In [None]:
## KNN baseline that assigns each sample to the most frequent class among its k nearest neighbors
from sklearn.neighbors import KNeighborsClassifier

SEED = 1
best_acc = 0
best_f1 = 0
best_k = 0

for k in range(2,10):
    
    knn = KNeighborsClassifier(n_neighbors=k)
    knn.fit(X_train, y_train_subs)

    knn_pred = knn.predict(X_val)
    acc = accuracy_score(y_val_subs, knn_pred)
    f1 = f1_score(y_val_subs, knn_pred, average='macro')

    if acc > best_acc:
        best_k = k
        best_acc = acc
        best_f1 = f1

print("Best K: ", best_k)
print("Accuracy: ", best_acc)
print("F1 Macro: ", best_f1)

In [None]:
knn = KNeighborsClassifier(n_neighbors=best_k)
knn.fit(X_train, y_train_subs)

knn_pred = knn.predict(X_train_left_out)

In [None]:
# Retrieve the rows in df_left_out for which knn_pred is different from 0
df_left_out['subgID'] = knn_pred
print(len(df_left_out))

divergent_samples = df_left_out.loc[df_left_out['subgID']!=0]
print(len(divergent_samples))
divergent_samples = divergent_samples.sample(frac=1, random_state=42).reset_index(drop=True)
divergent_samples = divergent_samples[:num_samples]
print(len(divergent_samples))

divergent_samples = divergent_samples.drop(
    columns=['hidden_states', 'prediction', 'subgID']
    )

In [None]:
df_train_80 = pd.read_csv('data/italic/train_data_80.csv')
print(len(df_train_80))

In [None]:
## concat the datasets df_train_80 and divergent_samples
df_new = pd.concat([df_train_80, divergent_samples], axis=0, ignore_index=True)

In [None]:
approach = 'knn'

df_new.to_csv(f'data/italic/new_data/train_data_{approach}_k{NUM_SUBGROUPS}_{num_samples}.csv')

# CM Baseline

In [None]:
## Create train input
X_train_left_out = torch.cat((
    torch.tensor(logits_train_left_out),
    torch.tensor(df_left_out['hidden_states']),
    torch.tensor(df_left_out['total_silence']).unsqueeze(1),
    torch.tensor(df_left_out['n_words']).unsqueeze(1),
    torch.tensor(df_left_out['speed_rate_word']).unsqueeze(1),
    ), dim=1)

In [None]:
best_model.linear3 = nn.Linear(HIDDEN_SIZE, NUM_SUBGROUPS+1).to(device)
cm_model = best_model

In [None]:
train_left_out_output = test(
    cm_model,
    X_train_left_out.to(device),
    )
train_left_out_output = train_left_out_output.cpu().detach().numpy()
train_left_out_output = np.argmax(train_left_out_output, axis=1)

In [None]:
# Retrieve the rows in df_left_out for which train_left_out_output is different from 0
df_left_out['subgID'] = train_left_out_output
print(len(df_left_out))

divergent_samples = df_left_out.loc[df_left_out['subgID']!=0]
print(len(divergent_samples))

divergent_samples = divergent_samples.drop(
    columns=['hidden_states', 'prediction', 'subgID']
    )
divergent_samples = divergent_samples.sample(frac=1, random_state=42).reset_index(drop=True)
divergent_samples = divergent_samples[:num_samples]
print(len(divergent_samples))


In [None]:
df_train_80 = pd.read_csv('data/italic/train_data_80.csv')
print(len(df_train_80))

In [None]:
## concat the datasets df_train_80 and divergent_samples
df_new = pd.concat([df_train_80, divergent_samples], axis=0, ignore_index=True)

In [None]:
approach = 'cm'

df_new.to_csv(f'data/italic/new_data/train_data_{approach}_k{NUM_SUBGROUPS}_{num_samples}.csv')

# Supervised Oracle

In [None]:
HF = False

In [None]:
## Use the wav2vec2 model to predict the intent of the held out samples
# If the prediction is not correct, this sample is considered divergent, thus it is added to the training set

## Load model
if HF == True:
    print("Loading model from HF")
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained("superb/wav2vec2-base-superb-ic").to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-ic")
else: 
    print("Loading model from local directory")
    model_checkpoint = "italic_original"
    model_w2v2 = Wav2Vec2ForSequenceClassification.from_pretrained(
        model_checkpoint, 
        output_hidden_states=True,
        local_files_only=True
        ).to(device)
    feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m")

    with open(os.path.join(model_checkpoint, "config.json"), "r") as f:
        config = json.load(f)
        label2id = config["label2id"]
        id2label = config["id2label"]

In [None]:
df_left_out = pd.read_csv('data/italic/train_data_20.csv')
dataset_left_out = Dataset.from_pandas(df_left_out)
dataset_left_out = dataset_left_out.map(lambda x: map_to_array(x, audio_col='path'))

In [None]:
logits_concatenation_left_out = []

with torch.no_grad():
    for i in tqdm(range(0, len(dataset_left_out))):
        inputs = preprocess_function(dataset_left_out[i]["speech"]).to(device)
        outputs = model_w2v2(**inputs)
        logits_concatenation_left_out.append(outputs.logits)

In [None]:
intents_predicted_left_out = []
for i in tqdm(range(len(logits_concatenation_left_out))):
    logits_left_out = logits_concatenation_left_out[i].detach().cpu()
    intents_predicted_left_out.append(torch.argmax(logits_train, dim=-1).item())

intents_gt_left_out = [ label2id[dt['intent']] for dt in tqdm(dataset_left_out) ]
is_correct_left_out = [ int(np.array(intents_predicted_left_out[i]) == np.array(int(intents_gt_left_out[i]))) for i in range(len(intents_gt_left_out))]
df_left_out['prediction'] = is_correct_left_out

In [None]:
## Take only the samples for which the prediction is not correct
df_left_out = df_left_out.loc[df_left_out['prediction']==0]
print(len(df_left_out))

In [None]:
df_left_out.to_csv('train_data_20_wrong.csv', index=False)

In [None]:
import pandas as pd 

df_left_out = pd.read_csv('train_data_20_wrong.csv')
df = pd.read_csv('data/italic/train_data_80.csv')

In [None]:
len(df_left_out), len(df)

In [None]:
df_left_out = df_left_out.sample(frac=1, random_state=42).reset_index(drop=True)
df_left_out1 = df_left_out[:num_samples]
df_new = pd.concat([df, df_left_out1], axis=0, ignore_index=True)
print(len(df), len(df_left_out1), len(df_new))

df_new.to_csv(f'data/italic/new_data/train_data_erroneous_k{K}_{num_samples}.csv', index=False)

# Clustering Baseline

In [None]:
## Discretize the dataframe
from divergence_utils import discretize

df_left_out = pd.read_csv('data/italic/train_data_20.csv')
df_discretized_rest = df_left_out[[f'speech_cluster_id_{k}' for k in [num_clusters]]]

In [None]:
print("Number of problematic subgroups: ", NUM_SUBGROUPS)

fp_divergence_i = fp_divergence_dict[config]
FPdiv = fp_divergence_i.getDivergence(th_redundancy=th_redundancy)[::-1] 
pr_bot = FPdiv.head(NUM_SUBGROUPS).copy()
itemsets = []
for i in range(NUM_SUBGROUPS):
    itemsets.append(list(pr_bot.itemsets.values[i])[0])

## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized_rest["subgID"] = 0
for i in range(0, len(df_discretized_rest)):
    for value,itemset in enumerate(itemsets):
        k, v = itemset.split("=")
        if df_discretized_rest.loc[i, k] == int(v):
            if df_discretized_rest.loc[i, "subgID"] == 0:
                df_discretized_rest.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

## Keep in df_discretized_rest only the elements with subgID != 0
df_train_rest = pd.read_csv("data/italic/train_data_20.csv")
df_train_rest = df_train_rest.loc[df_discretized_rest["subgID"]!=0][:num_samples]
print("Total number of samples in to be added: ", len(df_train_rest))

## Append df_discretized_rest to df_train
df_train = pd.read_csv("data/italic/train_data_80.csv")
df_train = df_train.append(df_train_rest, ignore_index=True)
df_train.to_csv(f"data/italic/new_data/train_data_clustering_k{NUM_SUBGROUPS}.csv", index=False)

# Metadata Oracle

In [None]:
## Discretize the dataframe
from divergence_utils import discretize

df_left_out = pd.read_csv('data/italic/train_data_20.csv')

df_discretized_rest = discretize(
    df_left_out[input_cols],
    bins=3,
    attributes=input_cols,
    strategy="quantile", 
    round_v = 2,
    min_distinct=5,
)

## Replace values with ranges: "low", "medium", "high"
replace_values = {}

for i in range(0,len(signal_cols)):

    for v in df_discretized_rest[signal_cols[i]].unique():
        if "<=" == v[0:2]:
            replace_values[v] = "low"
        elif ">" == v[0]:
            replace_values[v] = "high"
        elif "("  == v[0] and "]"  == v[-1]:
            replace_values[v] = "medium"
        else:
            raise ValueError(v)

    df_discretized_rest[signal_cols[i]].replace(replace_values, inplace=True)

In [None]:
from tqdm import tqdm

print("Number of problematic subgroups: ", NUM_SUBGROUPS)

fp_divergence_i = fp_divergence_dict[config]
FPdiv = fp_divergence_i.getDivergence(th_redundancy=th_redundancy)[::-1] 
pr_bot = FPdiv.head(NUM_SUBGROUPS).copy()
itemsets = []
for i in range(NUM_SUBGROUPS):
    itemsets.append(list(pr_bot.itemsets.values[i]))

## Create a column in the df, and assign a class to each sample:
# - 1 if the sample is in the most divergent itemset
# - 2 if the sample is in the second most divergent itemset
# - 3 if the sample is in the third most divergent itemset
# - ...
# - 0 otherwise
df_discretized_rest["subgID"] = 0
for i in range(0, len(df_discretized_rest)):
    for value,itemset in enumerate(itemsets):
        ks = []
        vs = []
        for item in itemset:
            k, v = item.split("=")
            ks.append(k)
            vs.append(v)
        if all(df_discretized_rest.loc[i, ks] == vs):
            if df_discretized_rest.loc[i, "subgID"] == 0:
                df_discretized_rest.loc[i, "subgID"] = value+1
            else:
                continue
        else:
            continue

## Keep in df_discretized_rest only the elements with subgID != 0
df_train_rest = pd.read_csv("data/italic/train_data_20.csv")
df_train_rest = df_train_rest.loc[df_discretized_rest["subgID"]!=0][:num_samples]
print("Total number of samples in to be added: ", len(df_train_rest))

## Append df_discretized_rest to df_train
df_train = pd.read_csv("data/italic/train_data_80.csv")
df_train = df_train.append(df_train_rest, ignore_index=True)
df_train.to_csv(f"data/italic/new_data/train_data_metadata_oracle_k{NUM_SUBGROUPS}.csv", index=False)