In [None]:
# Install prerequesite libraries
!pip install torch torchvision transformers datasets

In [73]:
# Load dataset
from datasets import load_dataset

# Load the GoEmotions dataset
dataset = load_dataset("go_emotions", "simplified")

# Access the train, validation, and test splits
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

train_dataset[0]



  0%|          | 0/3 [00:00<?, ?it/s]

{'text': "My favourite food is anything I didn't have to cook myself.",
 'labels': [27],
 'id': 'eebbqej'}

In [74]:
from transformers import AutoTokenizer

# Replace 'bert-base-uncased' with the pre-trained model of your choice
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Max sequence length in dataset = 30
def preprocess_dataset(example):
    # Tokenize the input text and return the encoded result
    encoding = tokenizer(example["text"], truncation=True, padding="max_length", max_length=30)
    return encoding

from torch.utils.data import DataLoader

# Preprocess the dataset
train_dataset = train_dataset.map(preprocess_dataset, batched=True)
val_dataset = val_dataset.map(preprocess_dataset, batched=True)
test_dataset = test_dataset.map(preprocess_dataset, batched=True)

# Set dataset format to PyTorch tensors
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)




In [43]:
import torch
from transformers import AutoModelForSequenceClassification

# Replace 'bert-base-uncased' with the pre-trained model of your choice
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=14)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [75]:
label_map = {
    18: "love",
    11: "disgust",
    17: "joy",
    25: "sadness",
    13: "excitement",
    14: "fear",
    2: "anger",
    15: "gratitude",
    12: "embarrassment",
    4: "approval",
    10: "disapproval",
    7: "curiosity",
    23: "relief",
}

def filter_selected_labels(example):
    for i, val in enumerate(example["labels"]):
        if val == 1 and i in label_map.keys():
            return True
    return False

train_dataset = train_dataset.filter(filter_selected_labels)
val_dataset = val_dataset.filter(filter_selected_labels)
test_dataset = test_dataset.filter(filter_selected_labels)

def map_selected_labels(example):
    example["labels"] = [i for i, val in enumerate(example["labels"]) if val == 1 and i in label_map.keys()]
    example["labels"] = [label_map[i] for i in example["labels"]]
    return example

train_dataset = train_dataset.map(map_selected_labels)
val_dataset = val_dataset.map(map_selected_labels)
test_dataset = test_dataset.map(map_selected_labels)


Filter:   0%|          | 0/43410 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5426 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5427 [00:00<?, ? examples/s]

In [76]:
label_counts = {label: 0 for label in label_map.values()}

for example in train_dataset:
    for label_idx in example["labels"]:
        label = label_map[int(label_idx)]
        label_counts[label] += 1

print("Label statistics:")
for label, count in label_counts.items():
    print(f"{label}: {count}")


Label statistics:
love: 0
disgust: 0
joy: 0
sadness: 0
excitement: 0
fear: 0
anger: 0
gratitude: 0
embarrassment: 0
approval: 0
disapproval: 0
curiosity: 0
relief: 0


In [77]:
print("Number of examples in filtered train dataset:", len(train_dataset))
print("Number of examples in filtered validation dataset:", len(val_dataset))
print("Number of examples in filtered test dataset:", len(test_dataset))

Number of examples in filtered train dataset: 0
Number of examples in filtered validation dataset: 0
Number of examples in filtered test dataset: 0
