In [1]:
__authors__ = "Anton Gochev, Jaro Habr, Yan Jiang, Samuel Kahn"
__version__ = "XCS224u, Stanford, Spring 2021"

### Experimental notebook demonstrating the extraction of static embeddings 

In [2]:
import os
import torch

from transformers import pipeline
import torch.nn.functional as F
from transformers import (
    BertTokenizer, BertModel,
    XLNetTokenizer, XLNetModel,
    RobertaTokenizer, RobertaModel,
    ElectraTokenizer, ElectraModel,    
)

import utils.model_utils as mu

In [3]:
test_colours = [
    'brown. not the yellow one or classic brown one, the weirder one', 
    'brown. not the yellow one or classic brown one',
    'some other brown. that one'
]

#### Bert model embeddings extraction

In [4]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained('bert-base-cased')

In [5]:
input_ids = torch.tensor(bert_tokenizer.encode(test_colours[0], add_special_tokens=True))
input_ids

tensor([ 101, 3058,  119, 1136, 1103, 3431, 1141, 1137, 5263, 3058, 1141,  117,
        1103, 6994, 1200, 1141,  102])

In [6]:
btokens = bert_tokenizer.convert_ids_to_tokens(input_ids)
btokens

['[CLS]',
 'brown',
 '.',
 'not',
 'the',
 'yellow',
 'one',
 'or',
 'classic',
 'brown',
 'one',
 ',',
 'the',
 'weird',
 '##er',
 'one',
 '[SEP]']

Quick test

In [7]:
e = mu.extract_input_embeddings(test_colours, bert_model, bert_tokenizer)

In [8]:
ce = mu.extract_positional_embeddings(test_colours, bert_model, bert_tokenizer)

#### XLNet model embeddings extraction

In [9]:
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetModel.from_pretrained('xlnet-base-cased')

Quick test

In [10]:
input_ids = torch.tensor(xlnet_tokenizer.encode(test_colours[0], add_special_tokens=True)).unsqueeze(0)
input_ids.size()

torch.Size([1, 17])

In [11]:
xtest = xlnet_tokenizer.convert_ids_to_tokens(input_ids[0])
xtest

['▁brown',
 '.',
 '▁not',
 '▁the',
 '▁yellow',
 '▁one',
 '▁or',
 '▁classic',
 '▁brown',
 '▁one',
 ',',
 '▁the',
 '▁weird',
 'er',
 '▁one',
 '<sep>',
 '<cls>']

In [12]:
e = mu.extract_input_embeddings(test_colours, xlnet_model, xlnet_tokenizer)

In [13]:
ce = mu.extract_positional_embeddings(test_colours, xlnet_model, xlnet_tokenizer)

#### RoBERTa model embeddings extractions

In [14]:
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaModel.from_pretrained('roberta-base')

In [15]:
input_ids = torch.tensor(roberta_tokenizer.encode(test_colours[0], add_special_tokens=True)).unsqueeze(0)

In [16]:
rtest = roberta_tokenizer.convert_ids_to_tokens(input_ids[0])
rtest

['<s>',
 'brown',
 '.',
 'Ġnot',
 'Ġthe',
 'Ġyellow',
 'Ġone',
 'Ġor',
 'Ġclassic',
 'Ġbrown',
 'Ġone',
 ',',
 'Ġthe',
 'Ġwe',
 'ir',
 'der',
 'Ġone',
 '</s>']

Quick test

In [17]:
e = mu.extract_input_embeddings(test_colours, roberta_model, roberta_tokenizer)

In [18]:
ce = mu.extract_positional_embeddings(test_colours, roberta_model, roberta_tokenizer)

#### ELECTRA model embeddings extractions

In [19]:
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
electra_model = ElectraModel.from_pretrained('google/electra-small-discriminator')

In [20]:
input_ids = torch.tensor(electra_tokenizer.encode(test_colours[0], add_special_tokens=True)).unsqueeze(0)
input_ids

tensor([[ 101, 2829, 1012, 2025, 1996, 3756, 2028, 2030, 4438, 2829, 2028, 1010,
         1996, 6881, 2121, 2028,  102]])

In [21]:
rtest = electra_tokenizer.convert_ids_to_tokens(input_ids[0])
rtest

['[CLS]',
 'brown',
 '.',
 'not',
 'the',
 'yellow',
 'one',
 'or',
 'classic',
 'brown',
 'one',
 ',',
 'the',
 'weird',
 '##er',
 'one',
 '[SEP]']

Quick test

In [22]:
e = mu.extract_input_embeddings(test_colours, electra_model, electra_tokenizer)

In [23]:
ce = mu.extract_positional_embeddings(test_colours, electra_model, electra_tokenizer)

#### Generate token sequences based on raw colour descriptions converted into model based tokens

In [24]:
mu.tokenize_colour_description(test_colours[0], tokenizer=electra_tokenizer)

['<s>',
 'brown',
 'not',
 'the',
 'yellow',
 'one',
 'or',
 'classic',
 'brown',
 'one',
 'the',
 'weird',
 'er',
 'one',
 '</s>']