In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import pandas as pd
import librosa
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import random
from sklearn.utils import resample
import openl3

train=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_train_3sp_sampled.csv")
test=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_test_3sp.csv")
torch.manual_seed(101)

class TextFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def extract_features(self, texts):
        # Tokenize input texts
        tokenized_texts = self.tokenizer.batch_encode_plus(
            texts,
            padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Move tensors to the appropriate device
        input_ids = tokenized_texts['input_ids'].to(self.device)
        attention_mask = tokenized_texts['attention_mask'].to(self.device)
        #print(input_ids.shape)
        # Extract text features from BERT model
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask=attention_mask)
            #print(type(outputs))
            text_features = outputs[0][:, 0, :].cpu().numpy()  # Use the [CLS] token embedding
        #print("Text feat dim:"+str(text_features))
        return text_features

    
class MultimodalDataset(Dataset):
        def __init__(self, data_fr, transform_text=None):
            self.data_fr = data_fr
            self.transform_text = transform_text

        def __len__(self):
            return len(self.data_fr)

        def __getitem__(self, index):
        #print(audio_file)
            text_p =  self.data_fr.loc[index, "text"]
            class_id = self.data_fr.loc[index, "Group"]
        #audio_file = self.audio_files[index]
        #text = self.texts[index]


        # Extract text features
            text_features = self.transform_text.extract_features([text_p])[0]

            return text_features, class_id

    
    
batch_size = 32

# Initialize feature extractor

text_feature_extractor = TextFeatureExtractor()
# Create dataset
dataset_tr = MultimodalDataset(train, transform_text=text_feature_extractor)
dataset_ts = MultimodalDataset(test, transform_text=text_feature_extractor)

# Create dataloader
tr_dataloader = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True)
ts_dataloader = DataLoader(dataset_ts, batch_size=batch_size, shuffle=True)

print("Done data loading")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Done data loading


In [2]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LSTMClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        #print(x.shape)
        x = torch.unsqueeze(x, 1)
        #print("After")
        #print(x.shape)
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])
        return out

class AccuracyMetric:
    def __init__(self):
        self.correct, self.total = None, None
        self.reset()

    def update(self, y_pred, y_true):
        self.correct += torch.sum(y_pred.argmax(-1) == y_true).item()
        self.total += y_true.size(0)

    def compute(self):
        return self.correct / self.total

    def reset(self):
        self.correct = 0
        self.total = 0    

# Set the hyperparameters
#input_size = embeddings.shape[1]
hidden_size = 128
num_classes = 2
learning_rate = 0.001
num_epochs = 10
batch_size = 32

model = LSTMClassifier(768, hidden_size, num_classes)
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)
next(model.parameters()).device

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
epochs = 12

# training loop
train_loss_history = []
train_accuracy_history = []

valid_loss_history = []
valid_accuracy_history = []

accuracy = AccuracyMetric()
best_acc=0

for epoch in range(1, epochs + 1):
    print(f"[INFO] Epoch: {epoch}")
    model.train()

    batch_train_loss = []
    batch_valid_loss = []

    predicted_label=[]

    true_label=[]
    
    for osfeat, y_batch in tqdm(tr_dataloader):
        # perform single training step
        #print(X_batch.shape)
        #print(osfeat.shape)
        model.zero_grad()
        #print("Before")
        
        osfeat, y_batch = osfeat.to(device), y_batch.to(device)
        y_predicted = model(osfeat.float())
        #print("After Model call")
        loss = criterion(y_predicted, y_batch)
        loss.backward()
        optimizer.step()
        
        accuracy.update(y_predicted, y_batch)
        batch_train_loss.append(loss.item())

    mean_epoch_loss_train = np.mean(batch_train_loss)
    train_accuracy = accuracy.compute()

    train_loss_history.append(mean_epoch_loss_train)
    train_accuracy_history.append(train_accuracy)
    accuracy.reset()

    model.eval()
    with torch.no_grad():
        for osfeat, y_batch in tqdm(ts_dataloader):
            #print(X_batch.shape)
            true_label+=list(y_batch.tolist())
            osfeat, y_batch = osfeat.to(device), y_batch.to(device)

            y_predicted = model(osfeat.float())
            #print(y_predicted)
            loss_val = criterion(y_predicted, y_batch)
            predicted_label+=y_predicted.argmax(-1).tolist()
            accuracy.update(y_predicted, y_batch)
            batch_valid_loss.append(loss_val.item())

    mean_epoch_loss_valid = np.mean(batch_valid_loss)
    valid_accuracy = accuracy.compute()

    if valid_accuracy > best_acc:
        best_acc=valid_accuracy
        model_depression_text_LSTM_best = torch.jit.script(model)
        model_depression_text_LSTM_best.save('model_depression_text_LSTM_best.pt')
        


    valid_loss_history.append(mean_epoch_loss_valid)
    valid_accuracy_history.append(valid_accuracy)
    accuracy.reset()
    print(classification_report(true_label, predicted_label))
    print(
        f"Train loss: {mean_epoch_loss_train:0.4f}, Train accuracy: {train_accuracy: 0.4f}"
    )
    print(
        f"Validation loss: {mean_epoch_loss_valid:0.4f}, Validation accuracy: {valid_accuracy: 0.4f}"
    )


cpu
[INFO] Epoch: 1


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:53:14<00:00, 21.17s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [11:08<00:00, 11.14s/it]


              precision    recall  f1-score   support

           0       0.76      0.42      0.55      1349
           1       0.33      0.68      0.45       565

    accuracy                           0.50      1914
   macro avg       0.55      0.55      0.50      1914
weighted avg       0.64      0.50      0.52      1914

Train loss: 0.6836, Train accuracy:  0.5443
Validation loss: 0.7037, Validation accuracy:  0.5010
[INFO] Epoch: 2


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [2:43:21<00:00, 30.54s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [11:56<00:00, 11.95s/it]


              precision    recall  f1-score   support

           0       0.73      0.62      0.67      1349
           1       0.33      0.44      0.37       565

    accuracy                           0.57      1914
   macro avg       0.53      0.53      0.52      1914
weighted avg       0.61      0.57      0.58      1914

Train loss: 0.6699, Train accuracy:  0.5862
Validation loss: 0.6819, Validation accuracy:  0.5674
[INFO] Epoch: 3


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:09:44<00:00, 13.03s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [11:46<00:00, 11.77s/it]


              precision    recall  f1-score   support

           0       0.76      0.33      0.46      1349
           1       0.32      0.76      0.45       565

    accuracy                           0.45      1914
   macro avg       0.54      0.54      0.45      1914
weighted avg       0.63      0.45      0.46      1914

Train loss: 0.6572, Train accuracy:  0.6032
Validation loss: 0.7570, Validation accuracy:  0.4545
[INFO] Epoch: 4


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:07:22<00:00, 12.59s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [11:33<00:00, 11.56s/it]


              precision    recall  f1-score   support

           0       0.75      0.32      0.45      1349
           1       0.31      0.74      0.44       565

    accuracy                           0.44      1914
   macro avg       0.53      0.53      0.44      1914
weighted avg       0.62      0.44      0.44      1914

Train loss: 0.6518, Train accuracy:  0.6117
Validation loss: 0.7764, Validation accuracy:  0.4436
[INFO] Epoch: 5


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:12:56<00:00, 13.63s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:41<00:00, 12.69s/it]


              precision    recall  f1-score   support

           0       0.72      0.76      0.74      1349
           1       0.34      0.30      0.32       565

    accuracy                           0.63      1914
   macro avg       0.53      0.53      0.53      1914
weighted avg       0.61      0.63      0.62      1914

Train loss: 0.6431, Train accuracy:  0.6237
Validation loss: 0.6490, Validation accuracy:  0.6254
[INFO] Epoch: 6


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:10:35<00:00, 13.19s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:15<00:00, 12.26s/it]


              precision    recall  f1-score   support

           0       0.73      0.51      0.60      1349
           1       0.32      0.56      0.41       565

    accuracy                           0.53      1914
   macro avg       0.53      0.54      0.51      1914
weighted avg       0.61      0.53      0.55      1914

Train loss: 0.6390, Train accuracy:  0.6277
Validation loss: 0.7210, Validation accuracy:  0.5251
[INFO] Epoch: 7


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:11:46<00:00, 13.42s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:14<00:00, 12.24s/it]


              precision    recall  f1-score   support

           0       0.72      0.49      0.58      1349
           1       0.31      0.56      0.40       565

    accuracy                           0.51      1914
   macro avg       0.52      0.52      0.49      1914
weighted avg       0.60      0.51      0.53      1914

Train loss: 0.6265, Train accuracy:  0.6508
Validation loss: 0.7329, Validation accuracy:  0.5084
[INFO] Epoch: 8


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:13:34<00:00, 13.75s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:16<00:00, 12.27s/it]


              precision    recall  f1-score   support

           0       0.72      0.66      0.69      1349
           1       0.33      0.40      0.36       565

    accuracy                           0.58      1914
   macro avg       0.53      0.53      0.53      1914
weighted avg       0.61      0.58      0.59      1914

Train loss: 0.6203, Train accuracy:  0.6537
Validation loss: 0.6811, Validation accuracy:  0.5810
[INFO] Epoch: 9


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:12:33<00:00, 13.56s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:17<00:00, 12.29s/it]


              precision    recall  f1-score   support

           0       0.73      0.27      0.39      1349
           1       0.30      0.76      0.43       565

    accuracy                           0.41      1914
   macro avg       0.52      0.51      0.41      1914
weighted avg       0.60      0.41      0.40      1914

Train loss: 0.6112, Train accuracy:  0.6648
Validation loss: 0.8641, Validation accuracy:  0.4133
[INFO] Epoch: 10


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:11:59<00:00, 13.46s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [14:07<00:00, 14.13s/it]


              precision    recall  f1-score   support

           0       0.71      0.93      0.80      1349
           1       0.33      0.09      0.14       565

    accuracy                           0.68      1914
   macro avg       0.52      0.51      0.47      1914
weighted avg       0.60      0.68      0.61      1914

Train loss: 0.6018, Train accuracy:  0.6683
Validation loss: 0.6588, Validation accuracy:  0.6776
[INFO] Epoch: 11


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:11:36<00:00, 13.39s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:12<00:00, 12.21s/it]


              precision    recall  f1-score   support

           0       0.70      0.42      0.52      1349
           1       0.29      0.58      0.39       565

    accuracy                           0.46      1914
   macro avg       0.50      0.50      0.46      1914
weighted avg       0.58      0.46      0.48      1914

Train loss: 0.5988, Train accuracy:  0.6689
Validation loss: 0.8048, Validation accuracy:  0.4650
[INFO] Epoch: 12


100%|██████████████████████████████████████████████████████████████████████████████| 321/321 [1:12:53<00:00, 13.63s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [12:20<00:00, 12.34s/it]

              precision    recall  f1-score   support

           0       0.72      0.74      0.73      1349
           1       0.33      0.31      0.32       565

    accuracy                           0.61      1914
   macro avg       0.52      0.52      0.52      1914
weighted avg       0.60      0.61      0.61      1914

Train loss: 0.5798, Train accuracy:  0.6869
Validation loss: 0.6756, Validation accuracy:  0.6118



