In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import pandas as pd
import librosa
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import random
from sklearn.utils import resample
import openl3

train=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_train_3sp_sampled.csv")
test=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_test_3sp.csv")

class TextFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def extract_features(self, texts):
        # Tokenize input texts
        tokenized_texts = self.tokenizer.batch_encode_plus(
            texts,
            padding='max_length',
            #padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Move tensors to the appropriate device
        input_ids = tokenized_texts['input_ids'].to(self.device)
        attention_mask = tokenized_texts['attention_mask'].to(self.device)
        #print(input_ids.shape)
        # Extract text features from BERT model
        #with torch.no_grad():
            #outputs = self.model(input_ids, attention_mask=attention_mask)
            #print(type(outputs))
           #text_features = outputs[0][:, 0, :].cpu().numpy()  # Use the [CLS] token embedding
        #print("Text feat dim:"+str(text_features))
        return input_ids

    
class MultimodalDataset(Dataset):
        def __init__(self, data_fr, transform_text=None):
            self.data_fr = data_fr
            self.transform_text = transform_text

        def __len__(self):
            return len(self.data_fr)

        def __getitem__(self, index):
        #print(audio_file)
            text_p =  self.data_fr.loc[index, "text"]
            class_id = self.data_fr.loc[index, "Group"]
        #audio_file = self.audio_files[index]
        #text = self.texts[index]


        # Extract text features
            text_features = self.transform_text.extract_features([text_p])[0]

            return text_features, class_id

    
    
batch_size = 32

# Initialize feature extractor

text_feature_extractor = TextFeatureExtractor()
# Create dataset
dataset_tr = MultimodalDataset(train, transform_text=text_feature_extractor)
dataset_ts = MultimodalDataset(test, transform_text=text_feature_extractor)

# Create dataloader
tr_dataloader = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True)
ts_dataloader = DataLoader(dataset_ts, batch_size=batch_size, shuffle=True)

print("Done data loading")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Done data loading


In [5]:
#class LSTMClassifier(nn.Module):
    #def __init__(self, input_size, hidden_size, num_classes):
        #super(LSTMClassifier, self).__init__()
        #self.hidden_size = hidden_size
        #self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        #self.fc = nn.Linear(hidden_size, num_classes)
    
    #def forward(self, x):
        ##print(x.shape)
        #x = torch.unsqueeze(x, 1)
        ##print("After")
        ##print(x.shape)
        #_, (h_n, _) = self.lstm(x)
        #out = self.fc(h_n[-1])
        #return out

class AccuracyMetric:
    def __init__(self):
        self.correct, self.total = None, None
        self.reset()

    def update(self, y_pred, y_true):
        self.correct += torch.sum(y_pred.argmax(-1) == y_true).item()
        self.total += y_true.size(0)

    def compute(self):
        return self.correct / self.total

    def reset(self):
        self.correct = 0
        self.total = 0    

# Set the hyperparameters
#input_size = embeddings.shape[1]
hidden_size = 128
num_classes = 2
learning_rate = 0.001
num_epochs = 10
batch_size = 32

model = torch.jit.load('model_depression_text_LSTM_best_bert.pt')
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)
next(model.parameters()).device

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
epochs = 22

# training loop
train_loss_history = []
train_accuracy_history = []

valid_loss_history = []
valid_accuracy_history = []

accuracy = AccuracyMetric()
best_acc=0




#model = torch.jit.load('model_depression_text_LSTM_best_bert.pt')
for epoch in range(1, epochs + 1):
    
    print(f"[INFO] Epoch: {epoch}")
    model.train()

    batch_train_loss = []
    batch_valid_loss = []

    predicted_label=[]

    true_label=[]
    
    for osfeat, y_batch in tqdm(tr_dataloader):
        # perform single training step
        #print(X_batch.shape)
        #print(osfeat.shape)
        model.zero_grad()
        #print("Before")
        
        osfeat, y_batch = osfeat.to(device), y_batch.to(device)
        y_predicted = model(osfeat.float())
        #print("After Model call")
        loss = criterion(y_predicted, y_batch)
        loss.backward()
        optimizer.step()
        
        accuracy.update(y_predicted, y_batch)
        batch_train_loss.append(loss.item())

    mean_epoch_loss_train = np.mean(batch_train_loss)
    train_accuracy = accuracy.compute()

    train_loss_history.append(mean_epoch_loss_train)
    train_accuracy_history.append(train_accuracy)
    accuracy.reset()

    model.eval()
    with torch.no_grad():
        for osfeat, y_batch in tqdm(ts_dataloader):
            #print(X_batch.shape)
            true_label+=list(y_batch.tolist())
            osfeat, y_batch = osfeat.to(device), y_batch.to(device)

            y_predicted = model(osfeat.float())
            #print(y_predicted)
            loss_val = criterion(y_predicted, y_batch)
            predicted_label+=y_predicted.argmax(-1).tolist()
            accuracy.update(y_predicted, y_batch)
            batch_valid_loss.append(loss_val.item())

    mean_epoch_loss_valid = np.mean(batch_valid_loss)
    valid_accuracy = accuracy.compute()

    if valid_accuracy > best_acc:
        best_acc=valid_accuracy
        model_depression_text_LSTM_final = torch.jit.script(model)
        model_depression_text_LSTM_final.save('model_depression_text_LSTM_final.pt')
        


    valid_loss_history.append(mean_epoch_loss_valid)
    valid_accuracy_history.append(valid_accuracy)
    accuracy.reset()
    print(classification_report(true_label, predicted_label))
    print(
        f"Train loss: {mean_epoch_loss_train:0.4f}, Train accuracy: {train_accuracy: 0.4f}"
    )
    print(
        f"Validation loss: {mean_epoch_loss_valid:0.4f}, Validation accuracy: {valid_accuracy: 0.4f}"
    )




cpu
[INFO] Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:41<00:00,  7.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:06<00:00,  9.69it/s]


              precision    recall  f1-score   support

           0       0.71      0.78      0.74      1349
           1       0.31      0.24      0.27       565

    accuracy                           0.62      1914
   macro avg       0.51      0.51      0.51      1914
weighted avg       0.59      0.62      0.60      1914

Train loss: 0.6964, Train accuracy:  0.5279
Validation loss: 0.6676, Validation accuracy:  0.6186
[INFO] Epoch: 2


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:38<00:00,  8.42it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:05<00:00, 10.46it/s]


              precision    recall  f1-score   support

           0       0.71      0.64      0.67      1349
           1       0.31      0.38      0.34       565

    accuracy                           0.56      1914
   macro avg       0.51      0.51      0.51      1914
weighted avg       0.59      0.56      0.58      1914

Train loss: 0.6948, Train accuracy:  0.5225
Validation loss: 0.6813, Validation accuracy:  0.5643
[INFO] Epoch: 3


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:44<00:00,  7.28it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:06<00:00,  9.48it/s]


              precision    recall  f1-score   support

           0       0.71      1.00      0.83      1349
           1       1.00      0.00      0.00       565

    accuracy                           0.71      1914
   macro avg       0.85      0.50      0.42      1914
weighted avg       0.79      0.71      0.58      1914

Train loss: 0.6921, Train accuracy:  0.5253
Validation loss: 0.6249, Validation accuracy:  0.7053
[INFO] Epoch: 4


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:15<00:00,  4.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:12<00:00,  4.94it/s]


              precision    recall  f1-score   support

           0       0.64      0.03      0.06      1349
           1       0.29      0.96      0.45       565

    accuracy                           0.30      1914
   macro avg       0.47      0.49      0.25      1914
weighted avg       0.54      0.30      0.17      1914

Train loss: 0.6935, Train accuracy:  0.5268
Validation loss: 0.7776, Validation accuracy:  0.3046
[INFO] Epoch: 5


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:12<00:00,  4.40it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:11<00:00,  5.35it/s]


              precision    recall  f1-score   support

           0       0.71      0.95      0.81      1349
           1       0.40      0.07      0.12       565

    accuracy                           0.69      1914
   macro avg       0.55      0.51      0.47      1914
weighted avg       0.62      0.69      0.61      1914

Train loss: 0.6940, Train accuracy:  0.5194
Validation loss: 0.6499, Validation accuracy:  0.6938
[INFO] Epoch: 6


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:14<00:00,  4.29it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:11<00:00,  5.14it/s]


              precision    recall  f1-score   support

           0       0.71      1.00      0.83      1349
           1       1.00      0.00      0.00       565

    accuracy                           0.71      1914
   macro avg       0.85      0.50      0.42      1914
weighted avg       0.79      0.71      0.58      1914

Train loss: 0.6953, Train accuracy:  0.5217
Validation loss: 0.6311, Validation accuracy:  0.7053
[INFO] Epoch: 7


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:14<00:00,  4.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.12it/s]


              precision    recall  f1-score   support

           0       0.71      1.00      0.83      1349
           1       1.00      0.00      0.00       565

    accuracy                           0.71      1914
   macro avg       0.85      0.50      0.42      1914
weighted avg       0.79      0.71      0.58      1914

Train loss: 0.6934, Train accuracy:  0.5154
Validation loss: 0.6164, Validation accuracy:  0.7053
[INFO] Epoch: 8


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:10<00:00,  4.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.38it/s]


              precision    recall  f1-score   support

           0       0.69      0.09      0.16      1349
           1       0.29      0.90      0.44       565

    accuracy                           0.33      1914
   macro avg       0.49      0.50      0.30      1914
weighted avg       0.57      0.33      0.24      1914

Train loss: 0.6999, Train accuracy:  0.5218
Validation loss: 0.7113, Validation accuracy:  0.3292
[INFO] Epoch: 9


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:59<00:00,  5.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:07<00:00,  7.92it/s]


              precision    recall  f1-score   support

           0       0.65      0.08      0.15      1349
           1       0.29      0.89      0.44       565

    accuracy                           0.32      1914
   macro avg       0.47      0.49      0.29      1914
weighted avg       0.54      0.32      0.23      1914

Train loss: 0.6919, Train accuracy:  0.5192
Validation loss: 0.7135, Validation accuracy:  0.3224
[INFO] Epoch: 10


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:49<00:00,  6.44it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:07<00:00,  8.15it/s]


              precision    recall  f1-score   support

           0       0.65      0.07      0.13      1349
           1       0.29      0.91      0.44       565

    accuracy                           0.32      1914
   macro avg       0.47      0.49      0.28      1914
weighted avg       0.55      0.32      0.22      1914

Train loss: 0.6925, Train accuracy:  0.5263
Validation loss: 0.7210, Validation accuracy:  0.3182
[INFO] Epoch: 11


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:50<00:00,  6.32it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:07<00:00,  8.01it/s]


              precision    recall  f1-score   support

           0       0.70      0.97      0.82      1349
           1       0.29      0.03      0.06       565

    accuracy                           0.69      1914
   macro avg       0.50      0.50      0.44      1914
weighted avg       0.58      0.69      0.59      1914

Train loss: 0.6928, Train accuracy:  0.5140
Validation loss: 0.6613, Validation accuracy:  0.6912
[INFO] Epoch: 12


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:48<00:00,  6.63it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:07<00:00,  8.29it/s]


              precision    recall  f1-score   support

           0       0.72      0.54      0.62      1349
           1       0.32      0.50      0.39       565

    accuracy                           0.53      1914
   macro avg       0.52      0.52      0.51      1914
weighted avg       0.60      0.53      0.55      1914

Train loss: 0.6931, Train accuracy:  0.5223
Validation loss: 0.7004, Validation accuracy:  0.5329
[INFO] Epoch: 13


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:48<00:00,  6.55it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:07<00:00,  8.02it/s]


              precision    recall  f1-score   support

           0       0.66      0.07      0.12      1349
           1       0.29      0.92      0.44       565

    accuracy                           0.32      1914
   macro avg       0.48      0.49      0.28      1914
weighted avg       0.55      0.32      0.22      1914

Train loss: 0.6902, Train accuracy:  0.5182
Validation loss: 0.7414, Validation accuracy:  0.3182
[INFO] Epoch: 14


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:01<00:00,  5.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:08<00:00,  7.48it/s]


              precision    recall  f1-score   support

           0       0.73      0.57      0.64      1349
           1       0.32      0.49      0.39       565

    accuracy                           0.55      1914
   macro avg       0.53      0.53      0.51      1914
weighted avg       0.61      0.55      0.56      1914

Train loss: 0.6944, Train accuracy:  0.5185
Validation loss: 0.6951, Validation accuracy:  0.5455
[INFO] Epoch: 15


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:51<00:00,  6.24it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:08<00:00,  6.95it/s]


              precision    recall  f1-score   support

           0       0.73      0.57      0.64      1349
           1       0.32      0.48      0.39       565

    accuracy                           0.55      1914
   macro avg       0.52      0.53      0.51      1914
weighted avg       0.61      0.55      0.57      1914

Train loss: 0.6934, Train accuracy:  0.5227
Validation loss: 0.6999, Validation accuracy:  0.5465
[INFO] Epoch: 16


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:54<00:00,  5.93it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:08<00:00,  6.94it/s]


              precision    recall  f1-score   support

           0       0.70      0.97      0.82      1349
           1       0.25      0.02      0.04       565

    accuracy                           0.69      1914
   macro avg       0.48      0.50      0.43      1914
weighted avg       0.57      0.69      0.59      1914

Train loss: 0.6919, Train accuracy:  0.5195
Validation loss: 0.6658, Validation accuracy:  0.6912
[INFO] Epoch: 17


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:54<00:00,  5.86it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:08<00:00,  7.02it/s]


              precision    recall  f1-score   support

           0       0.71      0.23      0.35      1349
           1       0.30      0.77      0.43       565

    accuracy                           0.39      1914
   macro avg       0.50      0.50      0.39      1914
weighted avg       0.58      0.39      0.37      1914

Train loss: 0.6882, Train accuracy:  0.5284
Validation loss: 0.7112, Validation accuracy:  0.3903
[INFO] Epoch: 18


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:03<00:00,  5.04it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.02it/s]


              precision    recall  f1-score   support

           0       0.66      0.06      0.12      1349
           1       0.29      0.92      0.44       565

    accuracy                           0.32      1914
   macro avg       0.48      0.49      0.28      1914
weighted avg       0.55      0.32      0.21      1914

Train loss: 0.6894, Train accuracy:  0.5286
Validation loss: 0.7455, Validation accuracy:  0.3171
[INFO] Epoch: 19


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:59<00:00,  5.41it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:10<00:00,  5.95it/s]


              precision    recall  f1-score   support

           0       0.65      0.07      0.13      1349
           1       0.29      0.90      0.44       565

    accuracy                           0.32      1914
   macro avg       0.47      0.49      0.29      1914
weighted avg       0.54      0.32      0.22      1914

Train loss: 0.6935, Train accuracy:  0.5235
Validation loss: 0.7226, Validation accuracy:  0.3192
[INFO] Epoch: 20


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:01<00:00,  5.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.60it/s]


              precision    recall  f1-score   support

           0       0.71      0.98      0.82      1349
           1       0.33      0.02      0.05       565

    accuracy                           0.70      1914
   macro avg       0.52      0.50      0.43      1914
weighted avg       0.60      0.70      0.59      1914

Train loss: 0.6923, Train accuracy:  0.5244
Validation loss: 0.6572, Validation accuracy:  0.6975
[INFO] Epoch: 21


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:57<00:00,  5.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:08<00:00,  7.15it/s]


              precision    recall  f1-score   support

           0       0.67      0.07      0.12      1349
           1       0.29      0.92      0.44       565

    accuracy                           0.32      1914
   macro avg       0.48      0.49      0.28      1914
weighted avg       0.56      0.32      0.21      1914

Train loss: 0.6954, Train accuracy:  0.5193
Validation loss: 0.7567, Validation accuracy:  0.3182
[INFO] Epoch: 22


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [01:00<00:00,  5.27it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:09<00:00,  6.61it/s]


              precision    recall  f1-score   support

           0       0.71      0.94      0.81      1349
           1       0.33      0.07      0.12       565

    accuracy                           0.68      1914
   macro avg       0.52      0.51      0.46      1914
weighted avg       0.60      0.68      0.60      1914

Train loss: 0.6907, Train accuracy:  0.5255
Validation loss: 0.6798, Validation accuracy:  0.6823
