In [53]:
import datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
import numpy as np
import torch

In [4]:
ds = datasets.load_dataset('snli')
ds.save_to_disk('data/snli')

Saving the dataset (1/1 shards): 100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 142919.10 examples/s]
Saving the dataset (1/1 shards): 100%|█████████████████████████████████| 10000/10000 [00:00<00:00, 188978.58 examples/s]
Saving the dataset (1/1 shards): 100%|███████████████████████████████| 550152/550152 [00:01<00:00, 495085.25 examples/s]


In [None]:
model_name = 'bert-base-uncased'
token = AutoTokenizer.from_pretrained(model_name)
token.save_pretrained('data/token')
model = AutoModel.from_pretrained(model_name)
model.save_pretrained('data/bert_normal')

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.save_pretrained('data/bert_class')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [62]:
tokenizer = AutoTokenizer.from_pretrained('data/token')
headless = AutoModel.from_pretrained('data/bert_normal')
classifier = AutoModelForSequenceClassification.from_pretrained('data/bert_class')

In [26]:
ds = datasets.load_from_disk('data/snli')

In [9]:
def encode(examples):
        return tokenizer(examples["premise"], examples["hypothesis"], truncation=True)

In [41]:
dataset = ds['train']

In [28]:
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 550152
})

In [29]:
np.unique(dataset['label'])

array([-1,  0,  1,  2])

In [42]:
dataset = dataset.select(np.arange(1000))

In [43]:
dataset = dataset.filter(lambda batch: np.array(batch["label"]) != -1, batched=True)
dataset = dataset.map(encode, batched=True)
dataset = dataset.map(lambda examples: {"labels": examples["label"]}, batched=True)
dataset.set_format(
    type="torch",
    columns=["input_ids", "token_type_ids", "attention_mask", "labels"],
)

Filter: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 10902.57 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████| 998/998 [00:00<00:00, 3377.54 examples/s]
Map: 100%|██████████████████████████████████████████████████████████████████| 998/998 [00:00<00:00, 23568.02 examples/s]


In [47]:
dataset[0]

{'input_ids': tensor([  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
          2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
          2005,  1037,  2971,  1012,   102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]),
 'labels': tensor(1)}

In [51]:
(dataset.remove_columns('labels')[0])

{'input_ids': tensor([  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
          2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
          2005,  1037,  2971,  1012,   102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1])}

In [63]:
headless.eval()
classifier.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [64]:
with torch.no_grad():
    print(headless(**(dataset.remove_columns('labels')[0:1])))

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.8701,  0.1044, -0.8448,  ..., -0.7244,  0.2107,  0.1728],
         [-0.8112,  0.1907, -0.9818,  ..., -0.3051,  0.4335,  0.0024],
         [-1.1000, -0.3378, -0.5038,  ..., -0.3010,  0.1676, -0.4993],
         ...,
         [ 0.2257, -0.6626,  0.3625,  ..., -0.3866, -0.8420, -0.8623],
         [ 0.6308,  0.0563, -0.5607,  ..., -0.0981, -0.4543, -0.2846],
         [ 0.6324,  0.0537, -0.5595,  ..., -0.0900, -0.4814, -0.2747]]]), pooler_output=tensor([[-0.9884, -0.9076, -0.9998,  0.9898,  0.9931, -0.8266,  0.9948,  0.8217,
         -0.9989, -1.0000, -0.9813,  0.9996,  0.9930,  0.9704,  0.9863, -0.9851,
         -0.9729, -0.9264,  0.8106, -0.9316,  0.9574,  1.0000, -0.8754,  0.8351,
          0.9131,  1.0000, -0.9784,  0.9803,  0.9808,  0.8880, -0.9737,  0.7851,
         -0.9968, -0.7635, -0.9998, -0.9993,  0.9240, -0.9121, -0.7272, -0.6074,
         -0.9706,  0.8622,  1.0000,  0.6741,  0.9449, -0.8274, -1.0000,  0.

In [66]:
with torch.no_grad():
    print(classifier(**(dataset.remove_columns('labels')[0:1]),output_hidden_states=True))

SequenceClassifierOutput(loss=None, logits=tensor([[0.0348, 0.1623]]), hidden_states=(tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.5956,  0.5420,  0.0412,  ...,  0.4376,  0.5639,  0.3365],
         [-0.5427,  0.3790, -0.4397,  ...,  0.5385,  1.1806, -0.9872],
         ...,
         [ 0.6080,  0.1944, -0.4991,  ...,  0.1576,  0.0924,  0.0526],
         [-0.1864,  0.2774, -0.2251,  ...,  0.7510,  0.4162,  0.3754],
         [-0.4083, -0.0742, -0.2198,  ..., -0.0169,  0.0754, -0.1792]]]), tensor([[[ 0.0145,  0.0151, -0.2307,  ...,  0.2228, -0.1197, -0.0114],
         [ 0.7825,  0.7195,  0.0461,  ...,  0.5636,  0.5113,  0.3183],
         [-0.5592,  0.4928, -0.4244,  ..., -0.2110,  0.5988, -0.9690],
         ...,
         [ 0.8251,  0.8316, -0.2125,  ...,  0.2319, -0.2733, -0.4672],
         [-0.1086,  0.1292, -0.1864,  ...,  0.3960,  0.1578,  0.1404],
         [-0.1343,  0.0762, -0.0862,  ...,  0.1826,  0.3630, -0.0585]]]), tensor([[[-0.1319, -0.2054, 

In [86]:
dataset[0]

{'input_ids': tensor([  101,  1037,  2711,  2006,  1037,  3586, 14523,  2058,  1037,  3714,
          2091, 13297,  1012,   102,  1037,  2711,  2003,  2731,  2010,  3586,
          2005,  1037,  2971,  1012,   102]),
 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1]),
 'labels': tensor(1)}

In [101]:
def get_representation(examples):
    # print(examples)
    with torch.no_grad():
        return {'z': headless(**tokenizer(examples['premise'],examples['hypothesis'],return_tensors='pt',truncation=True,padding=True)).pooler_output}

In [115]:
dataset = dataset.map(get_representation,batched=True)

In [121]:
np.where(dataset['labels']==0)[0]

array([  2,   4,   7,  14,  17,  18,  23,  25,  29,  31,  35,  37,  40,
        43,  44,  45,  49,  52,  56,  59,  60,  63,  67,  70,  74,  77,
        79,  83,  84,  88,  91,  96,  99, 103, 104, 106, 109, 112, 116,
       117, 122, 125, 127, 131, 133, 135, 138, 142, 144, 148, 150, 154,
       157, 158, 162, 165, 168, 171, 172, 173, 175, 183, 186, 187, 190,
       194, 198, 202, 205, 208, 210, 214, 216, 218, 222, 223, 227, 231,
       232, 237, 241, 242, 246, 248, 250, 257, 261, 262, 265, 267, 270,
       272, 277, 279, 282, 284, 288, 293, 294, 299, 300, 302, 305, 306,
       308, 312, 318, 322, 323, 329, 333, 334, 335, 338, 343, 346, 349,
       351, 355, 357, 360, 363, 367, 368, 371, 374, 378, 381, 385, 388,
       391, 393, 397, 400, 401, 405, 408, 411, 414, 417, 420, 422, 425,
       428, 432, 434, 437, 442, 443, 448, 450, 451, 454, 455, 458, 461,
       465, 469, 473, 477, 480, 483, 486, 490, 493, 495, 499, 500, 504,
       506, 510, 512, 516, 519, 520, 522, 525, 531, 535, 537, 54

In [117]:
dataset.select(np.where(dataset['labels']==1)[0])['z']

tensor([[-0.9884, -0.9076, -0.9998,  ..., -0.9965, -0.9490,  0.9859],
        [-0.9822, -0.8889, -0.9999,  ..., -0.9971, -0.9156,  0.9796],
        [-0.9899, -0.9047, -0.9997,  ..., -0.9913, -0.9477,  0.9897],
        ...,
        [-0.9829, -0.8738, -0.9998,  ..., -0.9962, -0.9162,  0.9777],
        [-0.9935, -0.9265, -0.9999,  ..., -0.9970, -0.9585,  0.9931],
        [-0.9811, -0.8282, -0.9977,  ..., -0.9596, -0.9331,  0.9816]])

In [118]:
torch.mean(dataset.select(np.where(dataset['labels']==1)[0])['z'],axis=0)

tensor([-0.9724, -0.8671, -0.9867,  0.9667,  0.9691, -0.7423,  0.9770,  0.7845,
        -0.9811, -0.9882, -0.9270,  0.9860,  0.9881,  0.9449,  0.9792, -0.9513,
        -0.9219, -0.8780,  0.7194, -0.8989,  0.9415,  0.9880, -0.8232,  0.7591,
         0.8617,  0.9870, -0.9500,  0.9676,  0.9758,  0.8635, -0.9409,  0.7241,
        -0.9941, -0.7040, -0.9868, -0.9904,  0.8789, -0.8784, -0.6542, -0.5599,
        -0.9595,  0.7716,  0.9880,  0.6509,  0.8973, -0.7636, -0.9880,  0.7787,
        -0.9456,  0.9857,  0.9840,  0.9829,  0.7463,  0.8924,  0.8858, -0.8915,
         0.5963,  0.6994, -0.7453, -0.9108, -0.8574,  0.8161, -0.9843, -0.9525,
         0.9845,  0.9814, -0.7861, -0.7937, -0.7410,  0.5514,  0.9772,  0.7374,
        -0.7262, -0.9265,  0.9782,  0.8218, -0.8628,  0.9880, -0.9105, -0.9882,
         0.9842,  0.9817,  0.8601, -0.9563,  0.9445, -0.9880,  0.9293, -0.6590,
        -0.9927,  0.7371,  0.9310, -0.7772,  0.9698,  0.8797, -0.9228, -0.9150,
        -0.8631, -0.9820, -0.7904, -0.88

In [123]:
def calculate_class_center(dataset):
    centers = {}
    for c in np.unique(dataset['labels']):
        centers[c] = torch.mean(dataset.select(np.where(dataset['labels']==c)[0])['z'],axis=0)
    return centers

In [125]:
centers = calculate_class_center(dataset)

In [129]:
cc = torch.stack([centers[c.item()] for c in dataset[0:10]['labels']])

In [130]:
cc

tensor([[-0.9724, -0.8671, -0.9867,  ..., -0.9722, -0.9231,  0.9788],
        [-0.9448, -0.8105, -0.9530,  ..., -0.9239, -0.8883,  0.9686],
        [-0.9772, -0.8664, -0.9912,  ..., -0.9740, -0.9257,  0.9812],
        ...,
        [-0.9772, -0.8664, -0.9912,  ..., -0.9740, -0.9257,  0.9812],
        [-0.9724, -0.8671, -0.9867,  ..., -0.9722, -0.9231,  0.9788],
        [-0.9724, -0.8671, -0.9867,  ..., -0.9722, -0.9231,  0.9788]])

In [138]:
def get_dist(examples):
    cc = torch.stack([centers[c.item()] for c in examples['labels']])
    return {'dist': torch.linalg.norm(examples['z'] - cc,dim=1)}

In [139]:
dataset = dataset.map(get_dist,batched=True)

Map: 100%|███████████████████████████████████████████████████████████████████| 998/998 [00:00<00:00, 1517.78 examples/s]


In [141]:
dataset['dist'].median()

tensor(1.6096)

In [151]:
torch.argsort(dataset['dist'])

tensor([861,  56, 933, 312, 936, 993, 307,  93, 929,  12, 171, 576, 210, 503,
        962, 537, 853, 388, 530,  53, 354, 701,  71, 265, 357, 250, 306, 367,
        968, 209, 198, 378, 191, 116, 291, 631, 458, 872, 305, 945, 104, 302,
        970,  40, 661, 162,  89, 418, 680, 289, 922, 826,  29, 924, 316, 205,
        350, 677, 967, 543, 419, 230, 117, 828, 214, 166, 238, 601, 451, 359,
        674, 261, 889, 926, 798, 202, 642, 975, 930, 859, 897, 624, 915, 645,
         63,  84, 170, 262, 896, 594, 882, 901, 297, 989, 118, 536, 183, 583,
        766, 393, 442, 847, 587, 365, 374, 258, 430, 479, 712, 637, 126, 473,
        165, 149, 839, 285, 996, 820, 288, 448,  45, 484, 932,  60, 525, 893,
        110, 633, 389,  30, 831, 401, 356, 635, 483, 443, 698, 586, 469,  57,
        618, 688, 628, 876, 703, 921, 445, 925, 848, 978,  52, 540, 142,  88,
        606, 883, 641, 404, 127, 776, 775, 822,   8,  34, 362, 405, 420, 836,
        369, 434, 480, 403, 160, 600, 386,  35,  44, 630,  15, 1

In [149]:
def final_select(dataset,fraction):
    N = len(dataset)
    indices = torch.argsort(dataset['dist'])
    start = int(N*(1-fraction)/2)
    end = int(N - N*(1-fraction)/2)+1
    return indices[start:end]

In [156]:
len(final_select(dataset,1))

998

In [None]:
import copy
def moderate(dataset,fraction):
    dataset = copy.deepcopy(dataset)
    dataset = dataset.map(get_representation,batched=True)
    centers = calculate_class_center(dataset)
    dataset = dataset.map(get_dist,batched=True)
    return final_select(dataset,1)