In [1]:
import torch
from transformers import BertTokenizer, BertModel
from sklearn.linear_model import LogisticRegression
import pandas as pd
import numpy as np
import time
#from keras.preprocessing.sequence import pad_sequences


In [2]:
def tok(tokenizer, text):
    return tokenizer.encode_plus(text, 
                                    add_special_tokens=True,
                                    max_length=510,
                                    padding='longest', 
                                    truncation=True,
                                      return_token_type_ids=True,
                                      return_attention_mask=True,
                                      return_tensors='pt'
                                    )

In [3]:
def set_embed(df, model, tokenizer):
    l = []
    for i, review in enumerate(df.review_body):
        tokened = tok(tokenizer, review)
        #print(model(**tokenizer))
        l.append(embed(model_mul, tokened).numpy())
    return l

In [4]:
from typing import Callable, List, Optional, Tuple
from torch import nn
import pandas as pd
from sklearn.base import TransformerMixin, BaseEstimator
import torch

def embed(model, tokens_tensor ):

    with torch.no_grad():
        tokens_tensor = tokens_tensor.to('cuda')
        model.to('cuda')
        outputs = model(**tokens_tensor)

        # Evaluating the model will return a different number of objects based on 
        # how it's  configured in the `from_pretrained` call earlier. In this case, 
        # becase we set `output_hidden_states = True`, the third item will be the 
        # hidden states from all layers. See the documentation for more details:
        # https://huggingface.co/transformers/model_doc/bert.html#bertmodel
        hidden_states = outputs[2]
        token_embeddings = torch.stack(hidden_states, dim=0)
        token_embeddings = torch.squeeze(token_embeddings, dim=1)
        # Stores the token vectors, with shape [6 x 768]
    
    token_vecs_sum = []

    # `token_embeddings` is a [6 x 12 x 768] tensor.

    # For each token in the sentence...
    for token in token_embeddings:

        # `token` is a [6 x 768] tensor

        # Sum the vectors from the last four layers.
        sum_vec = torch.sum(token[-4:], dim=0)

        # Use `sum_vec` to represent `token`.
        token_vecs_sum.append(sum_vec)
        
    token_vecs = hidden_states[-2][0]

    # Calculate the average of all 6 token vectors.
    sentence_embedding = torch.mean(token_vecs, dim=0)
    
    return sentence_embedding

def embed_cls(model, tokens_tensor):
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    with torch.no_grad():
        tokens_tensor = tokens_tensor.to('cuda')       
        model.to('cuda')
        
        #if torch.cuda.device_count() > 1:
            #print("Let's use", torch.cuda.device_count(), "GPUs!")
            #model = nn.DataParallel(model)
        #else:
            #model.to(device)
        
        outputs = model(**tokens_tensor)
        
        return outputs.pooler_output
        
class BertTransformer(BaseEstimator, TransformerMixin):
    def __init__(
            self,
            bert_tokenizer,
            bert_model,
            max_length: int = 510,
            embedding_func = None,
    ):
        self.tokenizer = bert_tokenizer
        self.model = bert_model
        self.model.eval()
        self.max_length = max_length
        self.embedding_func = embedding_func

        if self.embedding_func is None:
            self.embedding_func = lambda x: x[0][:, 0, :].squeeze()

    def _tokenize(self, text: str) -> Tuple[torch.tensor, torch.tensor]:
        # Tokenize the text with the provided tokenizer
#         tokenized_text = self.tokenizer.encode_plus(text,
#                                                     add_special_tokens=True,
#                                                     max_length=self.max_length
#                                                     )["input_ids"]
        
        tokenized_text = self.tokenizer.encode_plus(text, 
                                    add_special_tokens=True,
                                    max_length=self.max_length,
                                    padding='longest', 
                                    truncation=True,
                                    return_token_type_ids=True,
                                    return_attention_mask=True,
                                    return_tensors='pt'
                                    )
        return tokenized_text


    def _tokenize_and_predict(self, text: str) -> torch.tensor:
        tokenized = self._tokenize(text)

        #embeddings = self.model(**tokenized)
        return self.embedding_func(self.model, tokenized)

    def transform(self, text: List[str]):
        if isinstance(text, pd.Series):
            text = text.tolist()
        
#         return torch.stack([self._tokenize_and_predict(string) for string in text]).cpu()

        return torch.stack([self._tokenize_and_predict(text)]).cpu()

    def fit(self, X, y=None):
        """No fitting necessary so we just return ourselves"""
        return self

# Data sets

### English datset

In [8]:
eng_train = pd.read_csv('../data/train_en').review_body
eng_test = pd.read_csv('../data/test_en').review_body
eng_val = pd.read_csv('../data/val_en').review_body


### Spanish

In [9]:
es_train = pd.read_csv('../data/train_es').review_body
es_test = pd.read_csv('../data/test_es').review_body
es_val = pd.read_csv('../data/val_es').review_body

# LABSE BERT

In [15]:
tokenizer_mul = BertTokenizer.from_pretrained("sentence-transformers/LaBSE")
model_mul = BertModel.from_pretrained("sentence-transformers/LaBSE",
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )

In [13]:
#bert_transformer = BertTransformer(tokenizer_mul, model_mul, embedding_func=embed)
bert_transformer = BertTransformer(tokenizer_mul, model_mul, embedding_func=embed_cls)
  

# Labse BERT

In [16]:
tokenizer_mul = BertTokenizer.from_pretrained("pvl/labse_bert")
model_mul = BertModel.from_pretrained('pvl/labse_bert',
                                  output_hidden_states = True, # Whether the model returns all hidden-states.
                                  )

Downloading:   0%|          | 0.00/5.22M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/62.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

In [17]:
bert_transformer = BertTransformer(tokenizer_mul, model_mul, embedding_func=embed_cls)


In [18]:
import pickle
import itertools

files_name = ['eng_val_labse.pkl', 'eng_test_labse.pkl', 'eng_train_labse.pkl', 'es_val_labse.pkl', 'es_test_labse.pkl', 'es_train_labse.pkl']
texts = [eng_val, eng_test, eng_train, es_val, es_test, es_train]

for i,t in enumerate(texts):
    emebedded = []
    start = time.time()
    for k,com in enumerate(t):
        emebedded.append(bert_transformer.transform(com))
        print(k/len(t))
    
    with open(files_name[i], 'wb') as f:
        torch.save(list(itertools.chain(*emebedded)), f)    
    end = time.time()
    print(end - start)

0.0
0.0005
0.001
0.0015
0.002
0.0025
0.003
0.0035
0.004
0.0045
0.005
0.0055
0.006
0.0065
0.007
0.0075
0.008
0.0085
0.009
0.0095
0.01
0.0105
0.011
0.0115
0.012
0.0125
0.013
0.0135
0.014
0.0145
0.015
0.0155
0.016
0.0165
0.017
0.0175
0.018
0.0185
0.019
0.0195
0.02
0.0205
0.021
0.0215
0.022
0.0225
0.023
0.0235
0.024
0.0245
0.025
0.0255
0.026
0.0265
0.027
0.0275
0.028
0.0285
0.029
0.0295
0.03
0.0305
0.031
0.0315
0.032
0.0325
0.033
0.0335
0.034
0.0345
0.035
0.0355
0.036
0.0365
0.037
0.0375
0.038
0.0385
0.039
0.0395
0.04
0.0405
0.041
0.0415
0.042
0.0425
0.043
0.0435
0.044
0.0445
0.045
0.0455
0.046
0.0465
0.047
0.0475
0.048
0.0485
0.049
0.0495
0.05
0.0505
0.051
0.0515
0.052
0.0525
0.053
0.0535
0.054
0.0545
0.055
0.0555
0.056
0.0565
0.057
0.0575
0.058
0.0585
0.059
0.0595
0.06
0.0605
0.061
0.0615
0.062
0.0625
0.063
0.0635
0.064
0.0645
0.065
0.0655
0.066
0.0665
0.067
0.0675
0.068
0.0685
0.069
0.0695
0.07
0.0705
0.071
0.0715
0.072
0.0725
0.073
0.0735
0.074
0.0745
0.075
0.0755
0.076
0.0765
0.077
0.

In [15]:
1+1

2