In [61]:
from transformers import AutoTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import torch.nn as nn
import os
from dotenv import load_dotenv
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from huggingface_hub import HfApi

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
load_dotenv()


cuda


True

In [100]:
train_dataset = load_dataset('go_emotions', split='train')
test_dataset = load_dataset('go_emotions', split='test')

train_dataset[0]

class CustomDataset(Dataset):
    def __init__(self, input_ids, attention_masks, labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.labels = labels
    
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx]
        }

In [63]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

def convert_to_multihot(label_list, num_labels):
    multihot = torch.zeros(num_labels)
    
    for label in label_list:
        multihot[label] = 1
    
    return multihot

def tokenize_and_encode_dataset(dataset, num_labels=28):
    # Initialize lists to hold inputs and labels
    input_ids = []
    attention_masks = []
    labels = []
    
    # Iterate over the dataset and tokenize each sample
    for i in tqdm(range(len(dataset))):
        tokenized_input = tokenizer(
            dataset[i]['text'],
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        
        # Convert the labels to multihot encoding
        multihot_labels = convert_to_multihot(
            dataset[i]['labels'],
            num_labels=num_labels
        )
        
        # Append the data to the lists
        input_ids.append(tokenized_input['input_ids'].squeeze(0))
        attention_masks.append(tokenized_input['attention_mask'].squeeze(0))
        labels.append(multihot_labels)
    
    # Convert lists to tensors
    input_ids = torch.stack(input_ids)
    attention_masks = torch.stack(attention_masks)
    labels = torch.stack(labels)
    
    # Create a custom Dataset
    dataset_tensor = CustomDataset(input_ids, attention_masks, labels)
    
    return dataset_tensor
        
training = tokenize_and_encode_dataset(train_dataset)
testing = tokenize_and_encode_dataset(test_dataset)

100%|██████████| 43410/43410 [00:23<00:00, 1834.33it/s]
100%|██████████| 5427/5427 [00:02<00:00, 1863.25it/s]


In [101]:
base_model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased', 
    num_labels=28, 
    problem_type='multi_label_classification',
    )

model = base_model.to(device)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
torch.compile(model)

training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='steps',
    eval_steps=50,
    save_strategy='steps',
    save_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=20,
    per_device_eval_batch_size=20,
    num_train_epochs=3,
    weight_decay=.01,
    logging_dir='./logs',
    logging_steps=100,
    load_best_model_at_end=True,
    gradient_accumulation_steps=4,
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training,
    eval_dataset=testing
)

In [107]:
trainer.train()

                                      
  0%|          | 0/27 [1:05:21<?, ?it/s]          

{'loss': 0.3677, 'grad_norm': 1.5913376808166504, 'learning_rate': 1.8769987699876998e-05, 'epoch': 0.18}


                                        
  0%|          | 0/27 [1:06:22<?, ?it/s]          

{'loss': 0.1615, 'grad_norm': 0.7621739506721497, 'learning_rate': 1.7539975399753997e-05, 'epoch': 0.37}


                                        
  0%|          | 0/27 [1:07:22<?, ?it/s]          

{'loss': 0.1414, 'grad_norm': 0.5251739621162415, 'learning_rate': 1.6309963099630997e-05, 'epoch': 0.55}


                                        
  0%|          | 0/27 [1:08:22<?, ?it/s]          

{'loss': 0.1262, 'grad_norm': 0.8104351162910461, 'learning_rate': 1.5079950799507997e-05, 'epoch': 0.74}


                                        
  0%|          | 0/27 [1:09:22<?, ?it/s]          

{'loss': 0.1167, 'grad_norm': 0.6981897354125977, 'learning_rate': 1.3849938499384994e-05, 'epoch': 0.92}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                               
  0%|          | 0/27 [1:09:59<?, ?it/s]
[A

{'eval_loss': 0.10558442026376724, 'eval_runtime': 10.9875, 'eval_samples_per_second': 493.926, 'eval_steps_per_second': 24.755, 'epoch': 1.0}


                                        
  0%|          | 0/27 [1:10:34<?, ?it/s]          

{'loss': 0.1085, 'grad_norm': 0.7822399139404297, 'learning_rate': 1.2619926199261994e-05, 'epoch': 1.11}


                                        
  0%|          | 0/27 [1:11:33<?, ?it/s]          

{'loss': 0.1023, 'grad_norm': 0.7207314372062683, 'learning_rate': 1.1389913899138992e-05, 'epoch': 1.29}


                                        
  0%|          | 0/27 [1:12:32<?, ?it/s]          

{'loss': 0.1013, 'grad_norm': 0.7514426708221436, 'learning_rate': 1.0159901599015991e-05, 'epoch': 1.47}


                                        
  0%|          | 0/27 [1:13:31<?, ?it/s]          

{'loss': 0.0987, 'grad_norm': 0.8425426483154297, 'learning_rate': 8.92988929889299e-06, 'epoch': 1.66}


                                        
  0%|          | 0/27 [1:14:30<?, ?it/s]           

{'loss': 0.0952, 'grad_norm': 0.908227264881134, 'learning_rate': 7.699876998769989e-06, 'epoch': 1.84}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                                
  0%|          | 0/27 [1:15:31<?, ?it/s]
[A

{'eval_loss': 0.09182830154895782, 'eval_runtime': 10.8735, 'eval_samples_per_second': 499.106, 'eval_steps_per_second': 25.015, 'epoch': 2.0}


                                        
  0%|          | 0/27 [1:15:41<?, ?it/s]           

{'loss': 0.0945, 'grad_norm': 0.9004964828491211, 'learning_rate': 6.469864698646987e-06, 'epoch': 2.03}


                                        
  0%|          | 0/27 [1:16:40<?, ?it/s]           

{'loss': 0.0922, 'grad_norm': 0.9855219125747681, 'learning_rate': 5.2398523985239855e-06, 'epoch': 2.21}


                                        
  0%|          | 0/27 [1:17:39<?, ?it/s]           

{'loss': 0.0907, 'grad_norm': 1.1675050258636475, 'learning_rate': 4.009840098400984e-06, 'epoch': 2.4}


                                        
  0%|          | 0/27 [1:18:38<?, ?it/s]           

{'loss': 0.0918, 'grad_norm': 0.9246616363525391, 'learning_rate': 2.779827798277983e-06, 'epoch': 2.58}


                                        
  0%|          | 0/27 [1:19:37<?, ?it/s]           

{'loss': 0.0908, 'grad_norm': 0.9724268913269043, 'learning_rate': 1.5498154981549817e-06, 'epoch': 2.76}


                                        
  0%|          | 0/27 [1:20:36<?, ?it/s]           

{'loss': 0.0875, 'grad_norm': 1.0728187561035156, 'learning_rate': 3.198031980319803e-07, 'epoch': 2.95}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

                                                 
[A                                                
  0%|          | 0/27 [1:21:04<?, ?it/s]
[A

{'eval_loss': 0.08953775465488434, 'eval_runtime': 10.9095, 'eval_samples_per_second': 497.458, 'eval_steps_per_second': 24.932, 'epoch': 3.0}


                                        
100%|██████████| 1626/1626 [16:44<00:00,  1.62it/s]

{'train_runtime': 1004.3632, 'train_samples_per_second': 129.664, 'train_steps_per_second': 1.619, 'train_loss': 0.12241348334637368, 'epoch': 3.0}





TrainOutput(global_step=1626, training_loss=0.12241348334637368, metrics={'train_runtime': 1004.3632, 'train_samples_per_second': 129.664, 'train_steps_per_second': 1.619, 'total_flos': 1.723934893473792e+16, 'train_loss': 0.12241348334637368, 'epoch': 2.99677567941041})

In [108]:
trainer.save_model('./saved_model')

In [109]:
def predict_sentiment(text, model, tokenizer, device):
    
    model.to(device)
    model.eval()

    inputs = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=512,
        return_tensors='pt'
    )
    inputs = {k: v for k, v in inputs.items() if k != 'token_type_ids'}
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
        
        # print(f"Outputs type: {type(outputs)}")
        # print(f"Outputs: {outputs}")
        
        predictions = torch.sigmoid(outputs.logits)
        # linear_layer = torch.nn.Linear(outputs.pooler_output.shape[1], 28).to(device)
        # predictions = torch.sigmoid(linear_layer(outputs.pooler_output))
        
        
    predicted_labels = (predictions > .3).int()
    
    predicted_indicies = torch.where(predicted_labels[0] == 1)[0].tolist()
    
    return predicted_indicies, predictions[0]
    

In [129]:
class_labels = {          
    0: 'admiration',
    1: 'amusement',
    2: 'anger',
    3: 'annoyance',
    4: 'approval',
    5: 'caring',
    6: 'confusion',
    7: 'curiosity',
    8: 'desire',
    9: 'disappointment',
    10: 'disapproval',
    11: 'disgust',
    12: 'embarrassment', 
    13: 'excitement',
    14: 'fear',
    15: 'gratitude',
    16: 'grief',
    17: 'joy',
    18: 'love',
    19: 'nervousness',
    20: 'optimism',
    21: 'pride',
    22: 'realization',
    23: 'relief',
    24: 'remorse',
    25: 'sadness',
    26: 'surprise',
    27: 'neutral'
}


text = 'ugh this is so annoying'

labels, probabilities = predict_sentiment(text, model, tokenizer, device)

probabilities = probabilities.tolist()
top_3_prob = sorted(range(len(probabilities)), key=lambda i: probabilities[i], reverse=True)[:3]

for i in top_3_prob:
    label_name = class_labels.get(i, f'Label {i}')
    print(f"- {label_name}: {probabilities[i]:.4f}")  

# print(probabilities)
# print(labels)

# print('Predicted Labels:\n')
# for label in labels:
#     print(f"- {class_labels.get(label, f'Label {label}')}: {probabilities[label]:.4f}")

- annoyance: 0.3605
- anger: 0.3100
- disgust: 0.1839
