In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import (
    RobertaTokenizerFast,
    RobertaForSequenceClassification,
    TrainingArguments,
    Trainer,
    AutoConfig,
)
import pandas as pd
import librosa
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import random
from sklearn.utils import resample
import openl3

train=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_train_3sp_sampled.csv")
test=pd.read_csv("C:/Users/Administrator/Downloads/DAIC_test_3sp.csv")
torch.manual_seed(101)

class TextFeatureExtractor:
    def __init__(self):
        self.tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
        self.model = RobertaForSequenceClassification.from_pretrained('roberta-base')
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def extract_features(self, texts):
        # Tokenize input texts
        tokenized_texts = self.tokenizer.batch_encode_plus(
            texts,
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )

        # Move tensors to the appropriate device
        input_ids = tokenized_texts['input_ids'].to(self.device)
        #print(tokenized_texts['input_ids'].shape)
        attention_mask = tokenized_texts['attention_mask'].to(self.device)
        #print(input_ids.shape)
        # Extract text features from BERT model
        #with torch.no_grad():
            #outputs = self.model(input_ids, attention_mask=attention_mask)
            #print(outputs[0])
            #print(type(outputs))
            #text_features = outputs[0][:, 0, :].cpu().numpy()  # Use the [CLS] token embedding
            # print(text_features)
        #print("Text feat dim:"+str(text_features))
        return input_ids

    
class MultimodalDataset(Dataset):
        def __init__(self, data_fr, transform_text=None):
            self.data_fr = data_fr
            self.transform_text = transform_text

        def __len__(self):
            return len(self.data_fr)

        def __getitem__(self, index):
        #print(audio_file)
            text_p =  self.data_fr.loc[index, "text"]
            class_id = self.data_fr.loc[index, "Group"]
        #audio_file = self.audio_files[index]
        #text = self.texts[index]


        # Extract text features
            text_features = self.transform_text.extract_features([text_p])[0]

            return text_features, class_id

    
    
batch_size = 32

# Initialize feature extractor

text_feature_extractor = TextFeatureExtractor()
# Create dataset
dataset_tr = MultimodalDataset(train, transform_text=text_feature_extractor)
dataset_ts = MultimodalDataset(test, transform_text=text_feature_extractor)

# Create dataloader
tr_dataloader = DataLoader(dataset_tr, batch_size=batch_size, shuffle=True)
ts_dataloader = DataLoader(dataset_ts, batch_size=batch_size, shuffle=True)

print("Done data loading")


Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.weight', 'classi

Done data loading


In [8]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(LSTMClassifier, self).__init__()
        self.hidden_size = hidden_size
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size, num_classes)
    
    #propagation:
    def forward(self, x):
        #print(x.shape)
        x = torch.unsqueeze(x, 1)
        #print("After")
        #print(x.shape)
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])
        return out

class AccuracyMetric:
    def __init__(self):
        self.correct, self.total = None, None
        self.reset()

    def update(self, y_pred, y_true):
        self.correct += torch.sum(y_pred.argmax(-1) == y_true).item()
        self.total += y_true.size(0)

    def compute(self):
        return self.correct / self.total

    def reset(self):
        self.correct = 0
        self.total = 0    

# Set the hyperparameters
#input_size = embeddings.shape[1]
hidden_size = 128
num_classes = 2
learning_rate = 0.001
num_epochs = 10
batch_size = 32

model = LSTMClassifier(512, hidden_size, num_classes)
device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)
next(model.parameters()).device

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
epochs = 12

# training loop
train_loss_history = []
train_accuracy_history = []

valid_loss_history = []
valid_accuracy_history = []

accuracy = AccuracyMetric()
best_acc=0
for epoch in range(1, epochs + 1):
    print(f"[INFO] Epoch: {epoch}")
    model.train()

    batch_train_loss = []
    batch_valid_loss = []

    predicted_label=[]

    true_label=[]
    
    for osfeat, y_batch in tqdm(tr_dataloader):
        # perform single training step
        #print(X_batch.shape)
        #print(osfeat.shape)
        model.zero_grad()
        #print("Before")
        
        #osfeat = true features(~ x_batch for test)
        
        osfeat, y_batch = osfeat.to(device), y_batch.to(device)
        y_predicted = model(osfeat.float())
        #print("After Model call")
        loss = criterion(y_predicted, y_batch)
        loss.backward()
        optimizer.step()
        
        accuracy.update(y_predicted, y_batch)
        batch_train_loss.append(loss.item())

    mean_epoch_loss_train = np.mean(batch_train_loss)
    train_accuracy = accuracy.compute()

    train_loss_history.append(mean_epoch_loss_train)
    train_accuracy_history.append(train_accuracy)
    accuracy.reset()

    model.eval()
    with torch.no_grad():
        for osfeat, y_batch in tqdm(ts_dataloader):
            #print(X_batch.shape)
            true_label+=list(y_batch.tolist())
            osfeat, y_batch = osfeat.to(device), y_batch.to(device)

            y_predicted = model(osfeat.float())
            #print(y_predicted)
            loss_val = criterion(y_predicted, y_batch)
            predicted_label+=y_predicted.argmax(-1).tolist()
            accuracy.update(y_predicted, y_batch)
            batch_valid_loss.append(loss_val.item())

    mean_epoch_loss_valid = np.mean(batch_valid_loss)
    valid_accuracy = accuracy.compute()

    if valid_accuracy > best_acc:
        best_acc=valid_accuracy
        model_depression_text_LSTM_best_roberta = torch.jit.script(model)
        model_depression_text_LSTM_best_roberta.save('model_depression_text_LSTM_best_roberta.pt')
        


    valid_loss_history.append(mean_epoch_loss_valid)
    valid_accuracy_history.append(valid_accuracy)
    accuracy.reset()
    print(classification_report(true_label, predicted_label))
    print(
        f"Train loss: {mean_epoch_loss_train:0.4f}, Train accuracy: {train_accuracy: 0.4f}"
    )
    print(
        f"Validation loss: {mean_epoch_loss_valid:0.4f}, Validation accuracy: {valid_accuracy: 0.4f}"
    )


cpu
[INFO] Epoch: 1


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:25<00:00, 12.56it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:04<00:00, 14.53it/s]


              precision    recall  f1-score   support

           0       0.70      0.67      0.69      1349
           1       0.29      0.32      0.30       565

    accuracy                           0.57      1914
   macro avg       0.49      0.49      0.49      1914
weighted avg       0.58      0.57      0.57      1914

Train loss: 0.6953, Train accuracy:  0.5252
Validation loss: 0.6781, Validation accuracy:  0.5664
[INFO] Epoch: 2


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:28<00:00, 11.31it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:04<00:00, 14.97it/s]


              precision    recall  f1-score   support

           0       0.71      0.66      0.69      1349
           1       0.30      0.35      0.32       565

    accuracy                           0.57      1914
   macro avg       0.50      0.50      0.50      1914
weighted avg       0.59      0.57      0.58      1914

Train loss: 0.6912, Train accuracy:  0.5355
Validation loss: 0.6793, Validation accuracy:  0.5700
[INFO] Epoch: 3


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:27<00:00, 11.49it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 16.54it/s]


              precision    recall  f1-score   support

           0       0.72      0.30      0.42      1349
           1       0.30      0.72      0.43       565

    accuracy                           0.42      1914
   macro avg       0.51      0.51      0.42      1914
weighted avg       0.60      0.42      0.42      1914

Train loss: 0.6928, Train accuracy:  0.5300
Validation loss: 0.7366, Validation accuracy:  0.4242
[INFO] Epoch: 4


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:29<00:00, 10.77it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 15.76it/s]


              precision    recall  f1-score   support

           0       0.71      0.23      0.35      1349
           1       0.30      0.77      0.43       565

    accuracy                           0.39      1914
   macro avg       0.50      0.50      0.39      1914
weighted avg       0.59      0.39      0.37      1914

Train loss: 0.6925, Train accuracy:  0.5403
Validation loss: 0.7521, Validation accuracy:  0.3929
[INFO] Epoch: 5


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:28<00:00, 11.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 16.58it/s]


              precision    recall  f1-score   support

           0       0.71      0.43      0.53      1349
           1       0.30      0.59      0.40       565

    accuracy                           0.47      1914
   macro avg       0.51      0.51      0.47      1914
weighted avg       0.59      0.47      0.49      1914

Train loss: 0.6919, Train accuracy:  0.5416
Validation loss: 0.7152, Validation accuracy:  0.4744
[INFO] Epoch: 6


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:30<00:00, 10.52it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:04<00:00, 14.92it/s]


              precision    recall  f1-score   support

           0       0.71      0.33      0.45      1349
           1       0.30      0.67      0.41       565

    accuracy                           0.43      1914
   macro avg       0.50      0.50      0.43      1914
weighted avg       0.58      0.43      0.44      1914

Train loss: 0.6884, Train accuracy:  0.5436
Validation loss: 0.7372, Validation accuracy:  0.4300
[INFO] Epoch: 7


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:26<00:00, 11.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:04<00:00, 12.99it/s]


              precision    recall  f1-score   support

           0       0.73      0.17      0.27      1349
           1       0.30      0.85      0.44       565

    accuracy                           0.37      1914
   macro avg       0.51      0.51      0.36      1914
weighted avg       0.60      0.37      0.32      1914

Train loss: 0.6886, Train accuracy:  0.5397
Validation loss: 0.7819, Validation accuracy:  0.3694
[INFO] Epoch: 8


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:37<00:00,  8.61it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:04<00:00, 14.45it/s]


              precision    recall  f1-score   support

           0       0.69      0.36      0.47      1349
           1       0.28      0.61      0.39       565

    accuracy                           0.43      1914
   macro avg       0.48      0.48      0.43      1914
weighted avg       0.57      0.43      0.45      1914

Train loss: 0.6898, Train accuracy:  0.5458
Validation loss: 0.7310, Validation accuracy:  0.4316
[INFO] Epoch: 9


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:25<00:00, 12.36it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 18.46it/s]


              precision    recall  f1-score   support

           0       0.70      0.68      0.69      1349
           1       0.28      0.31      0.30       565

    accuracy                           0.57      1914
   macro avg       0.49      0.49      0.49      1914
weighted avg       0.58      0.57      0.57      1914

Train loss: 0.6896, Train accuracy:  0.5422
Validation loss: 0.6795, Validation accuracy:  0.5674
[INFO] Epoch: 10


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:25<00:00, 12.35it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 18.24it/s]


              precision    recall  f1-score   support

           0       0.70      0.64      0.67      1349
           1       0.29      0.35      0.32       565

    accuracy                           0.56      1914
   macro avg       0.50      0.50      0.50      1914
weighted avg       0.58      0.56      0.57      1914

Train loss: 0.6878, Train accuracy:  0.5386
Validation loss: 0.6866, Validation accuracy:  0.5569
[INFO] Epoch: 11


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:25<00:00, 12.64it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 18.15it/s]


              precision    recall  f1-score   support

           0       0.70      0.49      0.58      1349
           1       0.29      0.51      0.37       565

    accuracy                           0.50      1914
   macro avg       0.50      0.50      0.48      1914
weighted avg       0.58      0.50      0.52      1914

Train loss: 0.6881, Train accuracy:  0.5423
Validation loss: 0.7076, Validation accuracy:  0.4953
[INFO] Epoch: 12


100%|████████████████████████████████████████████████████████████████████████████████| 321/321 [00:25<00:00, 12.75it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 60/60 [00:03<00:00, 17.07it/s]


              precision    recall  f1-score   support

           0       0.71      0.72      0.71      1349
           1       0.30      0.28      0.29       565

    accuracy                           0.59      1914
   macro avg       0.50      0.50      0.50      1914
weighted avg       0.58      0.59      0.59      1914

Train loss: 0.6899, Train accuracy:  0.5401
Validation loss: 0.6685, Validation accuracy:  0.5920
