In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import TrainerCallback, TrainerState, TrainerControl
from transformers import AutoModelForAudioClassification, TrainingArguments, Trainer, AutoFeatureExtractor, AutoConfig
import evaluate
from sklearn.metrics import f1_score

In [2]:
df = pd.read_pickle('AnnoMI-ast-new.pkl')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
# make a dataset where intelocutor is client

df_client = df[df['interlocutor'] == 'client']
df_client = df_client[['client_ast_emb', 'client_talk_type']]
df_client.rename(columns={'client_ast_emb': 'inputs', 'client_talk_type': 'labels'}, inplace=True)

# df_therapist = df[df['interlocutor'] == 'therapist']
# df_therapist = df_therapist[['therapist_ast_emb', 'main_therapist_behaviour']]
# df_therapist.rename(columns={'therapist_ast_emb': 'inputs', 'main_therapist_behaviour': 'labels'}, inplace=True)

In [6]:
labels = df_client['labels'].unique()
# labels = df_therapist['labels'].unique()
labels

array(['neutral', 'change', 'sustain'], dtype=object)

In [7]:
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i  # store as integer
    id2label[i] = label  # key is also integer


In [8]:
id2label[2]

'sustain'

In [9]:
train_data, test_data = train_test_split(df_client, test_size=0.2, random_state=42)
# train_data, test_data = train_test_split(df_therapist, test_size=0.2, random_state=42)

In [10]:
class CustomDataset(Dataset):
    def __init__(self, dataframe, label2id):  # Add label2id as an argument
        self.data = dataframe
        self.label2id = label2id  # Store it

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = self.data.iloc[idx]['labels']
        encoded_label = self.label2id[label]  # Use label2id to encode the label
        
        input_values = self.data.iloc[idx]['inputs']
        input_values = input_values.squeeze(0)  # Remove the unnecessary dimension
        
        return {
            "input_values": input_values,
            "labels": encoded_label  # Use the encoded label
        }


train_dataset = CustomDataset(train_data, label2id)  # Pass label2id when initializing
eval_dataset = CustomDataset(test_data, label2id)    # Pass label2id when initializing


In [11]:
feature_extractor = AutoFeatureExtractor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

In [12]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=eval_pred.label_ids)

def f1_score_macro(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)  # Convert logits to class index
    return {"f1_macro": f1_score(labels, predictions, average="macro")}

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        log_prob = F.log_softmax(inputs, dim=-1)
        prob = torch.exp(log_prob)
        return F.nll_loss(
            ((1 - prob) ** self.gamma) * log_prob,
            targets,
            reduction=self.reduction
        )
    
class ThresholdEarlyStoppingCallback(TrainerCallback):
    def on_evaluate(self, args, state, control, metrics, **kwargs):
        f1 = metrics['eval_f1_macro']
        if f1 > 0.60:
            control.should_training_stop = True
        return control

In [13]:
num_labels = len(id2label)
config = AutoConfig.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593", num_labels=num_labels, 
                                    id2label=id2label, label2id=label2id)
model = AutoModelForAudioClassification.from_pretrained(
    "MIT/ast-finetuned-audioset-10-10-0.4593", config=config, ignore_mismatched_sizes=True
)
model.classifier.dense = torch.nn.Linear(model.config.hidden_size, num_labels)
model = model.to(device)

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([3]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
training_args = TrainingArguments(
    output_dir="./output_ast",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=200,
    warmup_ratio=0.1,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",
    save_total_limit=1,
    push_to_hub=False,
)

In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=feature_extractor,
    compute_metrics=f1_score_macro,
    callbacks=[],
)

In [16]:
trainer.train()

Epoch,Training Loss,Validation Loss,F1 Macro
0,No log,0.880848,0.323396
1,0.965800,0.859473,0.328906
2,0.861000,0.852322,0.360371
4,0.803400,0.866924,0.367122


KeyboardInterrupt: 

In [None]:
from sklearn.metrics import classification_report

# Get predictions
predictions, labels, _ = trainer.predict(eval_dataset)
predictions = np.argmax(predictions, axis=1)

# Print classification report
print(classification_report(labels, predictions, target_names=label2id.keys()))

                 precision    recall  f1-score   support

       question       0.53      0.51      0.52       364
therapist_input       0.49      0.28      0.36       187
     reflection       0.54      0.28      0.37       282
          other       0.57      0.85      0.69       455

       accuracy                           0.55      1288
      macro avg       0.53      0.48      0.48      1288
   weighted avg       0.54      0.55      0.52      1288

