A deep learning project that combines DistilBERT with Convolutional Neural Networks (CNN) for multi-label emotion classification using the GoEmotions dataset.
This project implements a hybrid neural network architecture that leverages the contextual understanding of DistilBERT with the feature extraction capabilities of CNNs to classify text into multiple emotion categories simultaneously. The model can detect up to 28 different emotions in a single piece of text.
The project uses the GoEmotions dataset, which contains:
- ~58,000 carefully curated Reddit comments
- 28 emotion categories including neutral, joy, sadness, anger, fear, surprise, and more
- Multi-label annotations (each text can have multiple emotions)
- Base Model: DistilBERT (distilbert-base-uncased)
- CNN Layer: 1D convolution with 128 filters and kernel size of 5
- Pooling: Global max pooling
- Output: Fully connected layer with sigmoid activation for multi-label classification
- Partial Fine-tuning: Only the last 3 layers of DistilBERT are fine-tuned
- Multi-label Support: Uses BCEWithLogitsLoss for handling multiple emotions per text
- Dropout Regularization: 42% dropout rate to prevent overfitting
pip install torch transformers pandas numpy scikit-learn matplotlib seaborn kagglehub- Download and preprocess the dataset:
import kagglehub
path = kagglehub.dataset_download("debarshichanda/goemotions")
# Data preprocessing and label encoding happens automatically- Train the model:
# Model trains for 5 epochs with cosine annealing learning rate scheduler
# Training metrics are displayed for each epoch# Load the saved model
model = DistilBertCNN(num_labels=28)
model.load_state_dict(torch.load("GoEmotionsDistilBertCNN.pth"))
model.eval()
# Make predictions
query = "I'm so excited about this new opportunity!"
encoded_input = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
output = model(encoded_input["input_ids"], encoded_input["attention_mask"])
probabilities = torch.sigmoid(output)
# Get top 5 emotions
top_values, top_indices = torch.topk(probabilities[0], k=5)
for idx, value in zip(top_indices, top_values):
emotion = emotions[idx.item()]
confidence = value.item() * 100
print(f"{emotion}: {confidence:.2f}%")The trained model achieves the following metrics on the development set:
- F1 Score: 0.4576 (macro-average)
- Precision: 0.6113 (macro-average)
- Recall: 0.3933 (macro-average)
- AUC: 0.6908 (macro-average)
- MCC: 0.5676 (Matthews Correlation Coefficient)
- Batch Size: 32
- Learning Rate: 7e-5 (AdamW optimizer)
- Max Sequence Length: 128 tokens
- Epochs: 5
- Loss Function: BCEWithLogitsLoss
- Scheduler: CosineAnnealingLR
- Prediction Threshold: 0.51
├── GoEmotionsDistilBertCNN.pth # Trained model weights
├── distil_bert_tokenizer_dir/ # Saved tokenizer files
│ ├── tokenizer_config.json
│ ├── vocab.txt
│ └── ...
├── main.py # Main training and evaluation script
└── README.md # This file
- Converts comma-separated emotion labels to multi-hot encoded vectors
- Tokenizes text using DistilBERT tokenizer with padding and truncation
- Creates PyTorch DataLoaders for efficient batch processing
- Inherits from
nn.Module - Combines DistilBERT embeddings with 1D CNN for feature extraction
- Uses global max pooling and dropout for regularization
- Comprehensive evaluation including F1, Precision, Recall, AUC, and MCC
- Multi-label confusion matrices for each emotion category
- Visualization of model performance across all emotion classes
- Experiment with different CNN architectures and filter sizes
- Implement class weighting to handle label imbalance
- Add attention mechanisms between BERT and CNN layers
- Explore ensemble methods with multiple model architectures
- Fine-tune the prediction threshold for better precision-recall balance
If you use this code or the GoEmotions dataset, please cite:
@inproceedings{demszky2020goemotions,
title={GoEmotions: A Dataset of Fine-Grained Emotions},
author={Demszky, Dorottya and Movshovitz-Attias, Dana and Ko, Jeongwoo and Cowen, Alan and Nemade, Gaurav and Ravi, Sujith},
booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
year={2020}
}