In [1]:
from datasets import load_dataset

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

In [3]:
dataset = load_dataset("emotion")

Using custom data configuration default
Reusing dataset emotion (/Users/daohuei/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})

In [4]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

distil_bert = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
distil_bert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [5]:
class_mapping = ["sad", "joy", "love", "anger", "fear", "surprise"]

In [9]:
class_idxs = [
    distil_bert_tokenizer(class_adj)["input_ids"][1] for class_adj in class_mapping
]

In [10]:
class_idxs

[6517, 6569, 2293, 4963, 3571, 4474]

In [18]:
template = ". emotion is [MASK]."
# template = " It has the emotion of [MASK]."
# template = " It feels [MASK]."
# template = " It is [MASK]."
# template = " It has the feeling of [MASK]."

sample_idx = 18
sample_text = dataset["train"]["text"][sample_idx]
sample_label = dataset["train"]["label"][sample_idx]

prompt_input = sample_text + template
tokenized_text = distil_bert_tokenizer(
    prompt_input, truncation=True, padding=True, return_tensors="pt"
)

prompt_input, class_mapping[sample_label]

('i started feeling sentimental about dolls i had as a child and so began a collection of vintage barbie dolls from the sixties. emotion is [MASK].',
 'sad')

In [26]:
output = distil_bert(**tokenized_text, output_hidden_states=True)
output_tokens = output.logits[:, 1:-1, :].argmax(-1).squeeze(0)
output_sent_logit = output.logits[:, -3, :].squeeze(0)
output_pred_token = output_sent_logit.argmax(-1)
output_sent_token = output_sent_logit[class_idxs].argmax(-1)
output_pred_word = distil_bert_tokenizer.decode(output_pred_token)
distil_bert_tokenizer.decode(output_tokens)

'i started feeling sentimental about dolls i had as a child and so began a collection of vintage barbie dolls from the sixties. emotion is overwhelming.'

In [27]:
class_mapping[output_sent_token], class_mapping[sample_label]

('love', 'sad')

In [22]:
def get_encoding_from_bert(word):
    tokenized_text = distil_bert_tokenizer(
        word, truncation=True, padding=True, return_tensors="pt"
    )
    output = distil_bert(**tokenized_text, output_hidden_states=True)
    cls_output = output.hidden_states[-1][0, 0, :]
    
    return cls_output

In [23]:
encoding = get_encoding_from_bert(output_pred_word)
emotion_encodings = []
for emotion in class_mapping:
    emotion_encoding = get_encoding_from_bert(emotion)
    emotion_encodings.append(emotion_encoding)
emotion_encodings = torch.stack(emotion_encodings)
sent_idx = F.cosine_similarity(encoding, emotion_encodings, dim=1).argmax(-1)
class_mapping[sent_idx], class_mapping[sample_label]

('fear', 'sad')