# Contrastive Learning From Scratch - DistilBERT

An attempt to build contrastive learning model from scratch. Parts include:

- Loading and preparing Wiki-1M data for model input
- Contrastive learning model
  - Forward passing using pre-trained model
  - Constrastive layer
  - Calculate loss
- Training procedure
  - Default trainer optimizer
  - Default trainer hyper-parameters

In [1]:
import os

# Set Project home
PROJECT_HOME = os.path.join('/',
                            'Users',
                            'ng-ka',
                            'OMSCS',
                            'DL',
                            'DLProject',
                            'contrastive-learning-in-distilled-models')
%cd {PROJECT_HOME}

# Load project code
%reload_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, './src')

#import distilface
import src.distilface as distilface

C:\Users\ng-ka\OMSCS\DL\DLProject\contrastive-learning-in-distilled-models


In [2]:
os.getcwd()

'C:\\Users\\ng-ka\\OMSCS\\DL\\DLProject\\contrastive-learning-in-distilled-models'

In [28]:
import torch
import torch.nn as nn

from transformers import AutoTokenizer, DistilBertModel, DistilBertPreTrainedModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling

from src.distilface.modules.pooler import Pooler
from src.distilface.modules.similarity import Similarity

class DistilBertCLModel(DistilBertPreTrainedModel):
    def __init__(self, config, pooler_type='avg_last4', temp=0.05):
        super().__init__(config)

        self.config = config
        self.pooler_type = pooler_type
        self.temp = 0.05

        self.distilbert = DistilBertModel(config)
        self.pooler = Pooler(pooler_type)
        self.sim = Similarity(temp=temp)

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None):
        if self.training:
            return self.cl_forward(self.distilbert, input_ids, attention_mask)
        else:
            return self.sent_emb(self.distilbert, input_ids, attention_mask)

    def cl_forward(self, encoder, input_ids=None, attention_mask=None):
        batch_size = input_ids.size(0)#64#input_ids.size(0)
        num_sent = input_ids.size(1)  # Number of sentences in one instance: 2 sentences

        # Flatten all input tensors
        input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
        attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)

        # Pre-trained Model Encoder
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        # Pooling
        pooler_output = self.pooler(attention_mask, outputs)
        pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1)))  # (bs, num_sent, hidden)

        # Separate representation
        z1, z2 = pooler_output[:, 0], pooler_output[:, 1]

        # Cosine similarity
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))

        # Calculate contrastive loss
        criterion = nn.CrossEntropyLoss()
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss = criterion(cos_sim, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def sent_emb(self, encoder, input_ids=None, attention_mask=None):
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        pooler_output = self.pooler(attention_mask, outputs)

        return BaseModelOutputWithPooling(
            pooler_output=pooler_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )


pretrained_model_name = 'distilbert-base-uncased'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained(pretrained_model_name)

model = DistilBertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

#model.eval();


loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at C:\Users\ng-ka/.cache\huggingface\transformers\23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.17.0",
  "vocab_size": 30522
}

loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at C:\Users\ng-ka/.cache\huggingface\transfor

In [29]:
model = torch.load('batch128_model.pth')
model.eval()

DistilBertCLModel(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linea

In [30]:
model = model.to('cpu')

In [31]:
quantized_model =torch.quantization.quantize_dynamic(model, qconfig_spec={nn.Linear},  dtype=torch.qint8
)

In [32]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

# compare the sizes
f=print_size_of_model(model,"fp32")
q=print_size_of_model(quantized_model,"int8")
print("{0:.2f} times smaller".format(f/q))

model:  fp32  	 Size (KB): 265489.337
model:  int8  	 Size (KB): 138116.329
1.92 times smaller


The following does not succeed due to lack of compatibility of our model with Torch Script

In [27]:
scripted_model = torch.jit.script(quantized_model)

UnsupportedNodeError: GeneratorExp aren't supported:
  File "C:\Users\ng-ka\anaconda3\envs\cl-distilled\lib\site-packages\transformers\modeling_utils.py", line 999
    
        x=False
        for it_element in (hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()):
                          ~ <--- HERE
            if it_element:
                x = True


Attempts with the alternative method of tracing were also not successful

In [36]:
from transformers import BertModel, BertTokenizer, BertConfig

enc = BertTokenizer.from_pretrained("distilbert-base-uncased")
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]


# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens]).to('cuda')
segments_tensors = torch.tensor([segments_ids]).to('cuda')

traced_model = torch.jit.trace(quantized_model, [tokens_tensor, segments_tensors])

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.


RuntimeError: Tracer cannot infer type of BaseModelOutputWithPooling(last_hidden_state=tensor([[[-2.8095e-01,  1.9375e-01, -2.3601e-01,  ...,  2.4750e-04,
           1.8413e-01,  7.4281e-01],
         [-2.0558e-01,  5.6973e-01, -3.4450e-01,  ...,  3.1061e-01,
           7.4593e-03,  6.6465e-01],
         [-2.6471e-01,  4.7577e-01, -8.7475e-01,  ...,  4.3364e-01,
           1.6303e-02,  1.2238e+00],
         ...,
         [ 2.1979e-01,  2.8544e-02, -4.8508e-01,  ...,  1.5100e-01,
           5.6132e-01,  6.3370e-01],
         [ 8.1528e-02,  5.9031e-01, -1.4732e-01,  ..., -5.9024e-02,
           5.9021e-01,  8.1966e-01],
         [-7.2517e-03,  1.6594e-01, -7.6334e-01,  ...,  5.0187e-01,
           3.8227e-01,  5.1416e-01]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 4.6693e-02,  4.0107e-02, -3.7071e-01, -2.7972e-02, -9.5654e-02,
         -8.7964e-02,  1.8528e-02,  2.7150e-01, -2.4230e-01,  9.8927e-05,
         -3.4853e-01, -5.7866e-01, -1.3977e-01,  3.4567e-01, -3.4365e-01,
         -9.6294e-03, -2.6303e-02, -2.1803e-01, -1.1859e-01,  7.5814e-02,
          2.7154e-01, -2.1654e-01, -3.3980e-01,  3.5713e-01,  7.2184e-02,
          7.7747e-02,  2.0000e-01,  3.3041e-01,  1.5145e-01, -3.1463e-01,
         -8.3670e-03, -6.3714e-02,  4.2273e-01,  2.9969e-01,  5.7845e-02,
          2.3974e-01, -4.7594e-02, -1.6258e-01, -1.6344e-01, -1.2472e-01,
         -1.7768e-01,  5.6963e-03, -2.2802e-01,  6.1100e-04,  4.6322e-01,
         -3.9256e-02,  4.0616e-01, -3.4898e-01,  8.9205e-02, -1.6020e-01,
          9.7896e-02, -1.6947e-01,  2.7750e-02,  3.5056e-01,  2.2463e-01,
         -2.5251e-02, -8.3530e-02,  1.1967e-02,  1.8789e-01,  3.0655e-01,
          6.2255e-02, -3.2581e-02,  2.6660e-02, -9.0701e-02, -3.4137e-01,
         -2.6564e-01, -1.7123e-01,  3.8302e-01, -5.6068e-01,  1.9953e-01,
          1.5175e-01,  2.5360e-01,  1.3732e-01, -1.1475e-01,  5.0170e-02,
          2.9594e-01,  4.7890e-01,  8.9323e-02,  4.1593e-01,  2.3964e-01,
          2.2769e-01,  1.6559e-01,  2.8347e-01,  1.4929e-01,  2.6771e-01,
         -1.7873e-01, -1.1776e-01, -8.1478e-02, -2.2917e-01,  6.6764e-02,
         -1.5156e-01,  1.4597e-01,  1.7807e-01, -7.8407e-02, -1.1961e-02,
         -5.5800e-03,  3.2663e-02, -3.6713e-02, -1.0436e-01, -2.6302e-01,
          2.9339e-01,  2.6338e-01, -8.5827e-03,  2.8918e-01,  2.0783e-01,
          6.6061e-02,  8.6752e-02,  2.6657e-01, -2.1378e-01,  5.5170e-01,
         -3.0502e-01,  1.9279e-02,  2.0699e-01, -4.1495e-01, -1.3450e-03,
         -3.4352e-02, -6.2233e-01,  5.7239e-02, -5.3565e-01, -1.9556e-01,
          2.1748e-01, -2.5492e-02, -1.7693e-01, -9.1041e-02, -8.2664e-02,
          1.1888e-01, -2.6719e-01,  1.9183e-01, -8.6937e-02, -1.1656e-01,
          9.3656e-02, -2.4710e-01,  4.2405e-01, -9.8818e-02, -1.9976e-01,
          5.9881e-02, -2.5661e-01, -4.5558e-04,  5.4826e-02, -9.8980e-02,
         -2.1050e-01, -3.2874e-02, -1.3713e-01,  1.4698e-01,  1.3910e-01,
         -1.8795e-01,  2.5036e-02,  1.5718e-01,  7.8290e-02,  3.4586e-01,
         -1.9084e-01, -2.4636e-01, -2.7752e-01, -3.2277e-01, -2.1060e-01,
          1.5730e-02, -1.6699e-01, -4.1642e-01,  1.0134e-01,  1.4960e-01,
         -1.6346e-01, -5.3434e-01, -2.3700e-01, -4.8405e-01, -4.0551e-01,
          1.3371e-01,  1.7637e-01, -4.6814e-01,  2.1758e-01, -9.9638e-02,
         -6.6738e-02,  5.5015e-01,  1.1654e-01,  1.7745e-02, -4.3454e-01,
         -9.5787e-02,  3.9904e-01, -1.8967e-01,  1.4247e-01,  4.7395e-01,
         -1.0383e-02, -2.8846e-02, -4.6897e-02, -1.0486e-01, -5.7609e-01,
         -3.6152e-01,  2.1397e-01, -3.2449e-01,  3.2640e-01,  1.3473e-01,
         -4.2674e-02,  3.7581e-01, -1.7810e-01, -1.5909e-01,  4.2082e-01,
          2.9741e-01, -1.1551e-01, -3.5975e-01,  3.0413e-01,  2.8718e-01,
          8.3572e-02, -7.6330e-02,  3.6599e-01, -3.5791e-01, -1.1984e-01,
          4.6330e-02, -2.1017e-01,  9.6010e-02,  2.5951e-01,  3.1474e-01,
          3.2159e-01,  1.7671e-01, -8.4449e-02,  1.2721e-01,  6.9955e-02,
          3.3938e-01, -3.6960e-02, -2.6437e-01, -3.6957e-01, -3.1519e-01,
         -2.8044e-01, -3.1960e-02, -3.2118e-01, -3.3568e-01,  5.6361e-02,
         -7.9344e-02, -1.0817e-01, -4.1884e-01,  5.6373e-01,  1.7060e-01,
         -3.7492e-01, -4.9206e-01, -1.1492e-01, -5.3896e-02, -1.0919e-01,
         -2.6218e-01,  2.0154e-01,  4.4679e-02,  6.2231e-02,  8.6094e-02,
         -1.1548e-01,  4.0717e-01,  1.0435e-01,  1.1964e-01,  5.5811e-02,
         -1.3387e-01, -3.0432e-01,  1.9487e-01, -2.3411e-01,  2.3979e-01,
          8.3957e-02, -1.7051e-01,  2.9582e-01,  7.6980e-02, -2.1012e-01,
         -2.2019e-01,  5.1967e-02, -2.0904e-01, -3.1589e-01,  1.3976e-01,
          7.3378e-02,  1.2886e-02, -3.7941e-02, -4.1992e-01, -6.7913e-01,
         -3.8605e-02,  3.3250e-01, -1.4962e-01, -2.3367e-01,  2.1604e-02,
          9.5433e-01, -3.3314e-01,  3.3863e-02, -1.1041e-01,  1.3514e-01,
          1.7344e-01,  3.0827e-01,  1.9480e-02,  1.3813e-01,  7.5715e-03,
          2.6293e-01,  1.4490e-01, -3.7647e-01,  2.0370e-01, -1.2067e-01,
         -1.9894e-01, -2.9209e-01, -1.8561e-01, -2.8015e-01, -1.9971e-01,
          7.9241e-02,  2.7789e-01,  2.0711e-01,  9.2009e-02, -1.2497e-01,
          1.3548e-02, -4.1796e-01, -9.4340e-02,  3.9416e-02,  4.4411e-02,
         -2.0689e-01,  7.8503e-02,  2.8727e-01,  4.8503e-01, -3.9871e-01,
          5.5450e-02, -2.8031e-01, -1.6286e-01,  3.7779e-01, -4.2408e-01,
          4.3749e-03,  1.3447e-01, -7.2981e-02, -2.7551e-01, -1.6305e-01,
         -1.8913e-01, -4.5124e-02,  1.1433e-01, -1.1529e-01,  1.6803e-01,
         -4.9927e-02,  3.2612e-02, -1.4389e-01, -6.5885e-01, -6.7365e-02,
         -1.4673e-01, -4.8238e-02, -1.4534e-01,  6.4086e-01,  3.6175e-01,
         -1.3694e-01, -3.0026e-01,  4.9679e-01, -1.2416e-01, -3.7848e-01,
         -2.7967e-01, -3.2192e-02, -5.8232e-01,  2.5316e-01, -3.7612e-01,
         -7.5341e-02, -1.0061e-01, -4.2381e-01,  2.0430e-01,  3.5421e-01,
          4.7512e-01, -2.4318e-02, -3.1277e-01,  5.9076e-02, -2.8672e-01,
         -2.0519e-01,  1.0690e-01,  2.1698e-01, -4.6779e-01, -4.2751e-01,
         -7.9820e-02, -1.0694e-01, -1.5257e-01, -2.0322e-01,  2.8396e-01,
         -9.1225e-02,  5.7800e-02,  2.5577e-01, -3.4644e-01,  1.6233e-01,
          6.3564e-01, -3.3724e-01,  1.9419e-01, -3.6530e-01,  1.0729e-01,
         -4.0173e-01, -3.8239e-01, -1.6141e-01,  7.4453e-02, -4.3982e-01,
         -3.7165e-01,  9.4542e-02, -3.9154e-01, -2.8783e-01, -2.0081e-01,
          2.1023e-02, -3.1915e-01,  5.1339e-01,  4.0421e-02, -5.2414e-01,
         -1.3860e-01,  5.0436e-01, -4.0707e-01, -5.7195e-01,  2.2586e-02,
         -2.6875e-01,  1.8693e-01, -2.5197e-01,  4.9181e-01, -5.9845e-01,
         -6.9314e-02, -5.7444e-02,  7.7586e-02,  3.3125e-02,  1.4080e-01,
         -1.8913e-01,  1.2483e-01,  1.8294e-01, -1.0081e+00,  2.8459e-01,
          2.2694e-01,  2.5965e-01, -7.4635e-02,  3.8446e-01, -4.5130e-01,
         -1.8509e-01,  8.3799e-02,  1.6537e-01,  3.1572e-01,  7.6182e-02,
          1.5847e-01,  3.7006e-01,  3.2070e-01,  2.1765e-01,  1.5052e-01,
         -4.7995e-01, -4.4966e-01, -3.5002e-01,  1.4406e-01, -8.8948e-02,
          5.1441e-01, -4.1660e-01,  3.4986e-01, -1.1441e-01, -1.0048e-01,
          1.6408e-01,  8.8887e-02,  1.5933e-01,  9.3904e-02,  1.1400e-01,
          1.8206e-01,  3.3361e-01, -1.2313e-01, -3.8495e-01,  2.2998e-01,
          1.0782e-01, -4.6062e-01, -2.9444e-01,  2.4281e-01, -1.2217e-01,
         -2.4393e-01,  1.8185e-01, -4.1725e-01, -3.6603e-01, -7.4732e-02,
          3.1326e-01, -1.5488e-01, -6.4646e-01, -8.3071e-02,  2.3564e-02,
         -3.8589e-01,  1.0421e-01, -1.0521e-01, -1.6525e-01, -1.1996e-01,
          1.4955e-01,  2.4033e-01, -8.6674e-02, -5.4170e-01, -2.5583e-01,
          5.6224e-02, -1.6628e-01, -4.1819e-01, -2.3568e-01, -4.9047e-01,
          1.7005e-02,  3.6465e-04,  2.6325e-01, -2.7289e-01,  1.3864e-01,
          8.5591e-02,  3.5195e-01, -5.7063e-02,  5.9695e-03,  3.2521e-02,
         -1.8778e-01,  1.8255e-01,  2.8537e-01, -1.4293e-02,  1.1457e-02,
         -3.1442e-01,  1.9079e-01, -3.3102e-01, -4.6480e-02, -2.7503e-01,
          4.3797e-01,  2.5024e-01, -1.1668e-01, -1.2751e-01, -1.4074e-01,
          1.5598e-01, -6.7441e-02, -1.0944e-01, -4.0199e-01,  1.6955e-01,
         -3.8893e-01,  1.8294e-01,  1.3988e-01,  3.9773e-01, -1.6795e-01,
          6.2370e-02, -8.5830e-02, -3.9306e-02, -3.4178e-01, -1.6844e-01,
          7.8893e-03, -3.4126e-01,  4.5400e-01,  1.6699e-01, -4.2358e-01,
         -1.6840e-01, -5.7318e-02,  3.9432e-02,  2.4061e-01,  1.4965e-01,
          1.8138e-01, -6.6945e-02, -3.7419e-01, -7.3155e-02, -2.2167e-01,
          1.9476e-01, -9.5210e-02,  2.7536e-01, -1.5343e-01, -5.6016e-01,
         -6.1347e-02,  7.8486e-02,  2.2220e-02,  1.2293e-01,  1.9137e-02,
          3.6239e-01, -3.9407e-03, -3.0480e-01,  1.0109e-01,  1.4238e-01,
         -1.0127e-01,  1.8982e-01, -7.3263e-02, -2.4488e-01,  4.4750e-01,
          3.5579e-01, -1.6979e-01,  1.6875e-01,  2.4465e-02, -1.2170e-01,
         -9.6275e-02,  8.9936e-02,  6.1709e-01,  1.4500e-01, -4.8076e-01,
         -3.0713e-01,  1.4727e-01, -7.0945e-03, -1.5526e-01,  9.5915e-02,
          7.4484e-02,  4.3584e-01,  1.6116e-01, -5.9871e-01, -5.9422e-01,
          7.6837e-02, -1.5509e-01, -3.6611e-01,  2.8282e-02,  3.7231e-01,
          3.2762e-01,  2.2060e-02, -3.5888e-01, -2.4051e-01,  6.4066e-02,
          3.5439e-01,  1.2462e-01,  1.6493e-01, -3.1871e-01,  5.5375e-02,
          1.1820e-01,  8.3088e-01,  8.5695e-02, -1.1434e-02,  5.5705e-02,
         -3.7754e-02,  1.1020e-01, -3.7759e-01,  1.2221e-01,  1.5877e-01,
          3.7045e-01, -2.2702e-01,  4.3940e-01, -1.2291e-01,  6.0639e-02,
          9.0442e-03, -1.1641e-01,  1.6571e-02, -2.3472e-01,  4.0214e-01,
          6.4819e-03,  6.5207e-03, -6.6666e-01,  1.4141e-01, -1.0585e-01,
         -2.4113e-01,  2.0391e-01, -1.0582e-01, -2.7862e-01, -1.3709e-01,
          2.6258e-03, -7.9851e-02,  4.0725e-03, -4.6333e-01, -8.4179e-02,
         -3.8501e-02, -2.1379e-01,  9.3424e-02, -3.2262e-01, -3.8604e-01,
          2.5370e-01,  1.5727e-01, -7.4532e-01, -1.9181e-01,  1.5279e-01,
         -2.5710e-01, -5.3802e-01,  3.6783e-01,  3.5995e-01, -2.0221e-01,
         -2.0596e-01,  5.2864e-01,  2.7097e-01,  1.5769e-01,  4.6152e-01,
         -4.7203e-01, -1.2711e-01,  2.9948e-02, -8.5174e-02, -2.4472e-01,
         -1.4154e-01,  3.8125e-01,  2.8051e-01, -1.8072e-01,  4.8548e-01,
          2.4945e-01,  1.4104e-01,  3.8624e-01, -2.8026e-01, -2.0120e-01,
          2.8732e-02, -1.1231e-02, -3.6426e-01,  2.6251e-02, -3.7851e-02,
          6.4009e-01, -1.5928e-01,  9.5855e-02, -6.6296e-02,  1.7410e-01,
          3.8996e-02,  7.4717e-02, -1.3931e-01,  5.8656e-01,  4.9502e-01,
         -2.6947e-02, -4.5381e-01, -2.7717e-01, -1.0380e-01, -1.1805e-01,
          3.7948e-01,  9.4765e-02,  2.8403e-01,  1.1377e-01, -4.7931e-01,
         -2.5457e-01,  1.8128e-02,  1.6517e-01,  3.1461e-01, -2.5417e-02,
         -2.0248e-01, -4.7060e-02,  6.1397e-02,  5.4805e-01,  1.4530e-01,
          7.1823e-02, -3.3883e-01, -1.8383e-01,  3.2484e-01,  6.6465e-02,
         -2.4925e-01,  1.2773e-01,  3.5344e-02,  3.2701e-01, -8.5392e-02,
          4.4462e-01, -1.3644e-01,  1.4186e-01, -2.9954e-01,  9.3068e-02,
          4.8697e-01, -4.9256e-01,  1.0619e-01,  4.2752e-01, -1.7284e-01,
         -2.3642e-01, -4.4011e-02, -2.9467e-01,  6.1980e-01,  5.8373e-01,
          2.2160e-01, -2.1866e-02,  1.7643e-01, -9.1520e-03,  1.2929e-01,
         -7.3082e-02,  1.6397e-01,  5.8583e-02, -4.6242e-01,  4.3455e-01,
         -1.0087e-01, -3.2727e-01,  2.2628e-01, -3.4582e-02,  1.6172e-01,
         -2.7409e-01,  1.0528e-01, -5.9848e-02,  2.3226e-01,  1.4564e-01,
         -1.2072e-01,  2.3905e-01, -1.9347e-01, -4.6393e-02,  2.3129e-01,
         -5.2212e-01, -3.7706e-01,  1.7923e-01, -1.3713e-01, -9.8761e-02,
         -4.0970e-03,  3.2820e-01, -6.4838e-01,  3.7760e-01,  1.0785e-01,
         -2.2988e-01,  3.9778e-01, -1.8354e-01, -2.0468e-01,  8.5161e-03,
          2.8459e-02, -5.1400e-01,  2.5988e-01, -1.5684e-01,  5.7713e-01,
         -1.2787e-01,  2.5969e-01,  1.0446e-01,  7.4681e-04, -2.0539e-01,
         -1.1128e-01, -1.0686e-01, -1.8856e-02,  2.0551e-01,  3.8597e-01,
          3.2728e-01,  2.5913e-01,  5.2154e-01]], device='cuda:0',
       grad_fn=<DivBackward0>), hidden_states=(tensor([[[ 0.3790, -0.1398, -0.1771,  ...,  0.1405, -0.0283,  0.1839],
         [-0.3286,  0.1314, -0.1459,  ...,  0.1504, -0.0726, -0.9195],
         [-0.4018, -0.7783,  0.3085,  ...,  0.2146, -0.1665,  0.1573],
         ...,
         [ 0.0851, -0.2625, -0.4894,  ...,  0.2734,  0.4745,  1.1782],
         [ 0.1677,  0.7349, -0.2215,  ..., -0.0586,  0.4081,  0.0223],
         [-0.5726,  0.0112,  0.0245,  ..., -0.1072,  0.2156, -0.0299]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0796, -0.0358, -0.2100,  ..., -0.0925,  0.1244,  0.1888],
         [-0.1361,  0.1932, -0.3597,  ..., -0.0295,  0.2344, -0.5821],
         [-0.1606, -0.3572, -0.7043,  ..., -0.0127,  0.1001,  0.1497],
         ...,
         [ 0.5795, -0.2205, -0.4840,  ..., -0.1051,  0.7013,  0.6734],
         [ 0.5804,  0.1233, -0.3604,  ..., -0.4208,  0.3692, -0.3083],
         [ 0.1685, -0.1799, -0.1704,  ..., -0.3674,  0.2134,  0.0064]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0183, -0.3107, -0.1695,  ..., -0.4951, -0.3178,  0.3724],
         [-0.7932,  0.6165, -0.3781,  ..., -0.1443, -0.8263, -0.0594],
         [-0.5486, -0.0839, -0.6945,  ..., -0.4000, -0.2905,  0.2969],
         ...,
         [ 0.8178,  0.1379, -0.1783,  ..., -0.4655, -0.3185,  0.7949],
         [ 0.5836,  0.5550, -0.3810,  ..., -0.6955, -0.9483, -0.0760],
         [ 0.0729, -0.0122, -0.0600,  ..., -0.0596, -0.1522,  0.0302]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.2324, -0.0543, -0.6694,  ..., -0.5096, -0.4335,  0.2488],
         [-1.4307,  1.3544, -0.3003,  ..., -0.7542, -1.0228, -0.9291],
         [-1.2292,  0.1476, -0.9470,  ..., -0.4628, -0.2401,  0.3044],
         ...,
         [ 1.0194,  0.5541, -0.2348,  ..., -0.1541, -0.1016,  0.5341],
         [ 0.8387,  1.2867, -0.0927,  ..., -0.7560, -0.3953, -0.5713],
         [ 0.0315, -0.0189,  0.0095,  ..., -0.0260, -0.0536, -0.0273]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-9.5326e-01, -4.3661e-01, -9.4581e-01,  ..., -1.5457e-01,
          -6.4549e-01, -1.9762e-01],
         [-2.1118e+00,  8.4958e-01,  3.3048e-01,  ...,  1.1778e-01,
          -8.8292e-01, -9.3446e-01],
         [-1.1577e+00, -3.1500e-01, -7.4675e-01,  ...,  2.2091e-01,
          -2.0268e-01,  4.8733e-01],
         ...,
         [ 4.8777e-01,  2.2107e-01,  3.3298e-01,  ...,  1.6233e-01,
           9.3253e-02,  3.7051e-01],
         [ 2.6080e-01,  9.9856e-01,  3.5275e-01,  ..., -2.5345e-02,
          -2.1419e-01, -8.3197e-01],
         [ 2.2233e-02,  2.0702e-03,  3.6527e-02,  ..., -5.7560e-02,
          -1.1700e-01, -5.4931e-02]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.9370, -0.6154,  0.0175,  ..., -0.2686,  0.0505,  0.3943],
         [-1.4244, -0.1362,  0.5284,  ...,  0.1380, -0.2261, -0.3084],
         [-0.7744, -0.5110, -0.2723,  ...,  0.0668,  0.0397,  0.6541],
         ...,
         [ 0.0961,  0.1713,  0.6896,  ..., -0.0356,  0.1312,  0.2198],
         [-0.2215,  0.0862,  0.9392,  ..., -0.1537, -0.0156, -0.4499],
         [-0.1323, -0.0390,  0.1072,  ...,  0.0682, -0.0397, -0.2653]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-2.8095e-01,  1.9375e-01, -2.3601e-01,  ...,  2.4750e-04,
           1.8413e-01,  7.4281e-01],
         [-2.0558e-01,  5.6973e-01, -3.4450e-01,  ...,  3.1061e-01,
           7.4593e-03,  6.6465e-01],
         [-2.6471e-01,  4.7577e-01, -8.7475e-01,  ...,  4.3364e-01,
           1.6303e-02,  1.2238e+00],
         ...,
         [ 2.1979e-01,  2.8544e-02, -4.8508e-01,  ...,  1.5100e-01,
           5.6132e-01,  6.3370e-01],
         [ 8.1528e-02,  5.9031e-01, -1.4732e-01,  ..., -5.9024e-02,
           5.9021e-01,  8.1966e-01],
         [-7.2517e-03,  1.6594e-01, -7.6334e-01,  ...,  5.0187e-01,
           3.8227e-01,  5.1416e-01]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>)), attentions=None)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

In [38]:
opt_model = optimize_for_mobile(traced_model)

RuntimeError: Tracer cannot infer type of BaseModelOutputWithPooling(last_hidden_state=tensor([[[-2.8095e-01,  1.9375e-01, -2.3601e-01,  ...,  2.4750e-04,
           1.8413e-01,  7.4281e-01],
         [-2.0558e-01,  5.6973e-01, -3.4450e-01,  ...,  3.1061e-01,
           7.4593e-03,  6.6465e-01],
         [-2.6471e-01,  4.7577e-01, -8.7475e-01,  ...,  4.3364e-01,
           1.6303e-02,  1.2238e+00],
         ...,
         [ 2.1979e-01,  2.8544e-02, -4.8508e-01,  ...,  1.5100e-01,
           5.6132e-01,  6.3370e-01],
         [ 8.1528e-02,  5.9031e-01, -1.4732e-01,  ..., -5.9024e-02,
           5.9021e-01,  8.1966e-01],
         [-7.2517e-03,  1.6594e-01, -7.6334e-01,  ...,  5.0187e-01,
           3.8227e-01,  5.1416e-01]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[ 4.6693e-02,  4.0107e-02, -3.7071e-01, -2.7972e-02, -9.5654e-02,
         -8.7964e-02,  1.8528e-02,  2.7150e-01, -2.4230e-01,  9.8927e-05,
         -3.4853e-01, -5.7866e-01, -1.3977e-01,  3.4567e-01, -3.4365e-01,
         -9.6294e-03, -2.6303e-02, -2.1803e-01, -1.1859e-01,  7.5814e-02,
          2.7154e-01, -2.1654e-01, -3.3980e-01,  3.5713e-01,  7.2184e-02,
          7.7747e-02,  2.0000e-01,  3.3041e-01,  1.5145e-01, -3.1463e-01,
         -8.3670e-03, -6.3714e-02,  4.2273e-01,  2.9969e-01,  5.7845e-02,
          2.3974e-01, -4.7594e-02, -1.6258e-01, -1.6344e-01, -1.2472e-01,
         -1.7768e-01,  5.6963e-03, -2.2802e-01,  6.1100e-04,  4.6322e-01,
         -3.9256e-02,  4.0616e-01, -3.4898e-01,  8.9205e-02, -1.6020e-01,
          9.7896e-02, -1.6947e-01,  2.7750e-02,  3.5056e-01,  2.2463e-01,
         -2.5251e-02, -8.3530e-02,  1.1967e-02,  1.8789e-01,  3.0655e-01,
          6.2255e-02, -3.2581e-02,  2.6660e-02, -9.0701e-02, -3.4137e-01,
         -2.6564e-01, -1.7123e-01,  3.8302e-01, -5.6068e-01,  1.9953e-01,
          1.5175e-01,  2.5360e-01,  1.3732e-01, -1.1475e-01,  5.0170e-02,
          2.9594e-01,  4.7890e-01,  8.9323e-02,  4.1593e-01,  2.3964e-01,
          2.2769e-01,  1.6559e-01,  2.8347e-01,  1.4929e-01,  2.6771e-01,
         -1.7873e-01, -1.1776e-01, -8.1478e-02, -2.2917e-01,  6.6764e-02,
         -1.5156e-01,  1.4597e-01,  1.7807e-01, -7.8407e-02, -1.1961e-02,
         -5.5800e-03,  3.2663e-02, -3.6713e-02, -1.0436e-01, -2.6302e-01,
          2.9339e-01,  2.6338e-01, -8.5827e-03,  2.8918e-01,  2.0783e-01,
          6.6061e-02,  8.6752e-02,  2.6657e-01, -2.1378e-01,  5.5170e-01,
         -3.0502e-01,  1.9279e-02,  2.0699e-01, -4.1495e-01, -1.3450e-03,
         -3.4352e-02, -6.2233e-01,  5.7239e-02, -5.3565e-01, -1.9556e-01,
          2.1748e-01, -2.5492e-02, -1.7693e-01, -9.1041e-02, -8.2664e-02,
          1.1888e-01, -2.6719e-01,  1.9183e-01, -8.6937e-02, -1.1656e-01,
          9.3656e-02, -2.4710e-01,  4.2405e-01, -9.8818e-02, -1.9976e-01,
          5.9881e-02, -2.5661e-01, -4.5558e-04,  5.4826e-02, -9.8980e-02,
         -2.1050e-01, -3.2874e-02, -1.3713e-01,  1.4698e-01,  1.3910e-01,
         -1.8795e-01,  2.5036e-02,  1.5718e-01,  7.8290e-02,  3.4586e-01,
         -1.9084e-01, -2.4636e-01, -2.7752e-01, -3.2277e-01, -2.1060e-01,
          1.5730e-02, -1.6699e-01, -4.1642e-01,  1.0134e-01,  1.4960e-01,
         -1.6346e-01, -5.3434e-01, -2.3700e-01, -4.8405e-01, -4.0551e-01,
          1.3371e-01,  1.7637e-01, -4.6814e-01,  2.1758e-01, -9.9638e-02,
         -6.6738e-02,  5.5015e-01,  1.1654e-01,  1.7745e-02, -4.3454e-01,
         -9.5787e-02,  3.9904e-01, -1.8967e-01,  1.4247e-01,  4.7395e-01,
         -1.0383e-02, -2.8846e-02, -4.6897e-02, -1.0486e-01, -5.7609e-01,
         -3.6152e-01,  2.1397e-01, -3.2449e-01,  3.2640e-01,  1.3473e-01,
         -4.2674e-02,  3.7581e-01, -1.7810e-01, -1.5909e-01,  4.2082e-01,
          2.9741e-01, -1.1551e-01, -3.5975e-01,  3.0413e-01,  2.8718e-01,
          8.3572e-02, -7.6330e-02,  3.6599e-01, -3.5791e-01, -1.1984e-01,
          4.6330e-02, -2.1017e-01,  9.6010e-02,  2.5951e-01,  3.1474e-01,
          3.2159e-01,  1.7671e-01, -8.4449e-02,  1.2721e-01,  6.9955e-02,
          3.3938e-01, -3.6960e-02, -2.6437e-01, -3.6957e-01, -3.1519e-01,
         -2.8044e-01, -3.1960e-02, -3.2118e-01, -3.3568e-01,  5.6361e-02,
         -7.9344e-02, -1.0817e-01, -4.1884e-01,  5.6373e-01,  1.7060e-01,
         -3.7492e-01, -4.9206e-01, -1.1492e-01, -5.3896e-02, -1.0919e-01,
         -2.6218e-01,  2.0154e-01,  4.4679e-02,  6.2231e-02,  8.6094e-02,
         -1.1548e-01,  4.0717e-01,  1.0435e-01,  1.1964e-01,  5.5811e-02,
         -1.3387e-01, -3.0432e-01,  1.9487e-01, -2.3411e-01,  2.3979e-01,
          8.3957e-02, -1.7051e-01,  2.9582e-01,  7.6980e-02, -2.1012e-01,
         -2.2019e-01,  5.1967e-02, -2.0904e-01, -3.1589e-01,  1.3976e-01,
          7.3378e-02,  1.2886e-02, -3.7941e-02, -4.1992e-01, -6.7913e-01,
         -3.8605e-02,  3.3250e-01, -1.4962e-01, -2.3367e-01,  2.1604e-02,
          9.5433e-01, -3.3314e-01,  3.3863e-02, -1.1041e-01,  1.3514e-01,
          1.7344e-01,  3.0827e-01,  1.9480e-02,  1.3813e-01,  7.5715e-03,
          2.6293e-01,  1.4490e-01, -3.7647e-01,  2.0370e-01, -1.2067e-01,
         -1.9894e-01, -2.9209e-01, -1.8561e-01, -2.8015e-01, -1.9971e-01,
          7.9241e-02,  2.7789e-01,  2.0711e-01,  9.2009e-02, -1.2497e-01,
          1.3548e-02, -4.1796e-01, -9.4340e-02,  3.9416e-02,  4.4411e-02,
         -2.0689e-01,  7.8503e-02,  2.8727e-01,  4.8503e-01, -3.9871e-01,
          5.5450e-02, -2.8031e-01, -1.6286e-01,  3.7779e-01, -4.2408e-01,
          4.3749e-03,  1.3447e-01, -7.2981e-02, -2.7551e-01, -1.6305e-01,
         -1.8913e-01, -4.5124e-02,  1.1433e-01, -1.1529e-01,  1.6803e-01,
         -4.9927e-02,  3.2612e-02, -1.4389e-01, -6.5885e-01, -6.7365e-02,
         -1.4673e-01, -4.8238e-02, -1.4534e-01,  6.4086e-01,  3.6175e-01,
         -1.3694e-01, -3.0026e-01,  4.9679e-01, -1.2416e-01, -3.7848e-01,
         -2.7967e-01, -3.2192e-02, -5.8232e-01,  2.5316e-01, -3.7612e-01,
         -7.5341e-02, -1.0061e-01, -4.2381e-01,  2.0430e-01,  3.5421e-01,
          4.7512e-01, -2.4318e-02, -3.1277e-01,  5.9076e-02, -2.8672e-01,
         -2.0519e-01,  1.0690e-01,  2.1698e-01, -4.6779e-01, -4.2751e-01,
         -7.9820e-02, -1.0694e-01, -1.5257e-01, -2.0322e-01,  2.8396e-01,
         -9.1225e-02,  5.7800e-02,  2.5577e-01, -3.4644e-01,  1.6233e-01,
          6.3564e-01, -3.3724e-01,  1.9419e-01, -3.6530e-01,  1.0729e-01,
         -4.0173e-01, -3.8239e-01, -1.6141e-01,  7.4453e-02, -4.3982e-01,
         -3.7165e-01,  9.4542e-02, -3.9154e-01, -2.8783e-01, -2.0081e-01,
          2.1023e-02, -3.1915e-01,  5.1339e-01,  4.0421e-02, -5.2414e-01,
         -1.3860e-01,  5.0436e-01, -4.0707e-01, -5.7195e-01,  2.2586e-02,
         -2.6875e-01,  1.8693e-01, -2.5197e-01,  4.9181e-01, -5.9845e-01,
         -6.9314e-02, -5.7444e-02,  7.7586e-02,  3.3125e-02,  1.4080e-01,
         -1.8913e-01,  1.2483e-01,  1.8294e-01, -1.0081e+00,  2.8459e-01,
          2.2694e-01,  2.5965e-01, -7.4635e-02,  3.8446e-01, -4.5130e-01,
         -1.8509e-01,  8.3799e-02,  1.6537e-01,  3.1572e-01,  7.6182e-02,
          1.5847e-01,  3.7006e-01,  3.2070e-01,  2.1765e-01,  1.5052e-01,
         -4.7995e-01, -4.4966e-01, -3.5002e-01,  1.4406e-01, -8.8948e-02,
          5.1441e-01, -4.1660e-01,  3.4986e-01, -1.1441e-01, -1.0048e-01,
          1.6408e-01,  8.8887e-02,  1.5933e-01,  9.3904e-02,  1.1400e-01,
          1.8206e-01,  3.3361e-01, -1.2313e-01, -3.8495e-01,  2.2998e-01,
          1.0782e-01, -4.6062e-01, -2.9444e-01,  2.4281e-01, -1.2217e-01,
         -2.4393e-01,  1.8185e-01, -4.1725e-01, -3.6603e-01, -7.4732e-02,
          3.1326e-01, -1.5488e-01, -6.4646e-01, -8.3071e-02,  2.3564e-02,
         -3.8589e-01,  1.0421e-01, -1.0521e-01, -1.6525e-01, -1.1996e-01,
          1.4955e-01,  2.4033e-01, -8.6674e-02, -5.4170e-01, -2.5583e-01,
          5.6224e-02, -1.6628e-01, -4.1819e-01, -2.3568e-01, -4.9047e-01,
          1.7005e-02,  3.6465e-04,  2.6325e-01, -2.7289e-01,  1.3864e-01,
          8.5591e-02,  3.5195e-01, -5.7063e-02,  5.9695e-03,  3.2521e-02,
         -1.8778e-01,  1.8255e-01,  2.8537e-01, -1.4293e-02,  1.1457e-02,
         -3.1442e-01,  1.9079e-01, -3.3102e-01, -4.6480e-02, -2.7503e-01,
          4.3797e-01,  2.5024e-01, -1.1668e-01, -1.2751e-01, -1.4074e-01,
          1.5598e-01, -6.7441e-02, -1.0944e-01, -4.0199e-01,  1.6955e-01,
         -3.8893e-01,  1.8294e-01,  1.3988e-01,  3.9773e-01, -1.6795e-01,
          6.2370e-02, -8.5830e-02, -3.9306e-02, -3.4178e-01, -1.6844e-01,
          7.8893e-03, -3.4126e-01,  4.5400e-01,  1.6699e-01, -4.2358e-01,
         -1.6840e-01, -5.7318e-02,  3.9432e-02,  2.4061e-01,  1.4965e-01,
          1.8138e-01, -6.6945e-02, -3.7419e-01, -7.3155e-02, -2.2167e-01,
          1.9476e-01, -9.5210e-02,  2.7536e-01, -1.5343e-01, -5.6016e-01,
         -6.1347e-02,  7.8486e-02,  2.2220e-02,  1.2293e-01,  1.9137e-02,
          3.6239e-01, -3.9407e-03, -3.0480e-01,  1.0109e-01,  1.4238e-01,
         -1.0127e-01,  1.8982e-01, -7.3263e-02, -2.4488e-01,  4.4750e-01,
          3.5579e-01, -1.6979e-01,  1.6875e-01,  2.4465e-02, -1.2170e-01,
         -9.6275e-02,  8.9936e-02,  6.1709e-01,  1.4500e-01, -4.8076e-01,
         -3.0713e-01,  1.4727e-01, -7.0945e-03, -1.5526e-01,  9.5915e-02,
          7.4484e-02,  4.3584e-01,  1.6116e-01, -5.9871e-01, -5.9422e-01,
          7.6837e-02, -1.5509e-01, -3.6611e-01,  2.8282e-02,  3.7231e-01,
          3.2762e-01,  2.2060e-02, -3.5888e-01, -2.4051e-01,  6.4066e-02,
          3.5439e-01,  1.2462e-01,  1.6493e-01, -3.1871e-01,  5.5375e-02,
          1.1820e-01,  8.3088e-01,  8.5695e-02, -1.1434e-02,  5.5705e-02,
         -3.7754e-02,  1.1020e-01, -3.7759e-01,  1.2221e-01,  1.5877e-01,
          3.7045e-01, -2.2702e-01,  4.3940e-01, -1.2291e-01,  6.0639e-02,
          9.0442e-03, -1.1641e-01,  1.6571e-02, -2.3472e-01,  4.0214e-01,
          6.4819e-03,  6.5207e-03, -6.6666e-01,  1.4141e-01, -1.0585e-01,
         -2.4113e-01,  2.0391e-01, -1.0582e-01, -2.7862e-01, -1.3709e-01,
          2.6258e-03, -7.9851e-02,  4.0725e-03, -4.6333e-01, -8.4179e-02,
         -3.8501e-02, -2.1379e-01,  9.3424e-02, -3.2262e-01, -3.8604e-01,
          2.5370e-01,  1.5727e-01, -7.4532e-01, -1.9181e-01,  1.5279e-01,
         -2.5710e-01, -5.3802e-01,  3.6783e-01,  3.5995e-01, -2.0221e-01,
         -2.0596e-01,  5.2864e-01,  2.7097e-01,  1.5769e-01,  4.6152e-01,
         -4.7203e-01, -1.2711e-01,  2.9948e-02, -8.5174e-02, -2.4472e-01,
         -1.4154e-01,  3.8125e-01,  2.8051e-01, -1.8072e-01,  4.8548e-01,
          2.4945e-01,  1.4104e-01,  3.8624e-01, -2.8026e-01, -2.0120e-01,
          2.8732e-02, -1.1231e-02, -3.6426e-01,  2.6251e-02, -3.7851e-02,
          6.4009e-01, -1.5928e-01,  9.5855e-02, -6.6296e-02,  1.7410e-01,
          3.8996e-02,  7.4717e-02, -1.3931e-01,  5.8656e-01,  4.9502e-01,
         -2.6947e-02, -4.5381e-01, -2.7717e-01, -1.0380e-01, -1.1805e-01,
          3.7948e-01,  9.4765e-02,  2.8403e-01,  1.1377e-01, -4.7931e-01,
         -2.5457e-01,  1.8128e-02,  1.6517e-01,  3.1461e-01, -2.5417e-02,
         -2.0248e-01, -4.7060e-02,  6.1397e-02,  5.4805e-01,  1.4530e-01,
          7.1823e-02, -3.3883e-01, -1.8383e-01,  3.2484e-01,  6.6465e-02,
         -2.4925e-01,  1.2773e-01,  3.5344e-02,  3.2701e-01, -8.5392e-02,
          4.4462e-01, -1.3644e-01,  1.4186e-01, -2.9954e-01,  9.3068e-02,
          4.8697e-01, -4.9256e-01,  1.0619e-01,  4.2752e-01, -1.7284e-01,
         -2.3642e-01, -4.4011e-02, -2.9467e-01,  6.1980e-01,  5.8373e-01,
          2.2160e-01, -2.1866e-02,  1.7643e-01, -9.1520e-03,  1.2929e-01,
         -7.3082e-02,  1.6397e-01,  5.8583e-02, -4.6242e-01,  4.3455e-01,
         -1.0087e-01, -3.2727e-01,  2.2628e-01, -3.4582e-02,  1.6172e-01,
         -2.7409e-01,  1.0528e-01, -5.9848e-02,  2.3226e-01,  1.4564e-01,
         -1.2072e-01,  2.3905e-01, -1.9347e-01, -4.6393e-02,  2.3129e-01,
         -5.2212e-01, -3.7706e-01,  1.7923e-01, -1.3713e-01, -9.8761e-02,
         -4.0970e-03,  3.2820e-01, -6.4838e-01,  3.7760e-01,  1.0785e-01,
         -2.2988e-01,  3.9778e-01, -1.8354e-01, -2.0468e-01,  8.5161e-03,
          2.8459e-02, -5.1400e-01,  2.5988e-01, -1.5684e-01,  5.7713e-01,
         -1.2787e-01,  2.5969e-01,  1.0446e-01,  7.4681e-04, -2.0539e-01,
         -1.1128e-01, -1.0686e-01, -1.8856e-02,  2.0551e-01,  3.8597e-01,
          3.2728e-01,  2.5913e-01,  5.2154e-01]], device='cuda:0',
       grad_fn=<DivBackward0>), hidden_states=(tensor([[[ 0.3790, -0.1398, -0.1771,  ...,  0.1405, -0.0283,  0.1839],
         [-0.3286,  0.1314, -0.1459,  ...,  0.1504, -0.0726, -0.9195],
         [-0.4018, -0.7783,  0.3085,  ...,  0.2146, -0.1665,  0.1573],
         ...,
         [ 0.0851, -0.2625, -0.4894,  ...,  0.2734,  0.4745,  1.1782],
         [ 0.1677,  0.7349, -0.2215,  ..., -0.0586,  0.4081,  0.0223],
         [-0.5726,  0.0112,  0.0245,  ..., -0.1072,  0.2156, -0.0299]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0796, -0.0358, -0.2100,  ..., -0.0925,  0.1244,  0.1888],
         [-0.1361,  0.1932, -0.3597,  ..., -0.0295,  0.2344, -0.5821],
         [-0.1606, -0.3572, -0.7043,  ..., -0.0127,  0.1001,  0.1497],
         ...,
         [ 0.5795, -0.2205, -0.4840,  ..., -0.1051,  0.7013,  0.6734],
         [ 0.5804,  0.1233, -0.3604,  ..., -0.4208,  0.3692, -0.3083],
         [ 0.1685, -0.1799, -0.1704,  ..., -0.3674,  0.2134,  0.0064]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[ 0.0183, -0.3107, -0.1695,  ..., -0.4951, -0.3178,  0.3724],
         [-0.7932,  0.6165, -0.3781,  ..., -0.1443, -0.8263, -0.0594],
         [-0.5486, -0.0839, -0.6945,  ..., -0.4000, -0.2905,  0.2969],
         ...,
         [ 0.8178,  0.1379, -0.1783,  ..., -0.4655, -0.3185,  0.7949],
         [ 0.5836,  0.5550, -0.3810,  ..., -0.6955, -0.9483, -0.0760],
         [ 0.0729, -0.0122, -0.0600,  ..., -0.0596, -0.1522,  0.0302]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.2324, -0.0543, -0.6694,  ..., -0.5096, -0.4335,  0.2488],
         [-1.4307,  1.3544, -0.3003,  ..., -0.7542, -1.0228, -0.9291],
         [-1.2292,  0.1476, -0.9470,  ..., -0.4628, -0.2401,  0.3044],
         ...,
         [ 1.0194,  0.5541, -0.2348,  ..., -0.1541, -0.1016,  0.5341],
         [ 0.8387,  1.2867, -0.0927,  ..., -0.7560, -0.3953, -0.5713],
         [ 0.0315, -0.0189,  0.0095,  ..., -0.0260, -0.0536, -0.0273]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-9.5326e-01, -4.3661e-01, -9.4581e-01,  ..., -1.5457e-01,
          -6.4549e-01, -1.9762e-01],
         [-2.1118e+00,  8.4958e-01,  3.3048e-01,  ...,  1.1778e-01,
          -8.8292e-01, -9.3446e-01],
         [-1.1577e+00, -3.1500e-01, -7.4675e-01,  ...,  2.2091e-01,
          -2.0268e-01,  4.8733e-01],
         ...,
         [ 4.8777e-01,  2.2107e-01,  3.3298e-01,  ...,  1.6233e-01,
           9.3253e-02,  3.7051e-01],
         [ 2.6080e-01,  9.9856e-01,  3.5275e-01,  ..., -2.5345e-02,
          -2.1419e-01, -8.3197e-01],
         [ 2.2233e-02,  2.0702e-03,  3.6527e-02,  ..., -5.7560e-02,
          -1.1700e-01, -5.4931e-02]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>), tensor([[[-0.9370, -0.6154,  0.0175,  ..., -0.2686,  0.0505,  0.3943],
         [-1.4244, -0.1362,  0.5284,  ...,  0.1380, -0.2261, -0.3084],
         [-0.7744, -0.5110, -0.2723,  ...,  0.0668,  0.0397,  0.6541],
         ...,
         [ 0.0961,  0.1713,  0.6896,  ..., -0.0356,  0.1312,  0.2198],
         [-0.2215,  0.0862,  0.9392,  ..., -0.1537, -0.0156, -0.4499],
         [-0.1323, -0.0390,  0.1072,  ...,  0.0682, -0.0397, -0.2653]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>), tensor([[[-2.8095e-01,  1.9375e-01, -2.3601e-01,  ...,  2.4750e-04,
           1.8413e-01,  7.4281e-01],
         [-2.0558e-01,  5.6973e-01, -3.4450e-01,  ...,  3.1061e-01,
           7.4593e-03,  6.6465e-01],
         [-2.6471e-01,  4.7577e-01, -8.7475e-01,  ...,  4.3364e-01,
           1.6303e-02,  1.2238e+00],
         ...,
         [ 2.1979e-01,  2.8544e-02, -4.8508e-01,  ...,  1.5100e-01,
           5.6132e-01,  6.3370e-01],
         [ 8.1528e-02,  5.9031e-01, -1.4732e-01,  ..., -5.9024e-02,
           5.9021e-01,  8.1966e-01],
         [-7.2517e-03,  1.6594e-01, -7.6334e-01,  ...,  5.0187e-01,
           3.8227e-01,  5.1416e-01]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>)), attentions=None)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]

In [3]:
# import torch
# torch.save(quantized_model, './quantized_model.pth')

  from .autonotebook import tqdm as notebook_tqdm


NameError: name 'quantized_model' is not defined

Compare with default BERT

## 1. Loading and Preparing Wiki-1M data

Use huggingface `datasets` library to load local file data.

In [4]:
import numpy as np

from datasets import load_dataset

data_files = {'train': 'data/training/wiki1m_for_simcse.txt'}
datasets = load_dataset('text', data_files=data_files)

Using custom data configuration default-84caea1147087fa9
Reusing dataset text (C:\Users\ng-ka\.cache\huggingface\datasets\text\default-84caea1147087fa9\0.0.0\4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8)
100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  5.44it/s]


In [5]:
# Unsupervised / Self-supervised dataset

column_names = datasets["train"].column_names
sent0_cname = column_names[0]
sent1_cname = column_names[0]

print('column_names:', column_names)
print('sent0_cname:', sent0_cname, '| sent1_cname:', sent1_cname)

column_names: ['text']
sent0_cname: text | sent1_cname: text


In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [7]:
def prepare_features(examples):
    total = len(examples[sent0_cname])

    # Avoid "None" fields 
    for idx in range(total):
        if examples[sent0_cname][idx] is None:
            examples[sent0_cname][idx] = " "
        if examples[sent1_cname][idx] is None:
            examples[sent1_cname][idx] = " "
    
    sentences = examples[sent0_cname] + examples[sent1_cname]

    sent_features = tokenizer(
        sentences,
        max_length=32,
        truncation=True,
        padding=True,
    )

    features = {}
    for key in sent_features:
        features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]

    return features

In [8]:
train_dataset = datasets["train"].map(prepare_features,
                                      batched=True,
                                    #   num_proc=24,
                                      remove_columns=column_names)

Loading cached processed dataset at C:\Users\ng-ka\.cache\huggingface\datasets\text\default-84caea1147087fa9\0.0.0\4b86d314f7236db91f0a0f5cda32d4375445e64c5eda2692655dd99c2dac68e8\cache-dd04d9dec26aeca9.arrow


In [7]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased")


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

Sentence 1 and Sentence 2 are the same sentence

## 2. Contrastive Learning Model

In [9]:
import torch
import torch.nn as nn

from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling

from distilface.modules.pooler import Pooler
from distilface.modules.similarity import Similarity


class BertCLModel(BertPreTrainedModel):
    def __init__(self, config, pooler_type='avg_first_last', temp=0.05):
        super().__init__(config)

        self.config = config
        self.pooler_type = pooler_type
        self.temp = 0.05

        self.bert = BertModel(config, add_pooling_layer=False)
        self.pooler = Pooler(pooler_type)
        self.sim = Similarity(temp=temp)

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        if self.training:
            return self.cl_forward(self.bert, input_ids, attention_mask, token_type_ids)
        else:
            return self.sent_emb(self.bert, input_ids, attention_mask, token_type_ids)

    def cl_forward(self, encoder, input_ids=None, attention_mask=None, token_type_ids=None):
        batch_size = input_ids.size(0)
        num_sent = input_ids.size(1)  # Number of sentences in one instance: 2 sentences

        # Flatten all input tensors
        input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
        attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
        token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)

        # Pre-trained Model Encoder
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True,
            output_hidden_states=True,
            return_dict=True,
        )

        # Pooling
        pooler_output = self.pooler(attention_mask, outputs)
        pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1)))  # (bs, num_sent, hidden)

        # Separate representation
        z1, z2 = pooler_output[:, 0], pooler_output[:, 1]

        # Cosine similarity
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))

        # Calculate contrastive loss
        criterion = nn.CrossEntropyLoss()
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss = criterion(cos_sim, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def sent_emb(self, encoder, input_ids=None, attention_mask=None, token_type_ids=None):
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=True,
            output_hidden_states=True,
            return_dict=True,
        )
        pooler_output = self.pooler(attention_mask, outputs)

        return BaseModelOutputWithPooling(
            pooler_output=pooler_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )


pretrained_model_name = 'bert-base-uncased'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained(pretrained_model_name)

model = BertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

model.eval();


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertCLModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertCLModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertCLModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## 2.1 Initial BERT embeddings performance

In [10]:
import senteval

def prepare(params, samples):
    return

def batcher(params, batch):
    sentences = [" ".join(s) for s in batch]
    batch = tokenizer.batch_encode_plus(
        sentences,
        return_tensors="pt",
        padding=True,
    )

    for k in batch:
        batch[k] = batch[k].to(device)

    with torch.no_grad():
        outputs = model(**batch)

    pooled_result = outputs.pooler_output.cpu()

    return pooled_result


def evaluate_model():
    PATH_TO_DATA = "./data"

    params = {"task_path": PATH_TO_DATA, "usepytorch": True, "kfold": 10}
    tasks = ["STSBenchmark", 'STS12', 'STS13', 'STS14', 'STS15']

    se = senteval.engine.SE(params, batcher, prepare)
    #se = se_engine.SE(params, batcher, prepare)
    results = se.eval(tasks)

    print('STS12: ', results["STS12"]["all"]["spearman"]["all"])
    print('STS13: ', results["STS13"]["all"]["spearman"]["all"])
    print('STS14: ', results["STS14"]["all"]["spearman"]["all"])
    print('STS15: ', results["STS15"]["all"]["spearman"]["all"])
    print('STSB: ', results["STSBenchmark"]["test"]["spearman"][0])

    return results

## 3. Trainer

In [15]:
import mlflow

from transformers import Trainer, TrainingArguments
from transformers import default_data_collator

training_args = TrainingArguments(
    output_dir='output',
    overwrite_output_dir=True,
    learning_rate=1e-05,
    per_device_train_batch_size= 128,
    per_device_eval_batch_size = 128,
    weight_decay=0.0,
    num_train_epochs=2,
    max_steps= 30000,
    logging_steps=5000,
    save_steps=5000
)

model.train()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
max_steps is given, it will override any value given in num_train_epochs


In [21]:
os.getcwd()

'C:\\Users\\ng-ka\\OMSCS\\DL\\DLProject\\contrastive-learning-in-distilled-models'

In [None]:
model_path = 'trained_model/bert_cl'

train_result = trainer.train()
torch.save(model, './bert_model_best_params.pth')

***** Running training *****
  Num examples = 1000000
  Num Epochs = 4
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 30000


Step,Training Loss
5000,0.0002


Saving model checkpoint to output\checkpoint-5000
Configuration saved in output\checkpoint-5000\config.json
Model weights saved in output\checkpoint-5000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-5000\tokenizer_config.json
Special tokens file saved in output\checkpoint-5000\special_tokens_map.json


In [14]:
def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (KB):', size/1e3)
    os.remove('temp.p')
    return size

# compare the sizes
f=print_size_of_model(model,"fp32")

model:  fp32  	 Size (KB): 435650.293


In [None]:
q=print_size_of_model(quantized_model,"int8")
print("{0:.2f} times smaller".format(f/q))

In [37]:
mlflow.end_run()

## 4. Evaluate DistilBert CL Model performance

In [16]:
model.eval()

results = evaluate_model()
results

  sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
  sent2 = np.array([s.split() for s in sent2])[not_empty_idx]


STS12:  0.6522348679699254
STS13:  0.7245264269994489
STS14:  0.6591385536542372
STS15:  0.7803083813296277
STSB:  0.7317242200449551


{'STSBenchmark': {'train': {'pearson': (0.7701978269277896, 0.0),
   'spearman': SpearmanrResult(correlation=0.7444988754489594, pvalue=0.0),
   'nsamples': 5749},
  'dev': {'pearson': (0.792601430214866, 0.0),
   'spearman': SpearmanrResult(correlation=0.7923150663580373, pvalue=0.0),
   'nsamples': 1500},
  'test': {'pearson': (0.7395838046909169, 4.979487826512024e-239),
   'spearman': SpearmanrResult(correlation=0.7317242200449551, pvalue=1.7326856147338482e-231),
   'nsamples': 1379},
  'all': {'pearson': {'all': 0.7683647566030786,
    'mean': 0.7674610206111909,
    'wmean': 0.7691997588084071},
   'spearman': {'all': 0.7548919636159135,
    'mean': 0.7561793872839839,
    'wmean': 0.7507700897004076}}},
 'STS12': {'MSRpar': {'pearson': (0.6587949262407169, 1.5689202234706753e-94),
   'spearman': SpearmanrResult(correlation=0.6366967124028455, pvalue=1.678438485809478e-86),
   'nsamples': 750},
  'MSRvid': {'pearson': (0.8536988445568228, 3.8250578501088855e-214),
   'spearman':

# 64 batch size

In [21]:
mlflow.end_run()

In [18]:
import torch
import torch.nn as nn

from transformers import AutoTokenizer, DistilBertModel, DistilBertPreTrainedModel, AutoConfig
from transformers.modeling_outputs import SequenceClassifierOutput, BaseModelOutputWithPooling

from src.distilface.modules.pooler import Pooler
from src.distilface.modules.similarity import Similarity

from torch.cuda.amp import autocast 

#scaler = GradScaler()

class DistilBertCLModel(DistilBertPreTrainedModel):
    def __init__(self, config, pooler_type='avg_first_last', temp=0.05):
        super().__init__(config)

        self.config = config
        self.pooler_type = pooler_type
        self.temp = 0.05

        self.distilbert = DistilBertModel(config)
        self.pooler = Pooler(pooler_type)
        self.sim = Similarity(temp=temp)

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None):
        with autocast():
            if self.training:
                return self.cl_forward(self.distilbert, input_ids, attention_mask)
            else:
                return self.sent_emb(self.distilbert, input_ids, attention_mask)

    def cl_forward(self, encoder, input_ids=None, attention_mask=None):
        batch_size = input_ids.size(0)#64#input_ids.size(0)
        num_sent = input_ids.size(1)  # Number of sentences in one instance: 2 sentences

        # Flatten all input tensors
        input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
        attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)

        # Pre-trained Model Encoder
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        # Pooling
        pooler_output = self.pooler(attention_mask, outputs)
        pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1)))  # (bs, num_sent, hidden)

        # Separate representation
        z1, z2 = pooler_output[:, 0], pooler_output[:, 1]

        # Cosine similarity
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))

        # Calculate contrastive loss
        criterion = nn.CrossEntropyLoss()
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss = criterion(cos_sim, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def sent_emb(self, encoder, input_ids=None, attention_mask=None):
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        pooler_output = self.pooler(attention_mask, outputs)

        return BaseModelOutputWithPooling(
            pooler_output=pooler_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )


pretrained_model_name = 'distilbert-base-uncased'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained(pretrained_model_name)

model2 = DistilBertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)

#model.eval();


loading configuration file https://huggingface.co/distilbert-base-uncased/resolve/main/config.json from cache at C:\Users\ng-ka/.cache\huggingface\transformers\23454919702d26495337f3da04d1655c7ee010d5ec9d77bdb9e399e00302c0a1.91b885ab15d631bf9cee9dc9d25ece0afd932f2f5130eba28f2055b2220c0333
Model config DistilBertConfig {
  "_name_or_path": "distilbert-base-uncased",
  "activation": "gelu",
  "architectures": [
    "DistilBertForMaskedLM"
  ],
  "attention_dropout": 0.1,
  "dim": 768,
  "dropout": 0.1,
  "hidden_dim": 3072,
  "initializer_range": 0.02,
  "max_position_embeddings": 512,
  "model_type": "distilbert",
  "n_heads": 12,
  "n_layers": 6,
  "pad_token_id": 0,
  "qa_dropout": 0.1,
  "seq_classif_dropout": 0.2,
  "sinusoidal_pos_embds": false,
  "tie_weights_": true,
  "transformers_version": "4.17.0",
  "vocab_size": 30522
}

loading weights file https://huggingface.co/distilbert-base-uncased/resolve/main/pytorch_model.bin from cache at C:\Users\ng-ka/.cache\huggingface\transfor

In [19]:
#model2 = DistilBertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)

training_args2 = TrainingArguments(
    output_dir='output',
    overwrite_output_dir=False,
    learning_rate=5e-05,
    per_device_train_batch_size= 64,
    per_device_eval_batch_size = 64,
    weight_decay=0.0,
    num_train_epochs=2,
    max_steps= 30000,
    logging_steps=5000,
    save_steps=5000,
    fp16=True
)

model2.train()

trainer2 = Trainer(
    model=model2,
    args=training_args2,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


In [22]:
model_path = 'trained_model/distilbert_cl'

#train_result = trainer.train(model_path=model_path)
train_result2 = trainer2.train()
torch.save(model2, './batch64_fp16_model.pth')

***** Running training *****
  Num examples = 1000000
  Num Epochs = 2
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 30000


Step,Training Loss
5000,0.0001
10000,0.0001
15000,0.0001
20000,0.0001
25000,0.0001
30000,0.0001


Saving model checkpoint to output\checkpoint-5000
Configuration saved in output\checkpoint-5000\config.json
Model weights saved in output\checkpoint-5000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-5000\tokenizer_config.json
Special tokens file saved in output\checkpoint-5000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-10000
Configuration saved in output\checkpoint-10000\config.json
Model weights saved in output\checkpoint-10000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-10000\tokenizer_config.json
Special tokens file saved in output\checkpoint-10000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-15000
Configuration saved in output\checkpoint-15000\config.json
Model weights saved in output\checkpoint-15000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-15000\tokenizer_config.json
Special tokens file saved in output\checkpoint-15000\special_tokens_map.json
Saving model check

In [23]:
import senteval
#import SentEval.senteval as senteval
#import SentEval_simcse.senteval as senteval
#import SentEval_simcse.senteval.engine as se_engine


def prepare(params, samples):
    return

def batcher(params, batch):
    sentences = [" ".join(s) for s in batch]
    batch = tokenizer.batch_encode_plus(
        sentences,
        return_tensors="pt",
        padding=True,
    )

    for k in batch:
        batch[k] = batch[k].to(device)

    with torch.no_grad():
        outputs = model2(**batch)

    pooled_result = outputs.pooler_output.cpu()

    return pooled_result


def evaluate_model():
    PATH_TO_DATA = "./data"

    params = {"task_path": PATH_TO_DATA, "usepytorch": True, "kfold": 10}
    tasks = ["STSBenchmark", 'STS12', 'STS13', 'STS14', 'STS15']

    se = senteval.engine.SE(params, batcher, prepare)
    #se = se_engine.SE(params, batcher, prepare)
    results = se.eval(tasks)

    print('STS12: ', results["STS12"]["all"]["spearman"]["all"])
    print('STS13: ', results["STS13"]["all"]["spearman"]["all"])
    print('STS14: ', results["STS14"]["all"]["spearman"]["all"])
    print('STS15: ', results["STS15"]["all"]["spearman"]["all"])
    print('STSB: ', results["STSBenchmark"]["test"]["spearman"][0])

    return results

In [24]:
model2.eval()

results2 = evaluate_model()
results2

  sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
  sent2 = np.array([s.split() for s in sent2])[not_empty_idx]


STS12:  0.5860086657732572
STS13:  0.7462740979049128
STS14:  0.6726454938423136
STS15:  0.7632500735789172
STSB:  0.7239425725655111


{'STSBenchmark': {'train': {'pearson': (0.7514974091222588, 0.0),
   'spearman': SpearmanrResult(correlation=0.728916173975978, pvalue=0.0),
   'nsamples': 5749},
  'dev': {'pearson': (0.7515928663596871, 5.587693491694605e-273),
   'spearman': SpearmanrResult(correlation=0.7571062733767996, pvalue=2.91924659407765e-279),
   'nsamples': 1500},
  'test': {'pearson': (0.7324146663602896, 3.863341707867918e-232),
   'spearman': SpearmanrResult(correlation=0.7239425725655111, pvalue=2.7955081361595227e-224),
   'nsamples': 1379},
  'all': {'pearson': {'all': 0.745320014694186,
    'mean': 0.7451683139474118,
    'wmean': 0.7484640391161609},
   'spearman': {'all': 0.7376444193730962,
    'mean': 0.7366550066394296,
    'wmean': 0.7330221722091952}}},
 'STS12': {'MSRpar': {'pearson': (0.6420318665513999, 2.2231580917506402e-88),
   'spearman': SpearmanrResult(correlation=0.6250375288907182, pvalue=1.5936298376149268e-82),
   'nsamples': 750},
  'MSRvid': {'pearson': (0.7821232084995746, 8.1

In [31]:
mlflow.end_run()

# 256 batch size

In [10]:
import mlflow

from transformers import Trainer, TrainingArguments
from transformers import default_data_collator

model3 = DistilBertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)

training_args3 = TrainingArguments(
    output_dir='output',
    overwrite_output_dir=True,
    learning_rate=5e-05,
    per_device_train_batch_size= 256,
    per_device_eval_batch_size = 256,
    weight_decay=0.0,
    num_train_epochs=2,
    max_steps= 30000,
    logging_steps=10000,
    save_steps=10000
)

model3.train()

trainer3 = Trainer(
    model=model3,
    args=training_args3,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertCLModel: ['vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertCLModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertCLModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
max_steps is given, it will override any value given in num_train_epochs


In [11]:
model_path = 'trained_model/distilbert_cl'

#train_result = trainer.train(model_path=model_path)
train_result3 = trainer3.train()
torch.save(model3, './batch256_model.pth')

***** Running training *****
  Num examples = 1000000
  Num Epochs = 8
  Instantaneous batch size per device = 256
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 1
  Total optimization steps = 30000


Step,Training Loss
10000,0.0005
20000,0.0004
30000,0.0003


Saving model checkpoint to output\checkpoint-10000
Configuration saved in output\checkpoint-10000\config.json
Model weights saved in output\checkpoint-10000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-10000\tokenizer_config.json
Special tokens file saved in output\checkpoint-10000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-20000
Configuration saved in output\checkpoint-20000\config.json
Model weights saved in output\checkpoint-20000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-20000\tokenizer_config.json
Special tokens file saved in output\checkpoint-20000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-30000
Configuration saved in output\checkpoint-30000\config.json
Model weights saved in output\checkpoint-30000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-30000\tokenizer_config.json
Special tokens file saved in output\checkpoint-30000\special_tokens_map.json


Training co

In [26]:
#mlflow.end_run()

In [12]:
model3.eval()

results3 = evaluate_model()
results3

  sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
  sent2 = np.array([s.split() for s in sent2])[not_empty_idx]


STS12:  0.5697409578579856
STS13:  0.6769343320004665
STS14:  0.6357083664225791
STS15:  0.744575188106335
STSB:  0.6320978112448171


{'STSBenchmark': {'train': {'pearson': (0.7108974767715875, 0.0),
   'spearman': SpearmanrResult(correlation=0.6957934629797473, pvalue=0.0),
   'nsamples': 5749},
  'dev': {'pearson': (0.7092411378218643, 1.0445253840454627e-229),
   'spearman': SpearmanrResult(correlation=0.7145470998781457, pvalue=1.0790202728598763e-234),
   'nsamples': 1500},
  'test': {'pearson': (0.6362564654216417, 2.357894834089188e-157),
   'spearman': SpearmanrResult(correlation=0.6320978112448171, pvalue=1.0318637044184543e-154),
   'nsamples': 1379},
  'all': {'pearson': {'all': 0.694542147265105,
    'mean': 0.6854650266716978,
    'wmean': 0.6986797596788475},
   'spearman': {'all': 0.692993221621729,
    'mean': 0.68081279136757,
    'wmean': 0.6888734527346301}}},
 'STS12': {'MSRpar': {'pearson': (0.5685692938261935, 1.922991733409069e-65),
   'spearman': SpearmanrResult(correlation=0.5768180019887325, pvalue=9.865649661539039e-68),
   'nsamples': 750},
  'MSRvid': {'pearson': (0.7231459307017963, 2.46

In [17]:
mlflow.end_run()

In [11]:
class DistilBertCLModel(DistilBertPreTrainedModel):
    def __init__(self, config, pooler_type='avg_first_last', temp=0.05):
        super().__init__(config)

        self.config = config
        self.pooler_type = pooler_type
        self.temp = 0.05

        self.distilbert = DistilBertModel(config)
        self.pooler = Pooler(pooler_type)
        self.sim = Similarity(temp=temp)

        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None):
        if self.training:
            return self.cl_forward(self.distilbert, input_ids, attention_mask)
        else:
            return self.sent_emb(self.distilbert, input_ids, attention_mask)

    def cl_forward(self, encoder, input_ids=None, attention_mask=None):
        batch_size = input_ids.size(0)#64#input_ids.size(0)
        num_sent = input_ids.size(1)  # Number of sentences in one instance: 2 sentences

        # Flatten all input tensors
        input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
        attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)

        # Pre-trained Model Encoder
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )

        # Pooling
        pooler_output = self.pooler(attention_mask, outputs)
        pooler_output = pooler_output.view((batch_size, num_sent, pooler_output.size(-1)))  # (bs, num_sent, hidden)

        # Separate representation
        z1, z2 = pooler_output[:, 0], pooler_output[:, 1]

        # Cosine similarity
        cos_sim = self.sim(z1.unsqueeze(1), z2.unsqueeze(0))

        # Calculate contrastive loss
        criterion = nn.CrossEntropyLoss()
        labels = torch.arange(cos_sim.size(0)).long().to(self.device)
        loss = criterion(cos_sim, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=cos_sim,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def sent_emb(self, encoder, input_ids=None, attention_mask=None):
        outputs = encoder(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
        )
        pooler_output = self.pooler(attention_mask, outputs)

        return BaseModelOutputWithPooling(
            pooler_output=pooler_output,
            last_hidden_state=outputs.last_hidden_state,
            hidden_states=outputs.hidden_states,
        )

In [15]:
import mlflow

from transformers import Trainer, TrainingArguments
from transformers import default_data_collator

model4 = DistilBertCLModel.from_pretrained(pretrained_model_name, config=config).to(device)

training_args4 = TrainingArguments(
    output_dir='output',
    overwrite_output_dir=True,
    learning_rate=5e-05,
    per_device_train_batch_size= 256,
    per_device_eval_batch_size = 256,
    weight_decay=0.0,
    num_train_epochs=2,
    max_steps= 30000,
    logging_steps=10000,
    save_steps=10000,
    fp16=True
)

model4.train()

trainer4 = Trainer(
    model=model4,
    args=training_args4,
    train_dataset=train_dataset,
    tokenizer=tokenizer,
    data_collator=default_data_collator
)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertCLModel: ['vocab_transform.weight', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertCLModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertCLModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
max_steps is given, it will override any value given in num_train_epochs
Using amp half precision backend


In [16]:
model_path = 'trained_model/distilbert_cl'

#train_result = trainer.train(model_path=model_path)
train_result4 = trainer4.train()
torch.save(model4, './batch256_16bit_model.pth')

***** Running training *****
  Num examples = 1000000
  Num Epochs = 8
  Instantaneous batch size per device = 256
  Total train batch size (w. parallel, distributed & accumulation) = 256
  Gradient Accumulation steps = 1
  Total optimization steps = 30000


Step,Training Loss
10000,0.0004
20000,0.0003
30000,0.0003


Saving model checkpoint to output\checkpoint-10000
Configuration saved in output\checkpoint-10000\config.json
Model weights saved in output\checkpoint-10000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-10000\tokenizer_config.json
Special tokens file saved in output\checkpoint-10000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-20000
Configuration saved in output\checkpoint-20000\config.json
Model weights saved in output\checkpoint-20000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-20000\tokenizer_config.json
Special tokens file saved in output\checkpoint-20000\special_tokens_map.json
Saving model checkpoint to output\checkpoint-30000
Configuration saved in output\checkpoint-30000\config.json
Model weights saved in output\checkpoint-30000\pytorch_model.bin
tokenizer config file saved in output\checkpoint-30000\tokenizer_config.json
Special tokens file saved in output\checkpoint-30000\special_tokens_map.json


Training co

In [17]:
#import senteval
#import SentEval.senteval as senteval
#import SentEval_simcse.senteval as senteval
#import SentEval_simcse.senteval.engine as se_engine


def prepare(params, samples):
    return

def batcher(params, batch):
    sentences = [" ".join(s) for s in batch]
    batch = tokenizer.batch_encode_plus(
        sentences,
        return_tensors="pt",
        padding=True,
    )

    for k in batch:
        batch[k] = batch[k].to(device)

    with torch.no_grad():
        outputs = model4(**batch)

    pooled_result = outputs.pooler_output.cpu()

    return pooled_result


def evaluate_model():
    PATH_TO_DATA = "./data"

    params = {"task_path": PATH_TO_DATA, "usepytorch": True, "kfold": 10}
    tasks = ["STSBenchmark", 'STS12', 'STS13', 'STS14', 'STS15']

    se = senteval.engine.SE(params, batcher, prepare)
    #se = se_engine.SE(params, batcher, prepare)
    results = se.eval(tasks)

    print('STS12: ', results["STS12"]["all"]["spearman"]["all"])
    print('STS13: ', results["STS13"]["all"]["spearman"]["all"])
    print('STS14: ', results["STS14"]["all"]["spearman"]["all"])
    print('STS15: ', results["STS15"]["all"]["spearman"]["all"])
    print('STSB: ', results["STSBenchmark"]["test"]["spearman"][0])

    return results

In [18]:
model4.eval()

results4 = evaluate_model()
results4

  sent1 = np.array([s.split() for s in sent1])[not_empty_idx]
  sent2 = np.array([s.split() for s in sent2])[not_empty_idx]


STS12:  0.6028650839372993
STS13:  0.740729745204717
STS14:  0.692465335997525
STS15:  0.779843190057093
STSB:  0.7226870749431175


{'STSBenchmark': {'train': {'pearson': (0.7654012355643011, 0.0),
   'spearman': SpearmanrResult(correlation=0.7482016388305586, pvalue=0.0),
   'nsamples': 5749},
  'dev': {'pearson': (0.7670311528467801, 5.261345910632041e-291),
   'spearman': SpearmanrResult(correlation=0.7708464868002052, pvalue=1.119990676071557e-295),
   'nsamples': 1500},
  'test': {'pearson': (0.7318742172557201, 1.2511244745112543e-231),
   'spearman': SpearmanrResult(correlation=0.7226870749431175, pvalue=3.856394414212571e-223),
   'nsamples': 1379},
  'all': {'pearson': {'all': 0.7581791623069682,
    'mean': 0.7547688685556003,
    'wmean': 0.7603260289899136},
   'spearman': {'all': 0.7518720772834617,
    'mean': 0.7472450668579604,
    'wmean': 0.7480605503226412}}},
 'STS12': {'MSRpar': {'pearson': (0.6638885051937329, 1.767074037152879e-96),
   'spearman': SpearmanrResult(correlation=0.6412098115217912, pvalue=4.352737200439051e-88),
   'nsamples': 750},
  'MSRvid': {'pearson': (0.7927758724538693, 5.