In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import pandas as pd
import librosa
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import random
from sklearn.utils import resample
import openl3

train=pd.read_csv("C:/Users/Administrator/Desktop/Text Modelling/DAIC_train_3sp_sampled.csv")
test=pd.read_csv("C:/Users/Administrator/Desktop/Text Modelling/DAIC_test_3sp.csv")
torch.manual_seed(101)

class TextFeatureExtractor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def extract_features(self, texts):
        # Tokenize input texts
        tokenized_texts = self.tokenizer.batch_encode_plus(
            texts,
            padding='max_length',
            #padding=True,
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Move tensors to the appropriate device
        input_ids = tokenized_texts['input_ids'].to(self.device)
        attention_mask = tokenized_texts['attention_mask'].to(self.device)
        #print(input_ids.shape)
        # Extract text features from BERT model
        #with torch.no_grad():
            #outputs = self.model(input_ids, attention_mask=attention_mask)
            #print(type(outputs))
           #text_features = outputs[0][:, 0, :].cpu().numpy()  # Use the [CLS] token embedding
        #print("Text feat dim:"+str(text_features))
        return input_ids

    
class MultimodalDataset(Dataset):
        def __init__(self, data_fr, transform_text=None):
            self.data_fr = data_fr
            self.transform_text = transform_text

        def __len__(self):
            return len(self.data_fr)

        def __getitem__(self, index):
        #print(audio_file)
            text_p =  self.data_fr.loc[index, "text"]
            class_id = self.data_fr.loc[index, "Group"]
        #audio_file = self.audio_files[index]
        #text = self.texts[index]


        # Extract text features
            text_features = self.transform_text.extract_features([text_p])[0]

            return text_features, class_id

    
    
batch_size = 32

# Initialize feature extractor

text_feature_extractor = TextFeatureExtractor()
# Create dataset
dataset_tr = MultimodalDataset(train, transform_text=text_feature_extractor)
dataset_ts = MultimodalDataset(test, transform_text=text_feature_extractor)

# Create dataloader
tr_dataloader = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True)
ts_dataloader = DataLoader(dataset_ts, batch_size=batch_size, shuffle=True)

print("Done data loading")


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<class '__main__.MultimodalDataset'>
Done data loading


In [6]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LSTMClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        #print(x.shape)
    
        x = torch.unsqueeze(x, 1)
        #print("After")
        #print(x.shape)
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])
        return out

class AccuracyMetric:
    def __init__(self):
        self.correct, self.total = None, None
        self.reset()

    def update(self, y_pred, y_true):
        self.correct += torch.sum(y_pred.argmax(-1) == y_true).item()
        self.total += y_true.size(0)

    def compute(self):
        return self.correct / self.total

    def reset(self):
        self.correct = 0
        self.total = 0    

# Set the hyperparameters
#input_size = embeddings.shape[1]
hidden_size = 128
num_classes = 2
learning_rate = 0.001
num_epochs = 10
batch_size = 32

model = LSTMClassifier(512, hidden_size, num_classes)
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)
next(model.parameters()).device

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
epochs = 12

# training loop
train_loss_history = []
train_accuracy_history = []

valid_loss_history = []
valid_accuracy_history = []

accuracy = AccuracyMetric()
best_acc=0
for epoch in range(1, epochs + 1):
    print(f"[INFO] Epoch: {epoch}")
    model.train()

    batch_train_loss = []
    batch_valid_loss = []

    predicted_label=[]

    true_label=[]
    
    for osfeat, y_batch in tqdm(tr_dataloader):
        # perform single training step
        #print(X_batch.shape)
        #print(osfeat.shape)
        model.zero_grad()
        #print("Before")
        
        osfeat, y_batch = osfeat.to(device), y_batch.to(device)
        y_predicted = model(osfeat.float())
        #print("After Model call")
        loss = criterion(y_predicted, y_batch)
        loss.backward()
        optimizer.step()
        
        accuracy.update(y_predicted, y_batch)
        batch_train_loss.append(loss.item())

    mean_epoch_loss_train = np.mean(batch_train_loss)
    train_accuracy = accuracy.compute()

    train_loss_history.append(mean_epoch_loss_train)
    train_accuracy_history.append(train_accuracy)
    accuracy.reset()

    model.eval()
    with torch.no_grad():
        for osfeat, y_batch in tqdm(ts_dataloader):
            #print(X_batch.shape)
            true_label+=list(y_batch.tolist())
            osfeat, y_batch = osfeat.to(device), y_batch.to(device)

            y_predicted = model(osfeat.float())
            #print(y_predicted)
            loss_val = criterion(y_predicted, y_batch)
            predicted_label+=y_predicted.argmax(-1).tolist()
            accuracy.update(y_predicted, y_batch)
            batch_valid_loss.append(loss_val.item())

    mean_epoch_loss_valid = np.mean(batch_valid_loss)
    valid_accuracy = accuracy.compute()

    if valid_accuracy > best_acc:
        best_acc=valid_accuracy
        model_depression_text_LSTM_best_bert = torch.jit.script(model)
        model_depression_text_LSTM_best_bert.save('model_depression_text_LSTM_best_bert.pt')
        


    valid_loss_history.append(mean_epoch_loss_valid)
    valid_accuracy_history.append(valid_accuracy)
    accuracy.reset()
    print(classification_report(true_label, predicted_label))
    print(
        f"Train loss: {mean_epoch_loss_train:0.4f}, Train accuracy: {train_accuracy: 0.4f}"
    )
    print(
        f"Validation loss: {mean_epoch_loss_valid:0.4f}, Validation accuracy: {valid_accuracy: 0.4f}"
    )


cpu
[INFO] Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [05:25<00:00,  1.01s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:33<00:00,  1.77it/s]


              precision    recall  f1-score   support

           0       0.69      0.05      0.09      1349
           1       0.29      0.95      0.45       565

    accuracy                           0.31      1914
   macro avg       0.49      0.50      0.27      1914
weighted avg       0.57      0.31      0.20      1914

Train loss: 0.6967, Train accuracy:  0.5207
Validation loss: 0.8198, Validation accuracy:  0.3135
[INFO] Epoch: 2


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [06:01<00:00,  1.12s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:30<00:00,  1.94it/s]


              precision    recall  f1-score   support

           0       0.72      0.11      0.19      1349
           1       0.30      0.90      0.45       565

    accuracy                           0.34      1914
   macro avg       0.51      0.51      0.32      1914
weighted avg       0.60      0.34      0.27      1914

Train loss: 0.6938, Train accuracy:  0.5204
Validation loss: 0.7806, Validation accuracy:  0.3443
[INFO] Epoch: 3


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [04:21<00:00,  1.23it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:32<00:00,  1.87it/s]


              precision    recall  f1-score   support

           0       0.70      0.74      0.72      1349
           1       0.28      0.24      0.26       565

    accuracy                           0.59      1914
   macro avg       0.49      0.49      0.49      1914
weighted avg       0.58      0.59      0.58      1914

Train loss: 0.6949, Train accuracy:  0.5225
Validation loss: 0.6790, Validation accuracy:  0.5930
[INFO] Epoch: 4


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [04:34<00:00,  1.17it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:33<00:00,  1.79it/s]


              precision    recall  f1-score   support

           0       0.71      1.00      0.83      1349
           1       0.67      0.00      0.01       565

    accuracy                           0.71      1914
   macro avg       0.69      0.50      0.42      1914
weighted avg       0.69      0.71      0.58      1914

Train loss: 0.6951, Train accuracy:  0.5223
Validation loss: 0.6312, Validation accuracy:  0.7053
[INFO] Epoch: 5


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [03:56<00:00,  1.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:21<00:00,  2.77it/s]


              precision    recall  f1-score   support

           0       0.72      0.75      0.73      1349
           1       0.33      0.30      0.31       565

    accuracy                           0.62      1914
   macro avg       0.52      0.52      0.52      1914
weighted avg       0.60      0.62      0.61      1914

Train loss: 0.6960, Train accuracy:  0.5200
Validation loss: 0.6762, Validation accuracy:  0.6155
[INFO] Epoch: 6


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [03:06<00:00,  1.72it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:22<00:00,  2.66it/s]


              precision    recall  f1-score   support

           0       0.71      0.93      0.80      1349
           1       0.33      0.08      0.13       565

    accuracy                           0.68      1914
   macro avg       0.52      0.51      0.47      1914
weighted avg       0.59      0.68      0.61      1914

Train loss: 0.6941, Train accuracy:  0.5254
Validation loss: 0.6712, Validation accuracy:  0.6782
[INFO] Epoch: 7


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [03:12<00:00,  1.66it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:29<00:00,  2.04it/s]


              precision    recall  f1-score   support

           0       0.69      0.16      0.26      1349
           1       0.29      0.82      0.43       565

    accuracy                           0.36      1914
   macro avg       0.49      0.49      0.35      1914
weighted avg       0.57      0.36      0.31      1914

Train loss: 0.6946, Train accuracy:  0.5160
Validation loss: 0.7272, Validation accuracy:  0.3574
[INFO] Epoch: 8


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [04:43<00:00,  1.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:33<00:00,  1.79it/s]


              precision    recall  f1-score   support

           0       0.70      0.99      0.82      1349
           1       0.26      0.01      0.02       565

    accuracy                           0.70      1914
   macro avg       0.48      0.50      0.42      1914
weighted avg       0.57      0.70      0.59      1914

Train loss: 0.6954, Train accuracy:  0.5220
Validation loss: 0.6569, Validation accuracy:  0.6980
[INFO] Epoch: 9


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [05:25<00:00,  1.01s/it]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:33<00:00,  1.81it/s]


              precision    recall  f1-score   support

           0       0.71      0.43      0.54      1349
           1       0.30      0.57      0.39       565

    accuracy                           0.47      1914
   macro avg       0.50      0.50      0.46      1914
weighted avg       0.58      0.47      0.49      1914

Train loss: 0.6942, Train accuracy:  0.5213
Validation loss: 0.6874, Validation accuracy:  0.4728
[INFO] Epoch: 10


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [04:43<00:00,  1.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:35<00:00,  1.67it/s]


              precision    recall  f1-score   support

           0       0.69      0.06      0.12      1349
           1       0.29      0.93      0.45       565

    accuracy                           0.32      1914
   macro avg       0.49      0.50      0.28      1914
weighted avg       0.57      0.32      0.21      1914

Train loss: 0.6931, Train accuracy:  0.5257
Validation loss: 0.7525, Validation accuracy:  0.3197
[INFO] Epoch: 11


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [03:23<00:00,  1.58it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:21<00:00,  2.78it/s]


              precision    recall  f1-score   support

           0       0.65      0.05      0.09      1349
           1       0.29      0.94      0.45       565

    accuracy                           0.31      1914
   macro avg       0.47      0.49      0.27      1914
weighted avg       0.55      0.31      0.20      1914

Train loss: 0.6947, Train accuracy:  0.5159
Validation loss: 0.7782, Validation accuracy:  0.3114
[INFO] Epoch: 12


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [03:58<00:00,  1.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:26<00:00,  2.31it/s]

              precision    recall  f1-score   support

           0       0.69      0.19      0.30      1349
           1       0.29      0.79      0.42       565

    accuracy                           0.37      1914
   macro avg       0.49      0.49      0.36      1914
weighted avg       0.57      0.37      0.34      1914

Train loss: 0.6915, Train accuracy:  0.5284
Validation loss: 0.7107, Validation accuracy:  0.3694



