In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from sklearn.model_selection import train_test_split
from torch.optim import Adam
import torch.nn.functional as F
from keras.preprocessing.sequence import pad_sequences

Using TensorFlow backend.


In [2]:
train_df = pd.read_json(r'../data/imdb/train.json').drop('text_b', axis=1)
test_df = pd.read_json(r'../data/imdb/test.json').drop('text_b', axis=1)

In [3]:
train_df = train_df[:100]
test_df = test_df[:100]
train_text = train_df.text_a.values
train_label = [1 if i=='pos' else 0 for i in train_df.label.values]
test_text = test_df.text_a.values
test_label = [1 if i=='pos' else 0 for i in test_df.label.values]

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [5]:
def tokenize_data(data):
    ids = []
    for d in data:
        encoded = tokenizer.encode(d, add_special_tokens=True)
        ids.append(encoded)
        
    print('Max sentence length: ', max([len(sen) for sen in ids]))
    ids = pad_sequences(ids, maxlen=512, dtype="long", 
                        value=0, truncating="post", padding="post") 
    
    attention_masks = []
    n = 0
    for i in ids:
        temp = [float(t>0) for t in i]
        attention_masks.append(temp)
        n += 1
    print('number:'+str(n))
    return ids, attention_masks

In [6]:
train_ids, train_masks = tokenize_data(train_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (514 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1135 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (587 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (886 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (803 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th

Max sentence length:  1361
number:100


In [7]:
print(len(train_masks))

100


In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print('Using GPU:', torch.cuda.get_device_name(0))
else:
    print('Using CPU')
    device = torch.device("cpu")

# Using 'BertForSequenceClassification'
model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased", num_labels=2,
        output_attentions = False, output_hidden_states = False).to(device)
optimizer = Adam(model.parameters(), lr=1e-6)

Using GPU: GeForce RTX 2080 Ti


In [11]:
SEED = 2020
train_inputs, val_inputs, train_labels, val_labels = train_test_split(train_ids, train_label, 
                                                            random_state=SEED, test_size=0.1)
train_masks, val_masks, _, _ = train_test_split(train_masks, train_ids,
                                             random_state=SEED, test_size=0.1)

train_inputs = torch.tensor(train_inputs)
val_inputs = torch.tensor(val_inputs)
train_labels = torch.tensor(train_labels)
val_labels = torch.tensor(val_labels)
train_masks = torch.tensor(train_masks)
val_masks = torch.tensor(val_masks)

In [12]:
BATCH_SIZE = 8

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)
val_data = TensorDataset(val_inputs, val_masks, val_labels)
val_sampler = SequentialSampler(val_data)
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=BATCH_SIZE)

In [11]:
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [12]:
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds.max(1)[1].view(y.size()).data 
               == y.data).float() #convert into float for division 
    acc = correct.sum()/len(correct)
    return acc

In [13]:
def train(model, optimizer, dataloader):
    epoch_loss = 0
    epoch_acc = 0
    total_len = 0
    
    model.train()
    
    for step, batch in enumerate(dataloader):
        # batch: [ids, mask, label]
        b_temp = tuple(b.to(device) for b in batch)
        b_ids, b_mask, b_labels = b_temp
        
        optimizer.zero_grad()
        
        outputs = model(b_ids, token_type_ids=None,
                      attention_mask=b_mask, labels=b_labels)
        loss = outputs[0]
        logits = outputs[1]
        
        pred = torch.argmax(F.softmax(logits), dim=1)
        acc = pred.eq(b_labels).sum().item()
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc
        
        total_len += len(b_labels)
    return epoch_loss/total_len, epoch_acc/total_len

In [14]:
def evaluate(model, dataloader):
    model.eval()
    
    epoch_loss = 0
    epoch_acc = 0
    total_len = 0
    
    for step, batch in enumerate(dataloader):
        # batch: [ids, mask, label]
        b_temp = tuple(b.to(device) for b in batch)
        b_ids, b_mask, b_labels = b_temp
        
        with torch.no_grad():
            outputs = model(b_ids, token_type_ids=None,
                          attention_mask=b_mask)

        logits = outputs[0]#.detach.cpu().numpy()
        pred = torch.argmax(F.softmax(logits), dim=1)
        acc = pred.eq(b_labels).sum().item()
        
        epoch_acc += acc
        
        total_len += len(b_labels)
    return epoch_acc / total_len

In [15]:
import time
EPOCHES = 5

train_loss = []
for epoch in range(EPOCHES):
    start_time = time.time()
    print('======== Epoch {:} / {:} ========'.format(epoch + 1, EPOCHES))
    
    train_loss, train_acc = train(model, optimizer, train_dataloader)
    val_accuracy = evaluate(model, val_dataloader)
    
    print("Train loss: %.3f | Train acc: %.2f | Val accuracy : %5.2f | Time: %f" 
          %(train_loss,train_acc,val_accuracy, (time.time() - start_time)/60))
    





Train loss: 0.045 | Train acc: 0.85 | Val accuracy :  0.91 | Time: 17.768819
Train loss: 0.023 | Train acc: 0.93 | Val accuracy :  0.92 | Time: 17.684740
Train loss: 0.020 | Train acc: 0.94 | Val accuracy :  0.92 | Time: 17.676667
Train loss: 0.018 | Train acc: 0.95 | Val accuracy :  0.92 | Time: 17.720844


In [9]:
test_ids, test_masks = tokenize_data(test_text)

Token indices sequence length is longer than the specified maximum sequence length for this model (1324 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (545 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (940 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (640 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (571 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th

Max sentence length:  1326
number:100


In [13]:
test_inputs = torch.tensor(test_ids)
test_labels = torch.tensor(test_label)
test_masks = torch.tensor(test_masks)

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=BATCH_SIZE)

  This is separate from the ipykernel package so we can avoid doing imports until


In [18]:
test_acc = evaluate(model, test_dataloader)
print("Test acc: %.3f" %(test_acc))



Test acc: 0.930


In [14]:
model = BertForSequenceClassification.from_pretrained(
        "bert-base-uncased", num_labels=2,
        output_attentions = True, output_hidden_states = True).to(device)

In [15]:
def get_vectors(model, dataloader):
    model.eval()
    
    epoch_loss = 0
    epoch_acc = 0
    total_len = 0
    preds=None
    for step, batch in enumerate(dataloader):
        # batch: [ids, mask, label]
        b_temp = tuple(b.to(device) for b in batch)
        b_ids, b_mask, b_labels = b_temp
        
        with torch.no_grad():
            outputs = model(b_ids, token_type_ids=None,
                          attention_mask=b_mask, labels=None)

        tmp_eval_loss, logits = outputs[:2]
        
    return outputs

In [16]:
outputs = get_vectors(model, test_dataloader)

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
           8.9008e-01,  1.6575e-01],
         [ 6.9938e-01, -3.5741e-01,  7.5132e-02,  ..., -5.3589e-01,
           2.1940e-01, -1.5962e+00],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-4.2141e-01, -8.3641e-02, -5.2981e-01,  ...,  6.4762e-01,
           2.2886e-01, -2.4024e-01],
         [ 4.6706e-04,  1.6225e-01, -6.4443e-02,  ...,  4.9443e-01,
           6.9413e-01,  3.6286e-01],
         ...,
         [ 9.2648e-01

13 (tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.1036, -0.5528,  0.0397,  ...,  0.2777, -0.3384,  0.5302],
         [ 0.3472,  0.2552,  0.5450,  ...,  0.5794,  0.8818,  0.3089],
         ...,
         [ 1.3832,  0.5990,  0.4331,  ..., -0.3166,  0.9318, -0.9913],
         [ 1.8774,  0.4274,  0.4806,  ..., -0.1606, -0.2800, -0.7667],
         [-0.2793,  0.2133, -0.8499,  ...,  0.6383,  1.2285, -1.4156]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.0923,  0.3753,  0.6408,  ..., -0.3404, -0.9103, -0.1868],
         [-0.4027,  0.4038,  0.2507,  ...,  0.5666, -0.9247, -0.1277],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.6485,  0.6739, -0.0932,  ...

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-1.4889e-01,  5.3714e-01, -5.6871e-01,  ..., -4.4369e-02,
          -4.1481e-01,  6.8564e-01],
         [-2.5807e-01, -2.0038e-01,  7.8156e-02,  ...,  9.2702e-01,
           4.3151e-01, -1.5860e+00],
         ...,
         [-2.8964e-01,  6.2930e-01,  1.2535e-01,  ..., -4.7787e-01,
           2.1157e-01, -1.0309e+00],
         [ 8.4472e-01,  3.7899e-01,  6.4288e-01,  ...,  4.0005e-01,
           7.1707e-02, -5.8268e-01],
         [ 2.5478e-01, -2.0126e-01,  2.1968e-01,  ...,  5.5538e-01,
           8.7054e-01, -9.1801e-01]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4061e-01,  7.0248e-01, -6.4828e-01,  ...,  2.2401e-01,
           7.5120e-01,  2.3857e-01],
         [-6.2884e-01,  4.4862e-01,  6.3140e-01,  ...,  5.9740e-01,
           5.3136e-01,  2.4762e-01],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.2832,  1.0031,  0.7676,  ...,  0.5968, -0.1371, -0.2700],
         [-0.6652,  0.2168,  0.2186,  ...,  0.4740,  0.6926, -1.0691],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.6485,  0.6739, -0.0932,  ...,  0.4475,  0.6696,  0.1820],
         [ 0.0565,  0.3126, -0.2331,  ...,  0.0136,  0.6972, -0.8789],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.6692,  0.6162, -1.0241,  ...

13 (tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.6485,  0.6739, -0.0932,  ...,  0.4475,  0.6696,  0.1820],
         [-0.6270, -0.0633, -0.3143,  ...,  0.3427,  0.4636,  0.4594],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.5467,  0.3301, -0.9227,  ...,  0.9150,  0.8351, -0.2478],
         [ 0.3753,  0.5709, -0.9220,  ...,  0.4509,  0.8474,  0.3426],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.5179,  0.5018, -0.1985,  ...

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 2.3410e-01,  4.4271e-01, -1.3336e-01,  ...,  7.3032e-01,
           9.8038e-01, -4.1566e-01],
         [ 4.3117e-01,  6.7826e-02,  5.9432e-01,  ...,  1.1861e-01,
          -2.4069e-01, -2.0226e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
           8.9008e-01,  1.6575e-01],
         [-7.6046e-01,  2.3549e-01,  3.3033e-01,  ..., -3.0810e-02,
          -8.5563e-01,  4.4695e-02],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-4.3670e-01,  5.3602e-01, -5.1413e-02,  ..., -3.9741e-02,
           6.7825e-01, -5.3183e-01],
         [-1.8317e+00,  6.3027e-01, -9.1186e-01,  ...,  8.8637e-02,
          -2.4055e-01, -6.6227e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-5.7257e-01,  5.0150e-01, -3.4027e-01,  ...,  1.0698e+00,
           7.2472e-01, -6.1854e-01],
         [-4.1489e-01, -1.9865e-01,  1.8604e-01,  ..., -3.8843e-01,
           4.9132e-01, -6.9791e-01],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 1.2749, -0.6534,  0.5339,  ..., -0.0148,  0.2415, -1.4490],
         [-0.6652,  0.2168,  0.2186,  ...,  0.4740,  0.6926, -1.0691],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.4605,  0.1282,  0.2784,  ...,  0.3820, -0.3786, -0.1164],
         [ 0.2908, -0.1538, -0.3800,  ...,  0.3538,  0.8016,  0.3832],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.4227, -0.0289, -0.1456,  ...

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.3402e-01, -9.9856e-02, -4.0651e-01,  ...,  1.8330e-01,
           7.1556e-01, -1.1127e+00],
         [ 5.3030e-02,  3.2892e-01,  1.9099e-01,  ..., -3.8187e-01,
           2.2459e-01,  2.8507e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 6.1660e-01, -2.5942e-01, -4.2591e-01,  ...,  6.6816e-01,
           7.8973e-01,  1.0660e-01],
         [-4.1808e-01,  3.3958e-01, -8.7880e-01,  ...,  3.3994e-01,
           7.1716e-01, -2.7936e-01],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
           8.9008e-01,  1.6575e-01],
         [ 1.1558e+00,  8.5331e-02, -1.1208e-01,  ...,  4.3965e-01,
           8.5903e-01, -3.2685e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 6.5574e-01,  9.5841e-01, -3.8884e-01,  ..., -2.7064e-03,
           1.3282e-01,  3.8959e-01],
         [-7.1147e-01,  2.7876e-01,  1.3426e-02,  ..., -1.0553e-01,
          -3.6994e-01,  3.2157e-01],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.5812,  0.6695,  0.1027,  ...,  0.2468,  0.3671,  0.1771],
         [-0.6860,  0.4399,  0.2051,  ...,  0.4881,  0.3567,  0.2992],
         ...,
         [-0.6177, -1.0351, -0.8221,  ...,  0.7175,  0.0997,  0.3623],
         [ 0.2145, -0.0538,  0.8151,  ..., -0.1228, -0.9260, -0.5776],
         [ 0.2548, -0.2013,  0.2197,  ...,  0.5554,  0.8705, -0.9180]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [-0.6485,  0.6739, -0.0932,  ...,  0.4475,  0.6696,  0.1820],
         [-1.2613, -0.2880,  0.3104,  ..., -0.6596,  1.0511, -0.9745],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],

        [[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.1992,  0.8737, -0.5840,  ...

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
           8.9008e-01,  1.6575e-01],
         [-4.8563e-01, -5.2378e-01,  2.9687e-01,  ...,  1.9118e-01,
          -6.7089e-02,  3.2029e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 6.6921e-01,  6.1623e-01, -1.0241e+00,  ...,  6.2799e-01,
           9.8995e-01,  3.6123e-01],
         [-4.8563e-01, -5.2378e-01,  2.9687e-01,  ...,  1.9118e-01,
          -6.7089e-02,  3.2029e-01],
         ...,
         [ 6.4577e-01

13 (tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [ 6.6864e-01,  1.1354e-01, -9.0192e-01,  ..., -1.5069e-01,
           5.9889e-01, -7.2968e-01],
         [-6.2703e-01, -6.3313e-02, -3.1428e-01,  ...,  3.4265e-01,
           4.6361e-01,  4.5937e-01],
         ...,
         [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
          -3.4658e-01, -4.8250e-01],
         [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
          -6.0378e-01, -4.9350e-01],
         [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
           2.1131e-01, -1.5097e+00]],

        [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
           3.8253e-02,  1.6400e-01],
         [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
           8.9008e-01,  1.6575e-01],
         [-3.1616e-01, -2.1233e-01, -3.5224e-01,  ..., -9.2020e-02,
          -1.3128e-01,  5.3488e-01],
         ...,
         [ 6.4577e-01

In [18]:
tmp_eval_loss, logits = outputs[:2]

In [19]:
logits

(tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
            3.8253e-02,  1.6400e-01],
          [ 6.6864e-01,  1.1354e-01, -9.0192e-01,  ..., -1.5069e-01,
            5.9889e-01, -7.2968e-01],
          [-6.2703e-01, -6.3313e-02, -3.1428e-01,  ...,  3.4265e-01,
            4.6361e-01,  4.5937e-01],
          ...,
          [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
           -3.4658e-01, -4.8250e-01],
          [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
           -6.0378e-01, -4.9350e-01],
          [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
            2.1131e-01, -1.5097e+00]],
 
         [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
            3.8253e-02,  1.6400e-01],
          [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
            8.9008e-01,  1.6575e-01],
          [-3.1616e-01, -2.1233e-01, -3.5224e-01,  ..., -9.2020e-02,
           -1.3128e-01,  5.3488e-01],
          ...,
    

In [28]:
len(logits[0]), logits[0]

(4, tensor([[[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
            3.8253e-02,  1.6400e-01],
          [ 6.6864e-01,  1.1354e-01, -9.0192e-01,  ..., -1.5069e-01,
            5.9889e-01, -7.2968e-01],
          [-6.2703e-01, -6.3313e-02, -3.1428e-01,  ...,  3.4265e-01,
            4.6361e-01,  4.5937e-01],
          ...,
          [ 6.4577e-01, -5.4087e-01, -1.7797e-01,  ..., -4.8202e-02,
           -3.4658e-01, -4.8250e-01],
          [ 7.4169e-01, -7.2706e-01,  3.2783e-01,  ..., -2.0112e-01,
           -6.0378e-01, -4.9350e-01],
          [ 2.9921e-01, -1.0338e+00,  1.2938e-01,  ...,  2.1486e-01,
            2.1131e-01, -1.5097e+00]],
 
         [[ 1.6855e-01, -2.8577e-01, -3.2613e-01,  ..., -2.7571e-02,
            3.8253e-02,  1.6400e-01],
          [-3.4026e-04,  5.3974e-01, -2.8805e-01,  ...,  7.5731e-01,
            8.9008e-01,  1.6575e-01],
          [-3.1616e-01, -2.1233e-01, -3.5224e-01,  ..., -9.2020e-02,
           -1.3128e-01,  5.3488e-01],
          ...,
 

In [22]:
len(logits[0][0]), logits[0][0]

(512, tensor([[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
         [ 0.6686,  0.1135, -0.9019,  ..., -0.1507,  0.5989, -0.7297],
         [-0.6270, -0.0633, -0.3143,  ...,  0.3427,  0.4636,  0.4594],
         ...,
         [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
         [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
         [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],
        device='cuda:0'))

In [25]:
len(logits[0][0][1]), logits[0][0][1]


(768, tensor([ 6.6864e-01,  1.1354e-01, -9.0192e-01, -6.1636e-01,  6.3789e-01,
          1.7374e-01,  1.6871e-01, -1.4412e+00,  1.2920e-01, -5.4721e-01,
         -1.0864e+00,  5.9365e-01,  5.3690e-02,  8.5018e-01, -1.2522e-01,
         -6.5338e-01,  1.4328e-01,  9.8830e-02,  7.0312e-01, -7.3852e-01,
          5.1136e-01, -9.4471e-02, -9.4006e-01,  7.6083e-01, -4.8157e-01,
         -1.3132e+00, -1.3942e+00,  3.7538e-01,  3.4867e-01,  3.8927e-01,
          6.0180e-01, -1.3353e+00, -1.0782e-01, -8.1453e-02, -6.5663e-01,
         -4.0626e-01,  1.9167e-01,  5.1021e-02,  4.2037e-03,  1.1238e+00,
         -1.0301e-01, -8.5701e-01,  3.7425e-01, -3.8552e-01,  5.8815e-01,
         -2.8987e-01, -1.1858e-01,  3.0727e-01, -6.1688e-01,  5.6427e-01,
         -8.6077e-01, -2.2652e-01, -1.3510e+00,  8.2907e-02,  3.4470e-01,
         -2.9087e-01,  2.3631e-03, -9.2381e-01,  2.2390e-01,  3.0528e-01,
          4.7683e-02,  4.1083e-01, -2.8994e-01, -3.4775e-01, -5.1602e-01,
          3.3162e-01, -2.5993e-01

In [30]:
for step, batch in enumerate(test_dataloader):
    # batch: [ids, mask, label]
    b_temp = tuple(b.to(device) for b in batch)
    b_ids, b_mask, b_labels = b_temp
    print(len(b_ids[0]))
    print(b_ids)

512
tensor([[  101,  1045,  2018,  ...,     0,     0,     0],
        [  101, 25665,  1029,  ...,     0,     0,     0],
        [  101,  2023,  3185,  ...,     0,     0,     0],
        ...,
        [  101,  2023,  2003,  ...,     0,     0,     0],
        [  101,  1000, 25193,  ...,     0,     0,     0],
        [  101,  2023,  2143,  ...,     0,     0,     0]], device='cuda:0')
512
tensor([[  101,  1000,  4028,  ...,     0,     0,     0],
        [  101,  2054,  2003,  ..., 25593,  1010,  2005],
        [  101,  2023,  3185,  ...,     0,     0,     0],
        ...,
        [  101,  2339,  2006,  ...,     0,     0,     0],
        [  101,  2009,  1005,  ...,     0,     0,     0],
        [  101,  2026,  2814,  ...,     0,     0,     0]], device='cuda:0')
512
tensor([[  101,  2074,  2004,  ...,  2486,  2017,  2000],
        [  101,  1045,  4149,  ...,     0,     0,     0],
        [  101,  2073,  2000,  ...,     0,     0,     0],
        ...,
        [  101,  2339,  2001,  ...,     0, 

In [34]:
a = logits[:][:][0]

tensor([[ 0.1686, -0.2858, -0.3261,  ..., -0.0276,  0.0383,  0.1640],
        [ 0.6686,  0.1135, -0.9019,  ..., -0.1507,  0.5989, -0.7297],
        [-0.6270, -0.0633, -0.3143,  ...,  0.3427,  0.4636,  0.4594],
        ...,
        [ 0.6458, -0.5409, -0.1780,  ..., -0.0482, -0.3466, -0.4825],
        [ 0.7417, -0.7271,  0.3278,  ..., -0.2011, -0.6038, -0.4935],
        [ 0.2992, -1.0338,  0.1294,  ...,  0.2149,  0.2113, -1.5097]],
       device='cuda:0')