In [1]:
from colors import ColorsCorpusReader
import os
import torch

from transformers import pipeline
import torch.nn.functional as F
from transformers import (
    BertTokenizer, BertModel,
    XLNetTokenizer, XLNetModel,
    RobertaTokenizer, RobertaModel,
    ElectraTokenizer, ElectraModel,    
)

import model_utils as mu

In [2]:
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv"
)

In [3]:
corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME,
    word_count=None, #2
    normalize_colors=True
)

In [4]:
examples = list(corpus.read())

In [5]:
len(examples)

46994

In [6]:
close_examples = [example for example in examples if example.condition == "close"]
split_examples = [example for example in examples if example.condition == "split"]
far_examples = [example for example in examples if example.condition == "far"]

In [7]:
print(f"close: {len(close_examples)}")
print(f"split: {len(split_examples)}")
print(f"far: {len(far_examples)}")

close: 15519
split: 15693
far: 15782


In [8]:
dev_rawcols, dev_texts = zip(*[[ex.colors, ex.contents] for ex in examples])

In [9]:
test_colours = ['brown. not the yellow one or classic brown one, the weirder one']

#### Bert model embeddings extraction

In [10]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained('bert-base-cased')

In [11]:
input_ids = torch.tensor(bert_tokenizer.encode(test_colours, add_special_tokens=True)).unsqueeze(0)
input_ids

tensor([[ 101, 3058,  119, 1136, 1103, 3431, 1141, 1137, 5263, 3058, 1141,  117,
         1103, 6994, 1200, 1141,  102]])

In [12]:
btokens = bert_tokenizer.convert_ids_to_tokens(input_ids[0])
btokens

['[CLS]',
 'brown',
 '.',
 'not',
 'the',
 'yellow',
 'one',
 'or',
 'classic',
 'brown',
 'one',
 ',',
 'the',
 'weird',
 '##er',
 'one',
 '[SEP]']

Quick test

In [None]:
test_embeddings, test_vocab = mu.extract_input_embeddings(
    test_colours, bert_model, bert_tokenizer)

test_contextual_embeddings, test_contextual_vocab = mu.extract_contextual_embeddings(
    test_colours, bert_model, bert_tokenizer)

Extract input embeddings

In [13]:
%time \
bert_embeddings, bert_vocab = mu.extract_input_embeddings(dev_texts, bert_model, bert_tokenizer)

CPU times: user 10.5 s, sys: 55.6 ms, total: 10.6 s
Wall time: 10.6 s


Extract contextual embeddings (pre-trained embedding + position)

In [14]:
%time \
bert_contextual_embeddings, bert_contextual_vocab = mu.extract_contextual_embeddings(dev_texts, bert_model, bert_tokenizer)

CPU times: user 41min 4s, sys: 1min 47s, total: 42min 51s
Wall time: 42min 30s


#### XLNet model embeddings extraction

In [15]:
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetModel.from_pretrained('xlnet-base-cased')

Quick test

In [16]:
input_ids = torch.tensor(xlnet_tokenizer.encode(test_colours, add_special_tokens=True)).unsqueeze(0)
input_ids

tensor([[3442,    9,   50,   18, 3493,   65,   49, 3523, 3442,   65,   19,   18,
         8189,  118,   65,    4,    3]])

In [17]:
xtest = xlnet_tokenizer.convert_ids_to_tokens(input_ids[0])
xtest

['▁brown',
 '.',
 '▁not',
 '▁the',
 '▁yellow',
 '▁one',
 '▁or',
 '▁classic',
 '▁brown',
 '▁one',
 ',',
 '▁the',
 '▁weird',
 'er',
 '▁one',
 '<sep>',
 '<cls>']

In [None]:
test_embeddings, test_vocab = mu.extract_input_embeddings(
    test_colours, xlnet_model, xlnet_tokenizer)

test_contextual_embeddings, test_contextual_vocab = mu.extract_contextual_embeddings(
    test_colours, xlnet_model, xlnet_tokenizer)

Extract input embeddings

In [18]:
%time \
xlnet_embeddings, xlnet_vocab = mu.extract_input_embeddings(dev_texts, xlnet_model, xlnet_tokenizer)

CPU times: user 9.78 s, sys: 82.6 ms, total: 9.87 s
Wall time: 9.88 s


Extract contextual embeddings (pre-trained embedding + position)

In [19]:
%time \
xlnet_contextual_embeddings, xlnet_contextual_vocab = mu.extract_contextual_embeddings(dev_texts, xlnet_model, xlnet_tokenizer)

CPU times: user 48min 5s, sys: 2min, total: 50min 5s
Wall time: 49min 44s


#### RoBERTa model embeddings extractions

In [20]:
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaModel.from_pretrained('roberta-base')

In [21]:
input_ids = torch.tensor(roberta_tokenizer.encode(test_colours, add_special_tokens=True)).unsqueeze(0)
input_ids

tensor([[    0, 31876,     4,    45,     5,  5718,    65,    50,  4187,  6219,
            65,     6,     5,    52,   853,  3624,    65,     2]])

In [22]:
rtest = roberta_tokenizer.convert_ids_to_tokens(input_ids[0])
rtest

['<s>',
 'brown',
 '.',
 'Ġnot',
 'Ġthe',
 'Ġyellow',
 'Ġone',
 'Ġor',
 'Ġclassic',
 'Ġbrown',
 'Ġone',
 ',',
 'Ġthe',
 'Ġwe',
 'ir',
 'der',
 'Ġone',
 '</s>']

Quick test

In [None]:
test_embeddings, test_vocab = mu.extract_input_embeddings(
    test_colours, roberta_model, roberta_tokenizer)

test_contextual_embeddings, test_contextual_vocab = mu.extract_contextual_embeddings(
    test_colours, roberta_model, roberta_tokenizer)

Extract input embeddings

In [23]:
%time \
roberta_embeddings, roberta_vocab = mu.extract_input_embeddings(dev_texts, roberta_model, roberta_tokenizer)

CPU times: user 10.7 s, sys: 207 ms, total: 10.9 s
Wall time: 11 s


Extract contextual embeddings (pre-trained embedding + position)

In [24]:
%time \
xlnet_contextual_embeddings, xlnet_contextual_vocab = mu.extract_contextual_embeddings(dev_texts, roberta_model, roberta_tokenizer)

CPU times: user 45min 42s, sys: 2min 25s, total: 48min 7s
Wall time: 47min 32s


#### ELECTRA model embeddings extractions

In [25]:
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
electra_model = ElectraModel.from_pretrained('google/electra-small-discriminator')

In [26]:
input_ids = torch.tensor(electra_tokenizer.encode(test_colours, add_special_tokens=True)).unsqueeze(0)
input_ids

tensor([[ 101, 2829, 1012, 2025, 1996, 3756, 2028, 2030, 4438, 2829, 2028, 1010,
         1996, 6881, 2121, 2028,  102]])

In [27]:
rtest = electra_tokenizer.convert_ids_to_tokens(input_ids[0])
rtest

['[CLS]',
 'brown',
 '.',
 'not',
 'the',
 'yellow',
 'one',
 'or',
 'classic',
 'brown',
 'one',
 ',',
 'the',
 'weird',
 '##er',
 'one',
 '[SEP]']

Quick test

In [None]:
test_embeddings, test_vocab = mu.extract_input_embeddings(
    test_colours, electra_model, electra_tokenizer)

test_contextual_embeddings, test_contextual_vocab = mu.extract_contextual_embeddings(
    test_colours, electra_model, electra_tokenizer)

Extract input embeddings

In [28]:
%time \
roberta_embeddings, roberta_vocab = mu.extract_input_embeddings(dev_texts, electra_model, electra_tokenizer)

CPU times: user 12.7 s, sys: 97.3 ms, total: 12.8 s
Wall time: 12.9 s


Extract contextual embeddings (pre-trained embedding + position)

In [29]:
%time \
xlnet_contextual_embeddings, xlnet_contextual_vocab = mu.extract_contextual_embeddings(dev_texts, electra_model, electra_tokenizer)

CPU times: user 12min 43s, sys: 1min 1s, total: 13min 45s
Wall time: 13min 7s
