## 1. Load dataset

In [1]:
from os.path import join
import pandas as pd


def relabeler(label):
    if label == 'entailment':
        return 0
    if label == 'neutral':
        return 1
    if label == 'contradiction':
        return 2
    else:
        None


def _load(fname, genre=None, dire='data/data135628/'):
    df = pd.read_csv(join(dire, fname))
    if genre:
        genres = ['slate', 'government', 'telephone',
                  'travel', 'fiction']

        if isinstance(genre, str):
            genre = [genre]
        elif isinstance(genre, (list, tuple)):
            assert all(isinstance(g, str) for g in genre)
        else:
            raise TypeError("genre must be a str or a list of str")

        assert all(g in genres for g in genre), f'genre must be in {genres}'
        df = df[df.isin(genre)['genre']]

    df = df.fillna(" ")
    if 'test' in fname:
        df = df[['sentence1', 'sentence2']]
    else:
        df['gold_label'] = df['gold_label'].apply(relabeler)
        df.dropna(subset=['sentence1', 'sentence2', 'gold_label'], inplace=True)
        df = df[['sentence1', 'sentence2', 'gold_label']]

    out = []
    for row in df.values:
        out.append(row.tolist())
    return out


def load_dataset(fname, genre=None, dire='data/data135628/'):
    if isinstance(fname, str):
        return _load(fname, genre, dire)

    elif isinstance(fname, (list, tuple)):
        return [_load(f, genre, dire) for f in fname]

    else:
        raise TypeError("fname must be a str or a list!")

In [2]:
train_set, dev1_set, dev2_set = load_dataset(['train.csv', 'dev_matched.csv', 'dev_mismatched.csv'])
train_set = train_set + dev1_set + dev2_set

test1_set, test2_set = load_dataset(['test_matched.csv', 'test_mismatched.csv'])
print("Train set size:", len(train_set))
print("Train set examples:", train_set[:2])

print("\nTest set size:", len(test1_set))
print("nTest set examples:", test1_set[:2])

print("\nTest set size:", len(test2_set))
print("nTest set examples:", test2_set[:2])

Train set size: 412349
Train set examples: [['Conceptually cream skimming has two basic dimensions - product and geography.', 'Product and geography are what make cream skimming work. ', 1], ['you know during the season and i guess at at your level uh you lose them to the next level if if they decide to recall the the parent team the Braves decide to call to recall a guy from triple A then a double A guy goes up to replace him and a single A guy goes up to replace him', 'You lose the things to the following level if the people recall.', 0]]

Test set size: 9796
nTest set examples: [['That which binds together Chinese.', 'This is a shared value among Chinese people.'], ["The actual length of an individual worker's H-2A visa varies depending upon the geographic location of the employer and the nature of the farmwork to be performed.", "The location of the employer effects the length of the worker's H-2A visa."]]

Test set size: 9847
nTest set examples: [['Even here, the channel perspecti

## 2. Transform text

In [3]:
from paddlenlp.datasets import MapDataset
from paddle.io import BatchSampler, DataLoader
from paddlenlp.data import Pad, Stack, Tuple
from paddlenlp.transformers import BertModel as SeqClfModel
from paddlenlp.transformers import BertTokenizer as PTMTokenizer


MODEL_NAME = "bert-base-uncased"
tokenizer = PTMTokenizer.from_pretrained(MODEL_NAME)


def example_converter(example, max_seq_length, tokenizer):
    text_a, text_b, label = example
    encoded = tokenizer(text=text_a, text_pair=text_b, max_seq_len=max_seq_length)
    input_ids = encoded["input_ids"]
    token_type_ids = encoded["token_type_ids"]
    return input_ids, token_type_ids, label


def get_trans_fn(max_seq_length=128, tokenizer=tokenizer):
    return lambda ex: example_converter(ex, max_seq_length, tokenizer)


batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    Stack(dtype="int64")
    ): fn(samples)


def create_dataloader(dataset,
                      trans_fn,
                      batchify_fn,
                      test=False,
                      batch_size=128,
                      shuffle=True,
                      sampler=BatchSampler):

    if test:
        dataset = [d + [0] for d in dataset]

    if not isinstance(dataset, MapDataset):
        dataset = MapDataset(dataset)

    dataset.map(trans_fn)
    batch_sampler = sampler(dataset,
                            shuffle=shuffle,
                            batch_size=batch_size)

    dataloder = DataLoader(dataset,
                           batch_sampler=batch_sampler,
                           collate_fn=batchify_fn)

    return dataloder

[2022-03-30 01:21:26,811] [    INFO]

 - Downloading https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-uncased-vocab.txt and saved to /home/aistudio/.paddlenlp/models/bert-base-uncased




[2022-03-30 01:21:26,814] [    INFO]

 - Downloading bert-base-uncased-vocab.txt from https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-uncased-vocab.txt




  0%|          | 0/227 [00:00<?, ?it/s]

100%|██████████| 227/227 [00:00<00:00, 23654.24it/s]




In [4]:
max_seq_length = 128; batch_size = 64
trans_fn = get_trans_fn(max_seq_length)
train_loader = create_dataloader(train_set, trans_fn, batchify_fn, batch_size=batch_size)
# dev_loader = create_dataloader(dev_set, trans_fn, batchify_fn, batch_size=batch_size)
test1_loader = create_dataloader(test1_set, trans_fn, batchify_fn, test=True, shuffle=False, batch_size=batch_size)
test2_loader = create_dataloader(test2_set, trans_fn, batchify_fn, test=True, shuffle=False, batch_size=batch_size)

## 3. Model building

In [5]:
from paddle import nn
import paddle


class PTM(nn.Layer):

    def __init__(self, pretrained_model, dropout=0.1, num_class=3):
        super().__init__()

        self.ptm = pretrained_model
        ptm_out_dim = self.ptm.config["hidden_size"]
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(ptm_out_dim, ptm_out_dim // 2)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(ptm_out_dim // 2, num_class)

    def encoder(self, input_ids, token_type_ids):
        _, embd = self.ptm(input_ids, token_type_ids)
        embd = self.dropout(embd)
        return embd

    def forward(self, input_ids, token_type_ids):
        embd = self.encoder(input_ids, token_type_ids)
        hidden = self.relu(self.fc1(embd))
        logits = self.fc2(hidden)
        return logits

In [6]:
from paddlenlp.transformers import LinearDecayWithWarmup

epoch = 3
weight_decay = 0.0
warmup_proportion = 0.0
lr_scheduler = LinearDecayWithWarmup(5e-5, len(train_loader) * epoch,
                                         warmup_proportion)

def get_model(model):
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
    parameters=model.parameters(),
    learning_rate=lr_scheduler,
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

    criterion = paddle.nn.CrossEntropyLoss()

    model = paddle.Model(model)
    metric = paddle.metric.Accuracy()
    model.prepare(optimizer, criterion, metric)
    return model

In [7]:
model = SeqClfModel.from_pretrained(MODEL_NAME)
model = PTM(model)
model = get_model(model)

[2022-03-30 01:21:27,018] [    INFO]

 - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams and saved to /home/aistudio/.paddlenlp/models/bert-base-uncased




[2022-03-30 01:21:27,020] [    INFO]

 - Downloading bert-base-uncased.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams




  0%|          | 0/793257 [00:00<?, ?it/s]

  1%|          | 5068/793257 [00:00<00:15, 50675.58it/s]

  1%|▏         | 11613/793257 [00:00<00:14, 54349.11it/s]

  2%|▏         | 18233/793257 [00:00<00:13, 57432.74it/s]

  3%|▎         | 24851/793257 [00:00<00:12, 59802.92it/s]

  4%|▍         | 31507/793257 [00:00<00:12, 61680.09it/s]

  5%|▍         | 38225/793257 [00:00<00:11, 63233.10it/s]

  6%|▌         | 45769/793257 [00:00<00:11, 66457.22it/s]

  7%|▋         | 53392/793257 [00:00<00:10, 69114.91it/s]

  8%|▊         | 61125/793257 [00:00<00:10, 71389.91it/s]

  9%|▊         | 68763/793257 [00:01<00:09, 72815.34it/s]

 10%|▉         | 76162/793257 [00:01<00:09, 73163.15it/s]

 11%|█         | 83699/793257 [00:01<00:09, 73807.66it/s]

 12%|█▏        | 91361/793257 [00:01<00:09, 74628.11it/s]

 12%|█▏        | 99043/793257 [00:01<00:09, 75266.39it/s]

 13%|█▎        | 106717/793257 [00:01<00:09, 75699.95it/s]

 14%|█▍        | 114277/793257 [00:01<00:08, 75663.03it/s]

 15%|█▌        | 122017/793257 [00:01<00:08, 76174.50it/s]

 16%|█▋        | 129667/793257 [00:01<00:08, 76269.54it/s]

 17%|█▋        | 137425/793257 [00:01<00:08, 76654.78it/s]

 18%|█▊        | 145108/793257 [00:02<00:08, 76706.74it/s]

 19%|█▉        | 152793/793257 [00:02<00:08, 76748.75it/s]

 20%|██        | 160525/793257 [00:02<00:08, 76917.67it/s]

 21%|██        | 168217/793257 [00:02<00:08, 76575.04it/s]

 22%|██▏       | 175875/793257 [00:02<00:08, 76565.76it/s]

 23%|██▎       | 183532/793257 [00:02<00:07, 76498.50it/s]

 24%|██▍       | 191228/793257 [00:02<00:07, 76635.68it/s]

 25%|██▌       | 198892/793257 [00:02<00:07, 76440.72it/s]

 26%|██▌       | 206537/793257 [00:02<00:07, 74140.60it/s]

 27%|██▋       | 213967/793257 [00:02<00:08, 68404.90it/s]

 28%|██▊       | 221523/793257 [00:03<00:08, 70403.27it/s]

 29%|██▉       | 228647/793257 [00:03<00:08, 68726.41it/s]

 30%|██▉       | 235810/793257 [00:03<00:08, 69571.76it/s]

 31%|███       | 243479/793257 [00:03<00:07, 71563.19it/s]

 32%|███▏      | 251086/793257 [00:03<00:07, 72856.62it/s]

 33%|███▎      | 258803/793257 [00:03<00:07, 74095.90it/s]

 34%|███▎      | 266246/793257 [00:03<00:07, 72181.92it/s]

 35%|███▍      | 273777/793257 [00:03<00:07, 73091.52it/s]

 35%|███▌      | 281113/793257 [00:03<00:07, 69516.23it/s]

 36%|███▋      | 288120/793257 [00:03<00:07, 66327.12it/s]

 37%|███▋      | 294823/793257 [00:04<00:07, 63486.04it/s]

 38%|███▊      | 301247/793257 [00:04<00:07, 61982.15it/s]

 39%|███▉      | 307506/793257 [00:04<00:07, 61660.51it/s]

 40%|███▉      | 313772/793257 [00:04<00:07, 61955.17it/s]

 40%|████      | 319998/793257 [00:04<00:07, 61133.11it/s]

 41%|████      | 326135/793257 [00:04<00:07, 58827.72it/s]

 42%|████▏     | 332053/793257 [00:04<00:07, 58228.70it/s]

 43%|████▎     | 337902/793257 [00:04<00:07, 57410.76it/s]

 43%|████▎     | 343664/793257 [00:04<00:07, 56846.23it/s]

 44%|████▍     | 350528/793257 [00:05<00:07, 59935.51it/s]

 45%|████▌     | 358163/793257 [00:05<00:06, 64067.20it/s]

 46%|████▌     | 365868/793257 [00:05<00:06, 67477.24it/s]

 47%|████▋     | 373491/793257 [00:05<00:06, 69881.57it/s]

 48%|████▊     | 381235/793257 [00:05<00:05, 71987.08it/s]

 49%|████▉     | 388930/793257 [00:05<00:05, 73407.38it/s]

 50%|█████     | 396713/793257 [00:05<00:05, 74678.19it/s]

 51%|█████     | 404410/793257 [00:05<00:05, 75349.45it/s]

 52%|█████▏    | 412132/793257 [00:05<00:05, 75899.82it/s]

 53%|█████▎    | 419752/793257 [00:05<00:04, 75959.04it/s]

 54%|█████▍    | 427369/793257 [00:06<00:04, 75247.36it/s]

 55%|█████▍    | 434953/793257 [00:06<00:04, 75423.77it/s]

 56%|█████▌    | 442575/793257 [00:06<00:04, 75658.60it/s]

 57%|█████▋    | 450239/793257 [00:06<00:04, 75948.29it/s]

 58%|█████▊    | 457898/793257 [00:06<00:04, 76138.07it/s]

 59%|█████▊    | 465666/793257 [00:06<00:04, 76592.37it/s]

 60%|█████▉    | 473336/793257 [00:06<00:04, 76624.10it/s]

 61%|██████    | 481001/793257 [00:06<00:04, 76127.04it/s]

 62%|██████▏   | 488655/793257 [00:06<00:03, 76248.31it/s]

 63%|██████▎   | 496329/793257 [00:06<00:03, 76392.56it/s]

 64%|██████▎   | 504043/793257 [00:07<00:03, 76614.80it/s]

 65%|██████▍   | 511741/793257 [00:07<00:03, 76723.36it/s]

 65%|██████▌   | 519425/793257 [00:07<00:03, 76757.90it/s]

 66%|██████▋   | 527123/793257 [00:07<00:03, 76812.85it/s]

 67%|██████▋   | 534822/793257 [00:07<00:03, 76865.16it/s]

 68%|██████▊   | 542515/793257 [00:07<00:03, 76882.10it/s]

 69%|██████▉   | 550230/793257 [00:07<00:03, 76960.04it/s]

 70%|███████   | 558013/793257 [00:07<00:03, 77216.87it/s]

 71%|███████▏  | 565736/793257 [00:07<00:02, 77193.78it/s]

 72%|███████▏  | 573529/793257 [00:07<00:02, 77411.11it/s]

 73%|███████▎  | 581271/793257 [00:08<00:02, 76652.92it/s]

 74%|███████▍  | 588939/793257 [00:08<00:02, 76493.91it/s]

 75%|███████▌  | 596590/793257 [00:08<00:02, 74420.24it/s]

 76%|███████▌  | 604046/793257 [00:08<00:02, 73252.06it/s]

 77%|███████▋  | 611385/793257 [00:08<00:02, 70759.70it/s]

 78%|███████▊  | 618489/793257 [00:08<00:02, 68988.82it/s]

 79%|███████▉  | 625433/793257 [00:08<00:02, 69122.35it/s]

 80%|███████▉  | 632761/793257 [00:08<00:02, 70318.59it/s]

 81%|████████  | 640353/793257 [00:08<00:02, 71907.65it/s]

 82%|████████▏ | 647830/793257 [00:09<00:01, 72741.93it/s]

 83%|████████▎ | 655123/793257 [00:09<00:02, 56065.47it/s]

 84%|████████▎ | 662661/793257 [00:09<00:02, 60733.29it/s]

 84%|████████▍ | 669263/793257 [00:09<00:02, 50646.69it/s]

 85%|████████▌ | 674980/793257 [00:09<00:02, 50753.08it/s]

 86%|████████▌ | 680512/793257 [00:09<00:02, 40559.47it/s]

 86%|████████▋ | 685213/793257 [00:09<00:02, 36775.37it/s]

 87%|████████▋ | 689410/793257 [00:10<00:03, 32216.38it/s]

 87%|████████▋ | 693098/793257 [00:10<00:03, 26864.76it/s]

 88%|████████▊ | 698607/793257 [00:10<00:02, 31743.57it/s]

 89%|████████▉ | 704067/793257 [00:10<00:02, 36299.39it/s]

 89%|████████▉ | 709522/793257 [00:10<00:02, 40348.95it/s]

 90%|█████████ | 714834/793257 [00:10<00:01, 43484.63it/s]

 91%|█████████ | 722035/793257 [00:10<00:01, 49344.62it/s]

 92%|█████████▏| 729355/793257 [00:10<00:01, 54689.98it/s]

 93%|█████████▎| 736506/793257 [00:11<00:00, 58841.72it/s]

 94%|█████████▍| 744229/793257 [00:11<00:00, 63366.88it/s]

 95%|█████████▍| 751964/793257 [00:11<00:00, 67000.30it/s]

 96%|█████████▌| 759824/793257 [00:11<00:00, 70103.14it/s]

 97%|█████████▋| 767546/793257 [00:11<00:00, 72096.18it/s]

 98%|█████████▊| 775270/793257 [00:11<00:00, 73564.19it/s]

 99%|█████████▊| 782794/793257 [00:11<00:00, 73516.07it/s]

100%|█████████▉| 790263/793257 [00:11<00:00, 73714.47it/s]

100%|██████████| 793257/793257 [00:11<00:00, 67497.35it/s]




W0330 01:21:38.848248   255 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0330 01:21:38.853195   255 device_context.cc:465] device: 0, cuDNN Version: 7.6.


[2022-03-30 01:21:49,124] [    INFO]

 - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']




## 4. Model training

In [8]:
model.fit(train_loader, epochs=epoch, verbose=2, log_freq=100)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.




Epoch 1/3




step  100/6443 - loss: 0.7668 - acc: 0.5563 - 361ms/step


step  200/6443 - loss: 0.5352 - acc: 0.6214 - 359ms/step


step  300/6443 - loss: 0.8902 - acc: 0.6503 - 355ms/step


step  400/6443 - loss: 0.6170 - acc: 0.6700 - 355ms/step


step  500/6443 - loss: 0.5381 - acc: 0.6873 - 356ms/step


step  600/6443 - loss: 0.5746 - acc: 0.6992 - 356ms/step


step  700/6443 - loss: 0.5383 - acc: 0.7074 - 356ms/step


step  800/6443 - loss: 0.4520 - acc: 0.7165 - 357ms/step


step  900/6443 - loss: 0.6355 - acc: 0.7215 - 357ms/step


step 1000/6443 - loss: 0.7656 - acc: 0.7262 - 356ms/step


step 1100/6443 - loss: 0.5419 - acc: 0.7316 - 358ms/step


step 1200/6443 - loss: 0.5789 - acc: 0.7353 - 357ms/step


step 1300/6443 - loss: 0.6627 - acc: 0.7391 - 358ms/step


step 1400/6443 - loss: 0.6020 - acc: 0.7427 - 359ms/step


step 1500/6443 - loss: 0.6321 - acc: 0.7455 - 358ms/step


step 1600/6443 - loss: 0.4877 - acc: 0.7488 - 358ms/step


step 1700/6443 - loss: 0.3998 - acc: 0.7510 - 358ms/step


step 1800/6443 - loss: 0.6125 - acc: 0.7531 - 358ms/step


step 1900/6443 - loss: 0.5628 - acc: 0.7553 - 358ms/step


step 2000/6443 - loss: 0.4364 - acc: 0.7574 - 358ms/step


step 2100/6443 - loss: 0.3597 - acc: 0.7589 - 358ms/step


step 2200/6443 - loss: 0.4114 - acc: 0.7609 - 357ms/step


step 2300/6443 - loss: 0.5957 - acc: 0.7622 - 357ms/step


step 2400/6443 - loss: 0.5190 - acc: 0.7637 - 358ms/step


step 2500/6443 - loss: 0.6652 - acc: 0.7656 - 358ms/step


step 2600/6443 - loss: 0.5554 - acc: 0.7669 - 358ms/step


step 2700/6443 - loss: 0.4510 - acc: 0.7684 - 358ms/step


step 2800/6443 - loss: 0.5726 - acc: 0.7696 - 357ms/step


step 2900/6443 - loss: 0.3973 - acc: 0.7708 - 357ms/step


step 3000/6443 - loss: 0.5607 - acc: 0.7722 - 357ms/step


step 3100/6443 - loss: 0.5617 - acc: 0.7733 - 357ms/step


step 3200/6443 - loss: 0.3851 - acc: 0.7746 - 357ms/step


step 3300/6443 - loss: 0.5478 - acc: 0.7757 - 357ms/step


step 3400/6443 - loss: 0.5133 - acc: 0.7766 - 357ms/step


step 3500/6443 - loss: 0.4452 - acc: 0.7777 - 357ms/step


step 3600/6443 - loss: 0.5632 - acc: 0.7787 - 357ms/step


step 3700/6443 - loss: 0.4369 - acc: 0.7796 - 357ms/step


step 3800/6443 - loss: 0.6565 - acc: 0.7804 - 357ms/step


step 3900/6443 - loss: 0.4923 - acc: 0.7811 - 357ms/step


step 4000/6443 - loss: 0.3500 - acc: 0.7821 - 357ms/step


step 4100/6443 - loss: 0.5533 - acc: 0.7829 - 357ms/step


step 4200/6443 - loss: 0.5174 - acc: 0.7836 - 357ms/step


step 4300/6443 - loss: 0.3128 - acc: 0.7845 - 357ms/step


step 4400/6443 - loss: 0.4726 - acc: 0.7851 - 357ms/step


step 4500/6443 - loss: 0.5596 - acc: 0.7861 - 357ms/step


step 4600/6443 - loss: 0.5413 - acc: 0.7868 - 357ms/step


step 4700/6443 - loss: 0.3273 - acc: 0.7873 - 357ms/step


step 4800/6443 - loss: 0.3840 - acc: 0.7879 - 357ms/step


step 4900/6443 - loss: 0.3907 - acc: 0.7886 - 357ms/step


step 5000/6443 - loss: 0.4558 - acc: 0.7893 - 357ms/step


step 5100/6443 - loss: 0.3587 - acc: 0.7901 - 357ms/step


step 5200/6443 - loss: 0.3377 - acc: 0.7907 - 357ms/step


step 5300/6443 - loss: 0.6689 - acc: 0.7913 - 357ms/step


step 5400/6443 - loss: 0.5266 - acc: 0.7918 - 357ms/step


step 5500/6443 - loss: 0.5272 - acc: 0.7922 - 357ms/step


step 5600/6443 - loss: 0.4921 - acc: 0.7926 - 357ms/step


step 5700/6443 - loss: 0.6019 - acc: 0.7932 - 357ms/step


step 5800/6443 - loss: 0.3568 - acc: 0.7938 - 357ms/step


step 5900/6443 - loss: 0.2390 - acc: 0.7943 - 357ms/step


step 6000/6443 - loss: 0.4576 - acc: 0.7948 - 357ms/step


step 6100/6443 - loss: 0.3420 - acc: 0.7952 - 357ms/step


step 6200/6443 - loss: 0.4477 - acc: 0.7956 - 357ms/step


step 6300/6443 - loss: 0.4465 - acc: 0.7961 - 357ms/step


step 6400/6443 - loss: 0.3074 - acc: 0.7966 - 357ms/step


step 6443/6443 - loss: 0.4267 - acc: 0.7968 - 357ms/step


Epoch 2/3




step  100/6443 - loss: 0.3359 - acc: 0.8808 - 357ms/step


step  200/6443 - loss: 0.3285 - acc: 0.8845 - 353ms/step


step  300/6443 - loss: 0.4120 - acc: 0.8840 - 359ms/step


step  400/6443 - loss: 0.3374 - acc: 0.8836 - 359ms/step


step  500/6443 - loss: 0.2878 - acc: 0.8833 - 358ms/step


step  600/6443 - loss: 0.1742 - acc: 0.8834 - 357ms/step


step  700/6443 - loss: 0.3429 - acc: 0.8831 - 358ms/step


step  800/6443 - loss: 0.3265 - acc: 0.8825 - 358ms/step


step  900/6443 - loss: 0.2673 - acc: 0.8829 - 358ms/step


step 1000/6443 - loss: 0.3679 - acc: 0.8828 - 358ms/step


step 1100/6443 - loss: 0.3183 - acc: 0.8821 - 358ms/step


step 1200/6443 - loss: 0.1965 - acc: 0.8818 - 358ms/step


step 1300/6443 - loss: 0.3164 - acc: 0.8812 - 359ms/step


step 1400/6443 - loss: 0.3200 - acc: 0.8810 - 360ms/step


step 1500/6443 - loss: 0.3771 - acc: 0.8812 - 359ms/step


step 1600/6443 - loss: 0.3234 - acc: 0.8812 - 359ms/step


step 1700/6443 - loss: 0.3098 - acc: 0.8811 - 359ms/step


step 1800/6443 - loss: 0.2969 - acc: 0.8811 - 359ms/step


step 1900/6443 - loss: 0.3182 - acc: 0.8806 - 359ms/step


step 2000/6443 - loss: 0.2407 - acc: 0.8804 - 359ms/step


step 2100/6443 - loss: 0.3068 - acc: 0.8804 - 359ms/step


step 2200/6443 - loss: 0.3106 - acc: 0.8803 - 359ms/step


step 2300/6443 - loss: 0.2099 - acc: 0.8805 - 360ms/step


step 2400/6443 - loss: 0.4320 - acc: 0.8807 - 360ms/step


step 2500/6443 - loss: 0.4240 - acc: 0.8805 - 360ms/step


step 2600/6443 - loss: 0.3422 - acc: 0.8804 - 360ms/step


step 2700/6443 - loss: 0.3130 - acc: 0.8802 - 360ms/step


step 2800/6443 - loss: 0.3108 - acc: 0.8803 - 360ms/step


step 2900/6443 - loss: 0.3559 - acc: 0.8802 - 360ms/step


step 3000/6443 - loss: 0.4348 - acc: 0.8801 - 360ms/step


step 3100/6443 - loss: 0.3453 - acc: 0.8802 - 360ms/step


step 3200/6443 - loss: 0.4024 - acc: 0.8800 - 359ms/step


step 3300/6443 - loss: 0.3483 - acc: 0.8800 - 360ms/step


step 3400/6443 - loss: 0.3363 - acc: 0.8800 - 359ms/step


step 3500/6443 - loss: 0.2359 - acc: 0.8802 - 359ms/step


step 3600/6443 - loss: 0.2241 - acc: 0.8804 - 359ms/step


step 3700/6443 - loss: 0.2759 - acc: 0.8804 - 359ms/step


step 3800/6443 - loss: 0.2637 - acc: 0.8805 - 359ms/step


step 3900/6443 - loss: 0.3005 - acc: 0.8804 - 359ms/step


step 4000/6443 - loss: 0.2310 - acc: 0.8804 - 359ms/step


step 4100/6443 - loss: 0.3177 - acc: 0.8806 - 359ms/step


step 4200/6443 - loss: 0.2312 - acc: 0.8807 - 359ms/step


step 4300/6443 - loss: 0.2463 - acc: 0.8808 - 359ms/step


step 4400/6443 - loss: 0.3688 - acc: 0.8808 - 359ms/step


step 4500/6443 - loss: 0.2666 - acc: 0.8809 - 359ms/step


step 4600/6443 - loss: 0.4019 - acc: 0.8810 - 359ms/step


step 4700/6443 - loss: 0.2999 - acc: 0.8811 - 359ms/step


step 4800/6443 - loss: 0.4138 - acc: 0.8811 - 359ms/step


step 4900/6443 - loss: 0.4187 - acc: 0.8810 - 359ms/step


step 5000/6443 - loss: 0.1992 - acc: 0.8812 - 359ms/step


step 5100/6443 - loss: 0.3376 - acc: 0.8811 - 359ms/step


step 5200/6443 - loss: 0.2817 - acc: 0.8810 - 359ms/step


step 5300/6443 - loss: 0.3684 - acc: 0.8810 - 359ms/step


step 5400/6443 - loss: 0.3316 - acc: 0.8811 - 359ms/step


step 5500/6443 - loss: 0.2211 - acc: 0.8812 - 359ms/step


step 5600/6443 - loss: 0.3067 - acc: 0.8812 - 359ms/step


step 5700/6443 - loss: 0.2798 - acc: 0.8812 - 359ms/step


step 5800/6443 - loss: 0.3826 - acc: 0.8812 - 359ms/step


step 5900/6443 - loss: 0.2732 - acc: 0.8812 - 359ms/step


step 6000/6443 - loss: 0.3300 - acc: 0.8813 - 359ms/step


step 6100/6443 - loss: 0.3447 - acc: 0.8815 - 359ms/step


step 6200/6443 - loss: 0.2211 - acc: 0.8817 - 358ms/step


step 6300/6443 - loss: 0.2186 - acc: 0.8818 - 358ms/step


step 6400/6443 - loss: 0.3504 - acc: 0.8818 - 358ms/step


step 6443/6443 - loss: 0.3087 - acc: 0.8818 - 358ms/step


Epoch 3/3




step  100/6443 - loss: 0.0747 - acc: 0.9313 - 358ms/step


step  200/6443 - loss: 0.2076 - acc: 0.9348 - 358ms/step


step  300/6443 - loss: 0.1069 - acc: 0.9357 - 358ms/step


step  400/6443 - loss: 0.1718 - acc: 0.9353 - 358ms/step


step  500/6443 - loss: 0.0991 - acc: 0.9359 - 357ms/step


step  600/6443 - loss: 0.1431 - acc: 0.9357 - 358ms/step


step  700/6443 - loss: 0.2968 - acc: 0.9357 - 358ms/step


step  800/6443 - loss: 0.1273 - acc: 0.9361 - 359ms/step


step  900/6443 - loss: 0.1956 - acc: 0.9357 - 359ms/step


step 1000/6443 - loss: 0.2452 - acc: 0.9360 - 358ms/step


step 1100/6443 - loss: 0.2385 - acc: 0.9354 - 358ms/step


step 1200/6443 - loss: 0.2421 - acc: 0.9356 - 359ms/step


step 1300/6443 - loss: 0.2058 - acc: 0.9355 - 359ms/step


step 1400/6443 - loss: 0.1699 - acc: 0.9358 - 359ms/step


step 1500/6443 - loss: 0.2302 - acc: 0.9360 - 359ms/step


step 1600/6443 - loss: 0.2053 - acc: 0.9361 - 359ms/step


step 1700/6443 - loss: 0.0954 - acc: 0.9361 - 359ms/step


step 1800/6443 - loss: 0.1786 - acc: 0.9360 - 359ms/step


step 1900/6443 - loss: 0.3652 - acc: 0.9358 - 358ms/step


step 2000/6443 - loss: 0.1610 - acc: 0.9361 - 358ms/step


step 2100/6443 - loss: 0.1200 - acc: 0.9363 - 358ms/step


step 2200/6443 - loss: 0.1410 - acc: 0.9363 - 358ms/step


step 2300/6443 - loss: 0.1737 - acc: 0.9364 - 358ms/step


step 2400/6443 - loss: 0.1394 - acc: 0.9365 - 358ms/step


step 2500/6443 - loss: 0.1166 - acc: 0.9365 - 358ms/step


step 2600/6443 - loss: 0.1861 - acc: 0.9365 - 358ms/step


step 2700/6443 - loss: 0.1025 - acc: 0.9365 - 357ms/step


step 2800/6443 - loss: 0.2207 - acc: 0.9366 - 357ms/step


step 2900/6443 - loss: 0.0899 - acc: 0.9367 - 357ms/step


step 3000/6443 - loss: 0.2008 - acc: 0.9368 - 357ms/step


step 3100/6443 - loss: 0.2808 - acc: 0.9367 - 358ms/step


step 3200/6443 - loss: 0.2032 - acc: 0.9367 - 357ms/step


step 3300/6443 - loss: 0.3259 - acc: 0.9367 - 358ms/step


step 3400/6443 - loss: 0.1644 - acc: 0.9367 - 358ms/step


step 3500/6443 - loss: 0.3197 - acc: 0.9368 - 357ms/step


step 3600/6443 - loss: 0.1002 - acc: 0.9368 - 357ms/step


step 3700/6443 - loss: 0.1083 - acc: 0.9369 - 357ms/step


step 3800/6443 - loss: 0.0413 - acc: 0.9369 - 357ms/step


step 3900/6443 - loss: 0.1602 - acc: 0.9370 - 357ms/step


step 4000/6443 - loss: 0.1512 - acc: 0.9370 - 357ms/step


step 4100/6443 - loss: 0.1036 - acc: 0.9371 - 358ms/step


step 4200/6443 - loss: 0.1108 - acc: 0.9372 - 358ms/step


step 4300/6443 - loss: 0.1737 - acc: 0.9374 - 358ms/step


step 4400/6443 - loss: 0.1709 - acc: 0.9373 - 357ms/step


step 4500/6443 - loss: 0.1229 - acc: 0.9372 - 357ms/step


step 4600/6443 - loss: 0.1799 - acc: 0.9371 - 357ms/step


step 4700/6443 - loss: 0.1006 - acc: 0.9371 - 357ms/step


step 4800/6443 - loss: 0.1280 - acc: 0.9372 - 357ms/step


step 4900/6443 - loss: 0.1407 - acc: 0.9372 - 357ms/step


step 5000/6443 - loss: 0.1448 - acc: 0.9372 - 357ms/step


step 5100/6443 - loss: 0.1133 - acc: 0.9372 - 357ms/step


step 5200/6443 - loss: 0.0751 - acc: 0.9374 - 357ms/step


step 5300/6443 - loss: 0.2002 - acc: 0.9373 - 357ms/step


step 5400/6443 - loss: 0.1154 - acc: 0.9374 - 357ms/step


step 5500/6443 - loss: 0.2602 - acc: 0.9373 - 357ms/step


step 5600/6443 - loss: 0.1218 - acc: 0.9373 - 357ms/step


step 5700/6443 - loss: 0.2070 - acc: 0.9374 - 357ms/step


step 5800/6443 - loss: 0.1951 - acc: 0.9374 - 357ms/step


step 5900/6443 - loss: 0.1337 - acc: 0.9374 - 357ms/step


step 6000/6443 - loss: 0.2437 - acc: 0.9375 - 357ms/step


step 6100/6443 - loss: 0.0658 - acc: 0.9376 - 357ms/step


step 6200/6443 - loss: 0.1126 - acc: 0.9376 - 357ms/step


step 6300/6443 - loss: 0.1005 - acc: 0.9375 - 357ms/step


step 6400/6443 - loss: 0.1572 - acc: 0.9375 - 358ms/step


step 6443/6443 - loss: 0.1478 - acc: 0.9375 - 357ms/step


## 5. Prediction

In [9]:
import paddle.nn.functional as F
from tqdm import tqdm


def predict(test_loader, id_inpath, out):
    predictions = []
    logits = model.predict(test_loader)

    for batch in tqdm(logits[0]):
        batch = paddle.to_tensor(batch)
        probs = F.softmax(batch, axis=1)
        preds = paddle.argmax(probs, axis=1).numpy().tolist()
        predictions.extend(preds)

    df = pd.read_csv(id_inpath)
    ids = df['pairID'].tolist()
    label_map = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
    results = [[idx, label_map[p]] for idx, p in zip(ids, predictions)]
    columns = ['pairID', 'glod_label']
    pd.DataFrame(results, columns=columns).to_csv(out, index=False)
    print(f"{out} has been saved!")

In [10]:
predict(test1_loader, 'data/data135628/test_matched.csv', 'test_matched_preds.csv')

Predict begin...






step   2/154 [..............................]

 - ETA: 35s - 233ms/step





step   4/154 [..............................]

 - ETA: 29s - 198ms/step





step   6/154 [>.............................]

 - ETA: 25s - 173ms/step





step   8/154 [>.............................]

 - ETA: 24s - 167ms/step





step  10/154 [>.............................]

 - ETA: 23s - 163ms/step





step  12/154 [=>............................]

 - ETA: 22s - 158ms/step





step  14/154 [=>............................]

 - ETA: 21s - 154ms/step





step  16/154 [==>...........................]

 - ETA: 21s - 153ms/step





step  18/154 [==>...........................]

 - ETA: 20s - 154ms/step





step  20/154 [==>...........................]

 - ETA: 20s - 152ms/step





step  22/154 [===>..........................]

 - ETA: 19s - 151ms/step





step  24/154 [===>..........................]

 - ETA: 19s - 151ms/step





step  26/154 [====>.........................]

 - ETA: 19s - 152ms/step





step  28/154 [====>.........................]

 - ETA: 19s - 152ms/step





step  30/154 [====>.........................]

 - ETA: 18s - 152ms/step





step  32/154 [=====>........................]

 - ETA: 18s - 153ms/step





step  34/154 [=====>........................]

 - ETA: 18s - 153ms/step







 - ETA: 18s - 153ms/step







 - ETA: 17s - 153ms/step







 - ETA: 17s - 153ms/step







 - ETA: 17s - 153ms/step







 - ETA: 16s - 153ms/step







 - ETA: 16s - 152ms/step







 - ETA: 16s - 153ms/step







 - ETA: 15s - 152ms/step







 - ETA: 15s - 151ms/step







 - ETA: 15s - 151ms/step







 - ETA: 14s - 151ms/step







 - ETA: 14s - 151ms/step







 - ETA: 14s - 150ms/step







 - ETA: 13s - 150ms/step







 - ETA: 13s - 150ms/step







 - ETA: 13s - 150ms/step







 - ETA: 12s - 150ms/step







 - ETA: 12s - 149ms/step







 - ETA: 12s - 149ms/step







 - ETA: 11s - 150ms/step







 - ETA: 11s - 150ms/step







 - ETA: 11s - 150ms/step







 - ETA: 11s - 149ms/step







 - ETA: 10s - 149ms/step







 - ETA: 10s - 149ms/step







 - ETA: 10s - 149ms/step







 - ETA: 9s - 149ms/step 







 - ETA: 9s - 149ms/step







 - ETA: 9s - 149ms/step







 - ETA: 8s - 149ms/step







 - ETA: 8s - 150ms/step







 - ETA: 8s - 150ms/step







 - ETA: 8s - 151ms/step







 - ETA: 7s - 151ms/step







 - ETA: 7s - 152ms/step







 - ETA: 7s - 152ms/step







 - ETA: 6s - 152ms/step







 - ETA: 6s - 152ms/step







 - ETA: 6s - 151ms/step







 - ETA: 6s - 151ms/step







 - ETA: 5s - 151ms/step







 - ETA: 5s - 151ms/step







 - ETA: 5s - 150ms/step







 - ETA: 4s - 151ms/step







 - ETA: 4s - 151ms/step







 - ETA: 4s - 150ms/step







 - ETA: 3s - 150ms/step







 - ETA: 3s - 150ms/step







 - ETA: 3s - 150ms/step







 - ETA: 3s - 150ms/step







 - ETA: 2s - 150ms/step







 - ETA: 2s - 150ms/step







 - ETA: 2s - 150ms/step







 - ETA: 1s - 150ms/step







 - ETA: 1s - 150ms/step







 - ETA: 1s - 150ms/step







 - ETA: 0s - 150ms/step







 - ETA: 0s - 150ms/step







 - ETA: 0s - 149ms/step







 - 148ms/step          


Predict samples: 9796




  0%|          | 0/154 [00:00<?, ?it/s]

100%|██████████| 154/154 [00:00<00:00, 8663.48it/s]




test_matched_preds.csv has been saved!




In [11]:
predict(test2_loader, 'data/data135628/test_mismatched.csv', 'test_mismatched_preds.csv')

Predict begin...






step   2/154 [..............................]

 - ETA: 34s - 230ms/step





step   4/154 [..............................]

 - ETA: 29s - 194ms/step





step   6/154 [>.............................]

 - ETA: 25s - 171ms/step





step   8/154 [>.............................]

 - ETA: 24s - 166ms/step





step  10/154 [>.............................]

 - ETA: 23s - 161ms/step





step  12/154 [=>............................]

 - ETA: 22s - 158ms/step





step  14/154 [=>............................]

 - ETA: 22s - 158ms/step





step  16/154 [==>...........................]

 - ETA: 21s - 157ms/step





step  18/154 [==>...........................]

 - ETA: 21s - 156ms/step





step  20/154 [==>...........................]

 - ETA: 20s - 155ms/step





step  22/154 [===>..........................]

 - ETA: 20s - 156ms/step





step  24/154 [===>..........................]

 - ETA: 20s - 156ms/step





step  26/154 [====>.........................]

 - ETA: 20s - 157ms/step





step  28/154 [====>.........................]

 - ETA: 19s - 157ms/step





step  30/154 [====>.........................]

 - ETA: 19s - 157ms/step





step  32/154 [=====>........................]

 - ETA: 19s - 157ms/step





step  34/154 [=====>........................]

 - ETA: 18s - 156ms/step







 - ETA: 18s - 157ms/step







 - ETA: 18s - 156ms/step







 - ETA: 17s - 155ms/step







 - ETA: 17s - 155ms/step







 - ETA: 16s - 154ms/step







 - ETA: 16s - 155ms/step







 - ETA: 16s - 155ms/step







 - ETA: 16s - 155ms/step







 - ETA: 15s - 155ms/step







 - ETA: 15s - 155ms/step







 - ETA: 15s - 155ms/step







 - ETA: 14s - 155ms/step







 - ETA: 14s - 154ms/step







 - ETA: 14s - 154ms/step







 - ETA: 13s - 154ms/step







 - ETA: 13s - 153ms/step







 - ETA: 13s - 153ms/step







 - ETA: 12s - 153ms/step







 - ETA: 12s - 152ms/step







 - ETA: 12s - 151ms/step







 - ETA: 11s - 150ms/step







 - ETA: 11s - 150ms/step







 - ETA: 11s - 149ms/step







 - ETA: 10s - 149ms/step







 - ETA: 10s - 149ms/step







 - ETA: 10s - 148ms/step







 - ETA: 9s - 149ms/step 







 - ETA: 9s - 149ms/step







 - ETA: 9s - 149ms/step







 - ETA: 8s - 149ms/step







 - ETA: 8s - 149ms/step







 - ETA: 8s - 149ms/step







 - ETA: 8s - 149ms/step







 - ETA: 7s - 149ms/step







 - ETA: 7s - 149ms/step







 - ETA: 7s - 149ms/step







 - ETA: 6s - 150ms/step







 - ETA: 6s - 150ms/step







 - ETA: 6s - 149ms/step







 - ETA: 5s - 150ms/step







 - ETA: 5s - 149ms/step







 - ETA: 5s - 149ms/step







 - ETA: 5s - 149ms/step







 - ETA: 4s - 149ms/step







 - ETA: 4s - 149ms/step







 - ETA: 4s - 149ms/step







 - ETA: 3s - 149ms/step







 - ETA: 3s - 149ms/step







 - ETA: 3s - 149ms/step







 - ETA: 2s - 149ms/step







 - ETA: 2s - 148ms/step







 - ETA: 2s - 148ms/step







 - ETA: 2s - 148ms/step







 - ETA: 1s - 149ms/step







 - ETA: 1s - 149ms/step







 - ETA: 1s - 149ms/step







 - ETA: 0s - 149ms/step







 - ETA: 0s - 149ms/step







 - ETA: 0s - 149ms/step







 - 148ms/step          


Predict samples: 9847




  0%|          | 0/154 [00:00<?, ?it/s]

100%|██████████| 154/154 [00:00<00:00, 8578.33it/s]




test_mismatched_preds.csv has been saved!


