<a href="https://colab.research.google.com/github/mbauergit/Mental-Health-Competition/blob/main/train_bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


#### Import Data ####

In [None]:
FEATURES_PATH = '/content/drive/MyDrive/Mental Health Competition/Data/train_features.csv'
LABELS_PATH = '/content/drive/MyDrive/Mental Health Competition/Data/train_labels.csv'

In [None]:
features = pd.read_csv(FEATURES_PATH, index_col=0)
labels = pd.read_csv(LABELS_PATH, index_col=0)
# display(features)
# display(labels)


In [None]:
merged_df = pd.merge(features, labels, left_index=True, right_index=True).reset_index(drop=True)
display(merged_df)

# # Oversample minority class for DepressedMood
# minority_class = merged_df[merged_df['DepressedMood'] == 1]
# majority_class = merged_df[merged_df['DepressedMood'] == 0]
# minority_upsampled = minority_class.sample(n=len(majority_class), replace=True, random_state=42)
# upsampled_df = pd.concat([majority_class, minority_upsampled])
# display(upsampled_df)
# print(upsampled_df['DepressedMood'].value_counts())

Unnamed: 0,NarrativeLE,NarrativeCME,DepressedMood,MentalIllnessTreatmentCurrnt,HistoryMentalIllnessTreatmnt,SuicideAttemptHistory,SuicideThoughtHistory,SubstanceAbuseProblem,MentalHealthProblem,DiagnosisAnxiety,...,Argument,SchoolProblem,RecentCriminalLegalProblem,SuicideNote,SuicideIntentDisclosed,DisclosedToIntimatePartner,DisclosedToOtherFamilyMember,DisclosedToFriend,InjuryLocationType,WeaponType1
0,V (XX XX) shot himself in a motor vehicle.The ...,V (XX XX) shot himself in a motor vehicle.The ...,0,0,0,0,1,0,0,0,...,0,0,0,0,1,0,1,0,2,5
1,V was XXXX. V was found in the basement of his...,V was XXXX. V was found in the basement of hi...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,6
2,V was XXXX. V was found in his residence unres...,V was XXXX. V was found in his residence suffe...,0,0,0,0,1,1,0,0,...,1,0,0,0,1,0,1,0,1,5
3,"The victim, a XX XX who had recently returned ...",On the day of the fatal event in the early mor...,1,0,0,0,1,0,0,0,...,0,0,0,0,0,0,0,0,1,5
4,XX XX V found deceased at home by his grandpar...,XX XX V found deceased at home by his grandpar...,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,1,1,1,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3995,The victim was a XX XX who was discovered at h...,The victim was a XX XX who was discovered at h...,0,1,1,1,0,1,1,0,...,0,0,0,0,1,1,0,0,1,9
3996,The V is a XX XX. The cause of death is acute ...,The V is a XX XX. The Cause of death is Acute ...,0,0,1,1,1,1,1,0,...,0,0,0,0,1,0,0,1,1,9
3997,V was a XX XX. V was found deceased in his bed...,"V was a XX XX. On the day of the incident, V r...",0,0,0,1,1,1,0,0,...,0,0,0,0,1,0,0,0,1,9
3998,"At 0100 hours, local police received a call of...","At approximately 0041 hours, officers responde...",0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,5


#### Train Functions ####

In [None]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [None]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            x = self.dropout(pooled_output)
            logits = self.fc(x)
            return logits

In [None]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    print("Training...")
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        labels = labels - 1
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [None]:
from sklearn.metrics import accuracy_score, classification_report
import torch.nn.functional as F  # Importing functional to use softmax

def evaluate(model, data_loader, device, threshold=0.5):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            # Get model outputs (logits)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Apply softmax to get probabilities
            probabilities = F.softmax(outputs, dim=1)  # Use F.sigmoid for binary classification

            # # Use the custom threshold to get predictions
            # preds = (probabilities[:, 1] > threshold).long()  # Change '1' to the index of class `1`

            # For multiclass
            # Use argmax to get the predicted class for multi-class classification
            preds = probabilities.argmax(dim=1) + 1

            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())

    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

In [None]:
def predict_sentiment(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return "positive" if preds.item() == 1 else "negative"

#### Running Code ####

In [None]:
display(labels)

Unnamed: 0_level_0,DepressedMood,MentalIllnessTreatmentCurrnt,HistoryMentalIllnessTreatmnt,SuicideAttemptHistory,SuicideThoughtHistory,SubstanceAbuseProblem,MentalHealthProblem,DiagnosisAnxiety,DiagnosisDepressionDysthymia,DiagnosisBipolar,...,Argument,SchoolProblem,RecentCriminalLegalProblem,SuicideNote,SuicideIntentDisclosed,DisclosedToIntimatePartner,DisclosedToOtherFamilyMember,DisclosedToFriend,InjuryLocationType,WeaponType1
uid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
aaaf,0,0,0,0,1,0,0,0,0,0,...,0,0,0,0,1,0,1,0,2,5
aaby,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,6
aacl,0,0,0,0,1,1,0,0,0,0,...,1,0,0,0,1,0,1,0,1,5
aacn,1,0,0,0,1,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,5
aadb,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,1,1,1,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
fhri,0,1,1,1,0,1,1,0,1,0,...,0,0,0,0,1,1,0,0,1,9
fhrn,0,0,1,1,1,1,1,0,1,0,...,0,0,0,0,1,0,0,1,1,9
fhsx,0,0,0,1,1,1,0,0,0,0,...,0,0,0,0,1,0,0,0,1,9
fhtq,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,1,5


In [None]:
# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = 12
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5

# Train a model for each variable
for i in range(24, 25):
    print(f"Training model for variable {merged_df.columns[i]}")


    # Oversample minority class for DepressedMood
    counts = merged_df.iloc[:, i].value_counts()
    print(counts)
    majority_class = merged_df[merged_df.iloc[:, i] == counts.idxmax()]
    minority_classes = merged_df[merged_df.iloc[:, i] != counts.idxmax()]
    display(minority_classes)

    upsampled_df = majority_class.copy()
    # up sample the minroity classes
    for j in range(1, len(counts)):
        minority_class = minority_classes[minority_classes.iloc[:, i] == counts.index[j]]
        minority_upsampled = minority_class.sample(n=len(majority_class), replace=True, random_state=42)
        upsampled_df = pd.concat([upsampled_df, minority_upsampled])
    # display(upsampled_df)
    print(upsampled_df.iloc[:, i].value_counts())


    texts = (upsampled_df['NarrativeCME'] + ' ' + upsampled_df['NarrativeCME']).to_numpy()
    new_labels = upsampled_df.iloc[:, i].to_numpy()

    print(texts[0])
    print(new_labels[0])

    train_texts, val_texts, train_labels, val_labels = train_test_split(texts, new_labels, test_size=0.2, random_state=42)

    tokenizer = BertTokenizer.from_pretrained(bert_model_name)
    train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
    val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    print("train_dataloader:", len(train_dataloader))
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = BERTClassifier(bert_model_name, num_classes).to(device)

    optimizer = AdamW(model.parameters(), lr=learning_rate)
    total_steps = len(train_dataloader) * num_epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train(model, train_dataloader, optimizer, scheduler, device)
        accuracy, report = evaluate(model, val_dataloader, device)
        print(f"Validation Accuracy: {accuracy:.4f}")
        print(report)

    torch.save(model.state_dict(), "bert_classifier_" + str(i) + ".pth")


Training model for variable WeaponType1
WeaponType1
5     2008
6     1439
9      286
3      109
7       46
8       45
2       28
10      17
4       11
12       4
1        4
11       3
Name: count, dtype: int64


Unnamed: 0,NarrativeLE,NarrativeCME,DepressedMood,MentalIllnessTreatmentCurrnt,HistoryMentalIllnessTreatmnt,SuicideAttemptHistory,SuicideThoughtHistory,SubstanceAbuseProblem,MentalHealthProblem,DiagnosisAnxiety,...,Argument,SchoolProblem,RecentCriminalLegalProblem,SuicideNote,SuicideIntentDisclosed,DisclosedToIntimatePartner,DisclosedToOtherFamilyMember,DisclosedToFriend,InjuryLocationType,WeaponType1
1,V was XXXX. V was found in the basement of his...,V was XXXX. V was found in the basement of hi...,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,6
4,XX XX V found deceased at home by his grandpar...,XX XX V found deceased at home by his grandpar...,0,0,0,0,0,0,0,0,...,0,0,0,0,1,0,1,1,1,6
7,"Victim XX17 history of depression, anxiety and...","The victim was a XX, XX. The victim committed...",1,1,1,0,1,0,1,0,...,0,0,0,0,0,0,0,0,1,6
9,"V (XX XX) was found by her father, hanging in ...",V (XX XX) was found hanging in the bathroom fr...,1,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,6
10,V (XX XX) was found by her father hanging from...,"V (XX XX) was found by her father, hanging fro...",0,1,1,1,0,0,1,0,...,0,0,0,0,0,0,0,0,1,6
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3986,Officers were dispatched to a residence regard...,The victim was a XX XX who was found in his re...,1,1,1,0,0,0,1,0,...,0,0,0,0,0,0,0,0,1,6
3993,V is a XX XX who died by suicide via acute com...,V is a XX XX who died by suicide via acute com...,0,1,1,0,0,1,1,1,...,0,0,0,1,0,0,0,0,4,9
3995,The victim was a XX XX who was discovered at h...,The victim was a XX XX who was discovered at h...,0,1,1,1,0,1,1,0,...,0,0,0,0,1,1,0,0,1,9
3996,The V is a XX XX. The cause of death is acute ...,The V is a XX XX. The Cause of death is Acute ...,0,0,1,1,1,1,1,0,...,0,0,0,0,1,0,0,1,1,9


WeaponType1
5     2008
6     2008
9     2008
3     2008
7     2008
8     2008
2     2008
10    2008
4     2008
12    2008
1     2008
11    2008
Name: count, dtype: int64
V (XX XX) shot himself in a motor vehicle.The V's mother called law enforcement and reported the V as missing and suicidal with a firearm.The V was located in a vehicle in a retail parking lot.  When law enforcement approached the vehicle the V shot himself.There are no other circumstances. V (XX XX) shot himself in a motor vehicle.The V's mother called law enforcement and reported the V as missing and suicidal with a firearm.The V was located in a vehicle in a retail parking lot.  When law enforcement approached the vehicle the V shot himself.There are no other circumstances.
5




train_dataloader: 1205




Epoch 1/4
Training...
Validation Accuracy: 0.9967
              precision    recall  f1-score   support

           1       1.00      1.00      1.00       404
           2       1.00      1.00      1.00       374
           3       1.00      1.00      1.00       394
           4       1.00      1.00      1.00       405
           5       0.98      0.99      0.99       370
           6       1.00      0.97      0.99       449
           7       1.00      1.00      1.00       426
           8       1.00      1.00      1.00       396
           9       0.98      1.00      0.99       395
          10       1.00      1.00      1.00       381
          11       1.00      1.00      1.00       411
          12       1.00      1.00      1.00       415

    accuracy                           1.00      4820
   macro avg       1.00      1.00      1.00      4820
weighted avg       1.00      1.00      1.00      4820

Epoch 2/4
Training...
Validation Accuracy: 0.9981
              precision    recall