In [4]:
!pip install datasets
from datasets import load_dataset

Collecting datasets
  Downloading datasets-2.2.1-py3-none-any.whl (342 kB)
[K     |████████████████████████████████| 342 kB 3.9 MB/s eta 0:00:01
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.12.2-py38-none-any.whl (128 kB)
[K     |████████████████████████████████| 128 kB 14.1 MB/s eta 0:00:01
[?25hCollecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.3.0-py3-none-any.whl (136 kB)
[K     |████████████████████████████████| 136 kB 16.5 MB/s eta 0:00:01
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.1-cp38-cp38-macosx_11_0_arm64.whl (551 kB)
[K     |████████████████████████████████| 551 kB 19.3 MB/s eta 0:00:01
[?25hCollecting pyarrow>=6.0.0
  Downloading pyarrow-8.0.0-cp38-cp38-macosx_11_0_arm64.whl (16.2 MB)
[K     |████████████████████████████████| 16.2 MB 15.9 MB/s eta 0:00:01
Collecting xxhash
  Downloading xxhash-3.0.0-cp38-cp38-macosx_11_0_arm64.whl (30 kB)
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting 

In [5]:
import torch
from torch import nn
from torch.nn import functional as F

In [6]:
dataset = load_dataset("emotion")

Downloading builder script:   0%|          | 0.00/1.66k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset emotion/default (download: 1.97 MiB, generated: 2.07 MiB, post-processed: Unknown size, total: 4.05 MiB) to /Users/daohuei/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705...


Downloading data:   0%|          | 0.00/1.66M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/204k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/207k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Dataset emotion downloaded and prepared to /Users/daohuei/.cache/huggingface/datasets/emotion/default/0.0.0/348f63ca8e27b3713b6c04d723efe6d824a56fb3d1449794716c0f0296072705. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
from transformers import AutoModelForMaskedLM, AutoTokenizer

distil_bert = AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased")
distil_bert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

In [8]:
class_mapping = ["sad", "joy", "love", "anger", "fear", "surprise"]

In [9]:
class_idxs = [
    distil_bert_tokenizer(class_adj)["input_ids"][1] for class_adj in class_mapping
]

In [10]:
class_idxs

[6517, 6569, 2293, 4963, 3571, 4474]

In [18]:
template = ". emotion is [MASK]."
# template = " It has the emotion of [MASK]."
# template = " It feels [MASK]."
# template = " It is [MASK]."
# template = " It has the feeling of [MASK]."

sample_idx = 18
sample_text = dataset["train"]["text"][sample_idx]
sample_label = dataset["train"]["label"][sample_idx]

prompt_input = sample_text + template
tokenized_text = distil_bert_tokenizer(
    prompt_input, truncation=True, padding=True, return_tensors="pt"
)

prompt_input, class_mapping[sample_label]

('i started feeling sentimental about dolls i had as a child and so began a collection of vintage barbie dolls from the sixties. emotion is [MASK].',
 'sad')

In [26]:
output = distil_bert(**tokenized_text, output_hidden_states=True)
output_tokens = output.logits[:, 1:-1, :].argmax(-1).squeeze(0)
output_sent_logit = output.logits[:, -3, :].squeeze(0)
output_pred_token = output_sent_logit.argmax(-1)
output_sent_token = output_sent_logit[class_idxs].argmax(-1)
output_pred_word = distil_bert_tokenizer.decode(output_pred_token)
distil_bert_tokenizer.decode(output_tokens)

'i started feeling sentimental about dolls i had as a child and so began a collection of vintage barbie dolls from the sixties. emotion is overwhelming.'

In [27]:
class_mapping[output_sent_token], class_mapping[sample_label]

('love', 'sad')

In [22]:
def get_encoding_from_bert(word):
    tokenized_text = distil_bert_tokenizer(
        word, truncation=True, padding=True, return_tensors="pt"
    )
    output = distil_bert(**tokenized_text, output_hidden_states=True)
    cls_output = output.hidden_states[-1][0, 0, :]
    
    return cls_output

In [23]:
encoding = get_encoding_from_bert(output_pred_word)
emotion_encodings = []
for emotion in class_mapping:
    emotion_encoding = get_encoding_from_bert(emotion)
    emotion_encodings.append(emotion_encoding)
emotion_encodings = torch.stack(emotion_encodings)
sent_idx = F.cosine_similarity(encoding, emotion_encodings, dim=1).argmax(-1)
class_mapping[sent_idx], class_mapping[sample_label]

('fear', 'sad')