## 1. Load dataset

In [1]:
from os.path import join
import pandas as pd


def relabeler(label):
    if label == 'entailment':
        return 0
    if label == 'neutral':
        return 1
    if label == 'contradiction':
        return 2
    else:
        None


def _load(fname, genre=None, dire='data/data135628/'):
    df = pd.read_csv(join(dire, fname))
    if genre:
        genres = ['slate', 'government', 'telephone',
                  'travel', 'fiction']

        if isinstance(genre, str):
            genre = [genre]
        elif isinstance(genre, (list, tuple)):
            assert all(isinstance(g, str) for g in genre)
        else:
            raise TypeError("genre must be a str or a list of str")

        assert all(g in genres for g in genre), f'genre must be in {genres}'
        df = df[df.isin(genre)['genre']]

    df = df.fillna(" ")
    if 'test' in fname:
        df = df[['sentence1', 'sentence2']]
    else:
        df['gold_label'] = df['gold_label'].apply(relabeler)
        df.dropna(subset=['sentence1', 'sentence2', 'gold_label'], inplace=True)
        df = df[['sentence1', 'sentence2', 'gold_label']]

    out = []
    for row in df.values:
        out.append(row.tolist())
    return out


def load_dataset(fname, genre=None, dire='data/data135628/'):
    if isinstance(fname, str):
        return _load(fname, genre, dire)

    elif isinstance(fname, (list, tuple)):
        return [_load(f, genre, dire) for f in fname]

    else:
        raise TypeError("fname must be a str or a list!")

In [2]:
train_set, dev1_set, dev2_set = load_dataset(['train.csv', 'dev_matched.csv', 'dev_mismatched.csv'])
dev_set = dev1_set + dev2_set

test1_set, test2_set = load_dataset(['test_matched.csv', 'test_mismatched.csv'])
print("Train set size:", len(train_set))
print("Train set examples:", train_set[:2])

print("\nDev set size:", len(dev_set))
print("nDev set examples:", dev_set[:2])

print("\nTest1 set size:", len(test1_set))
print("nTest2 set examples:", test1_set[:2])

print("\nTest2 set size:", len(test2_set))
print("nTest2 set examples:", test2_set[:2])

Train set size: 392702
Train set examples: [['Conceptually cream skimming has two basic dimensions - product and geography.', 'Product and geography are what make cream skimming work. ', 1], ['you know during the season and i guess at at your level uh you lose them to the next level if if they decide to recall the the parent team the Braves decide to call to recall a guy from triple A then a double A guy goes up to replace him and a single A guy goes up to replace him', 'You lose the things to the following level if the people recall.', 0]]

Dev set size: 19647
nDev set examples: [['The new rights are nice enough', 'Everyone really likes the newest benefits ', 1.0], ['This site includes a list of all award winners and a searchable database of Government Executive articles.', 'The Government Executive articles housed on the website are not able to be searched.', 2.0]]

Test1 set size: 9796
nTest2 set examples: [['That which binds together Chinese.', 'This is a shared value among Chinese

## 2. Transform text

In [3]:
from paddlenlp.datasets import MapDataset
from paddle.io import BatchSampler, DataLoader
from paddlenlp.data import Pad, Stack, Tuple
from paddlenlp.transformers import BertModel as SeqClfModel
from paddlenlp.transformers import BertTokenizer as PTMTokenizer


MODEL_NAME = "bert-base-uncased"
tokenizer = PTMTokenizer.from_pretrained(MODEL_NAME)


def example_converter(example, max_seq_length, tokenizer):
    text_a, text_b, label = example
    encoded = tokenizer(text=text_a, text_pair=text_b, max_seq_len=max_seq_length)
    input_ids = encoded["input_ids"]
    token_type_ids = encoded["token_type_ids"]
    return input_ids, token_type_ids, label


def get_trans_fn(max_seq_length=128, tokenizer=tokenizer):
    return lambda ex: example_converter(ex, max_seq_length, tokenizer)


batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    Stack(dtype="int64")
    ): fn(samples)


def create_dataloader(dataset,
                      trans_fn,
                      batchify_fn,
                      test=False,
                      batch_size=128,
                      shuffle=True,
                      sampler=BatchSampler):

    if test:
        dataset = [d + [0] for d in dataset]

    if not isinstance(dataset, MapDataset):
        dataset = MapDataset(dataset)

    dataset.map(trans_fn)
    batch_sampler = sampler(dataset,
                            shuffle=shuffle,
                            batch_size=batch_size)

    dataloder = DataLoader(dataset,
                           batch_sampler=batch_sampler,
                           collate_fn=batchify_fn)

    return dataloder

[2022-04-01 03:14:31,156] [    INFO]

 - Downloading https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-uncased-vocab.txt and saved to /home/aistudio/.paddlenlp/models/bert-base-uncased




[2022-04-01 03:14:31,158] [    INFO]

 - Downloading bert-base-uncased-vocab.txt from https://paddle-hapi.bj.bcebos.com/models/bert/bert-base-uncased-vocab.txt




  0%|          | 0/227 [00:00<?, ?it/s]

100%|██████████| 227/227 [00:00<00:00, 24811.90it/s]




In [4]:
max_seq_length = 128; batch_size = 128
trans_fn = get_trans_fn(max_seq_length)
train_loader = create_dataloader(train_set, trans_fn, batchify_fn, batch_size=batch_size)
dev_loader = create_dataloader(dev_set, trans_fn, batchify_fn, batch_size=batch_size)
test1_loader = create_dataloader(test1_set, trans_fn, batchify_fn, test=True, shuffle=False, batch_size=batch_size)
test2_loader = create_dataloader(test2_set, trans_fn, batchify_fn, test=True, shuffle=False, batch_size=batch_size)

## 3. Model building

In [5]:
from paddle import nn
import paddle


class PTM(nn.Layer):

    def __init__(self, pretrained_model, dropout=0.1, num_class=3):
        super().__init__()

        self.ptm = pretrained_model
        ptm_out_dim = self.ptm.config["hidden_size"]
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(ptm_out_dim, ptm_out_dim // 2)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(ptm_out_dim // 2, num_class)

    def encoder(self, input_ids, token_type_ids):
        _, embd = self.ptm(input_ids, token_type_ids)
        embd = self.dropout(embd)
        return embd

    def forward(self, input_ids, token_type_ids):
        embd = self.encoder(input_ids, token_type_ids)
        hidden = self.relu(self.fc1(embd))
        logits = self.fc2(hidden)
        return logits

In [6]:
from paddlenlp.transformers import LinearDecayWithWarmup

epoch = 10
weight_decay = 0.0
warmup_proportion = 0.0
lr_scheduler = LinearDecayWithWarmup(5e-5, len(train_loader) * epoch,
                                         warmup_proportion)

def get_model(model):
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
    parameters=model.parameters(),
    learning_rate=lr_scheduler,
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

    criterion = paddle.nn.CrossEntropyLoss()

    model = paddle.Model(model)
    metric = paddle.metric.Accuracy()
    model.prepare(optimizer, criterion, metric)
    return model

In [7]:
model = SeqClfModel.from_pretrained(MODEL_NAME)
model = PTM(model)
model = get_model(model)

[2022-04-01 03:14:31,329] [    INFO]

 - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams and saved to /home/aistudio/.paddlenlp/models/bert-base-uncased




[2022-04-01 03:14:31,331] [    INFO]

 - Downloading bert-base-uncased.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/bert-base-uncased.pdparams




  0%|          | 0/793257 [00:00<?, ?it/s]

  1%|          | 5312/793257 [00:00<00:14, 53114.73it/s]

  2%|▏         | 12665/793257 [00:00<00:13, 57937.94it/s]

  3%|▎         | 20176/793257 [00:00<00:12, 62202.82it/s]

  3%|▎         | 27725/793257 [00:00<00:11, 65669.22it/s]

  4%|▍         | 34869/793257 [00:00<00:11, 67300.01it/s]

  5%|▌         | 42419/793257 [00:00<00:10, 69559.78it/s]

  6%|▋         | 49933/793257 [00:00<00:10, 71143.07it/s]

  7%|▋         | 57547/793257 [00:00<00:10, 72570.63it/s]

  8%|▊         | 65173/793257 [00:00<00:09, 73637.14it/s]

  9%|▉         | 72787/793257 [00:01<00:09, 74367.45it/s]

 10%|█         | 80338/793257 [00:01<00:09, 74705.90it/s]

 11%|█         | 87877/793257 [00:01<00:09, 74909.31it/s]

 12%|█▏        | 95314/793257 [00:01<00:09, 74620.18it/s]

 13%|█▎        | 102814/793257 [00:01<00:09, 74731.93it/s]

 14%|█▍        | 110342/793257 [00:01<00:09, 74894.63it/s]

 15%|█▍        | 117814/793257 [00:01<00:09, 74831.88it/s]

 16%|█▌        | 125368/793257 [00:01<00:08, 75039.63it/s]

 17%|█▋        | 132921/793257 [00:01<00:08, 75185.56it/s]

 18%|█▊        | 140463/793257 [00:01<00:08, 75253.72it/s]

 19%|█▊        | 147985/793257 [00:02<00:08, 75032.54it/s]

 20%|█▉        | 155615/793257 [00:02<00:08, 75407.18it/s]

 21%|██        | 163155/793257 [00:02<00:08, 75396.90it/s]

 22%|██▏       | 170848/793257 [00:02<00:08, 75847.46it/s]

 22%|██▏       | 178433/793257 [00:02<00:08, 75283.13it/s]

 23%|██▎       | 185963/793257 [00:02<00:08, 73763.43it/s]

 24%|██▍       | 193347/793257 [00:02<00:08, 73043.76it/s]

 25%|██▌       | 200659/793257 [00:02<00:08, 72597.62it/s]

 26%|██▌       | 207925/793257 [00:02<00:08, 72030.86it/s]

 27%|██▋       | 215133/793257 [00:02<00:08, 65413.15it/s]

 28%|██▊       | 221795/793257 [00:03<00:11, 50752.28it/s]

 29%|██▉       | 228635/793257 [00:03<00:10, 55008.47it/s]

 30%|██▉       | 236151/793257 [00:03<00:09, 59818.86it/s]

 31%|███       | 243425/793257 [00:03<00:08, 63185.73it/s]

 32%|███▏      | 250905/793257 [00:03<00:08, 66271.81it/s]

 33%|███▎      | 258373/793257 [00:03<00:07, 68586.55it/s]

 34%|███▎      | 265887/793257 [00:03<00:07, 70427.98it/s]

 34%|███▍      | 273477/793257 [00:03<00:07, 71983.76it/s]

 35%|███▌      | 280987/793257 [00:03<00:07, 72889.89it/s]

 36%|███▋      | 288402/793257 [00:04<00:06, 73262.18it/s]

 37%|███▋      | 295797/793257 [00:04<00:06, 73445.33it/s]

 38%|███▊      | 303333/793257 [00:04<00:06, 74008.99it/s]

 39%|███▉      | 310784/793257 [00:04<00:06, 74157.02it/s]

 40%|████      | 318224/793257 [00:04<00:06, 73370.51it/s]

 41%|████      | 325736/793257 [00:04<00:06, 73884.58it/s]

 42%|████▏     | 333139/793257 [00:04<00:06, 73882.78it/s]

 43%|████▎     | 340538/793257 [00:04<00:06, 73841.44it/s]

 44%|████▍     | 348128/793257 [00:04<00:05, 74446.48it/s]

 45%|████▍     | 355640/793257 [00:04<00:05, 74646.12it/s]

 46%|████▌     | 363109/793257 [00:05<00:05, 74042.63it/s]

 47%|████▋     | 370662/793257 [00:05<00:05, 74481.73it/s]

 48%|████▊     | 378225/793257 [00:05<00:05, 74820.22it/s]

 49%|████▊     | 385710/793257 [00:05<00:05, 74731.97it/s]

 50%|████▉     | 393235/793257 [00:05<00:05, 74880.73it/s]

 51%|█████     | 400725/793257 [00:05<00:05, 72883.02it/s]

 51%|█████▏    | 408069/793257 [00:05<00:05, 73046.54it/s]

 52%|█████▏    | 415635/793257 [00:05<00:05, 73810.00it/s]

 53%|█████▎    | 423243/793257 [00:05<00:04, 74473.96it/s]

 54%|█████▍    | 430698/793257 [00:05<00:04, 73466.12it/s]

 55%|█████▌    | 438053/793257 [00:06<00:04, 72388.76it/s]

 56%|█████▌    | 445571/793257 [00:06<00:04, 73200.37it/s]

 57%|█████▋    | 452993/793257 [00:06<00:04, 73500.98it/s]

 58%|█████▊    | 460479/793257 [00:06<00:04, 73901.62it/s]

 59%|█████▉    | 468054/793257 [00:06<00:04, 74445.25it/s]

 60%|█████▉    | 475702/793257 [00:06<00:04, 75043.13it/s]

 61%|██████    | 483211/793257 [00:06<00:04, 75025.51it/s]

 62%|██████▏   | 490717/793257 [00:06<00:04, 74912.89it/s]

 63%|██████▎   | 498229/793257 [00:06<00:03, 74974.60it/s]

 64%|██████▍   | 505728/793257 [00:06<00:03, 74956.19it/s]

 65%|██████▍   | 513225/793257 [00:07<00:03, 74693.53it/s]

 66%|██████▌   | 520696/793257 [00:07<00:03, 72750.90it/s]

 67%|██████▋   | 528033/793257 [00:07<00:03, 72934.72it/s]

 68%|██████▊   | 535461/793257 [00:07<00:03, 73331.52it/s]

 68%|██████▊   | 542907/793257 [00:07<00:03, 73664.39it/s]

 69%|██████▉   | 550279/793257 [00:07<00:03, 67806.68it/s]

 70%|███████   | 557155/793257 [00:07<00:03, 61785.38it/s]

 71%|███████   | 563504/793257 [00:07<00:03, 59454.43it/s]

 72%|███████▏  | 569588/793257 [00:07<00:03, 58138.73it/s]

 73%|███████▎  | 576747/793257 [00:08<00:03, 61610.98it/s]

 74%|███████▎  | 584076/793257 [00:08<00:03, 64704.05it/s]

 75%|███████▍  | 591651/793257 [00:08<00:02, 67663.65it/s]

 76%|███████▌  | 599252/793257 [00:08<00:02, 69966.62it/s]

 77%|███████▋  | 606857/793257 [00:08<00:02, 71686.08it/s]

 77%|███████▋  | 614291/793257 [00:08<00:02, 72460.90it/s]

 78%|███████▊  | 621861/793257 [00:08<00:02, 73400.37it/s]

 79%|███████▉  | 629248/793257 [00:08<00:02, 68048.60it/s]

 80%|████████  | 636164/793257 [00:08<00:02, 66180.41it/s]

 81%|████████  | 642871/793257 [00:09<00:02, 61325.08it/s]

 82%|████████▏ | 649137/793257 [00:09<00:02, 57999.23it/s]

 83%|████████▎ | 655068/793257 [00:09<00:02, 49881.47it/s]

 83%|████████▎ | 660348/793257 [00:09<00:02, 47265.84it/s]

 84%|████████▍ | 665307/793257 [00:09<00:02, 45578.77it/s]

 84%|████████▍ | 670041/793257 [00:09<00:03, 39947.80it/s]

 85%|████████▌ | 674285/793257 [00:09<00:03, 38666.31it/s]

 86%|████████▌ | 681329/793257 [00:09<00:02, 44717.21it/s]

 87%|████████▋ | 688507/793257 [00:10<00:02, 50419.66it/s]

 88%|████████▊ | 695918/793257 [00:10<00:01, 55767.41it/s]

 89%|████████▊ | 703241/793257 [00:10<00:01, 60063.64it/s]

 90%|████████▉ | 710643/793257 [00:10<00:01, 63661.65it/s]

 91%|█████████ | 718010/793257 [00:10<00:01, 66365.49it/s]

 91%|█████████▏| 725405/793257 [00:10<00:00, 68470.82it/s]

 92%|█████████▏| 732703/793257 [00:10<00:00, 69762.45it/s]

 93%|█████████▎| 740230/793257 [00:10<00:00, 71326.32it/s]

 94%|█████████▍| 747831/793257 [00:10<00:00, 72669.45it/s]

 95%|█████████▌| 755195/793257 [00:10<00:00, 72676.28it/s]

 96%|█████████▌| 762654/793257 [00:11<00:00, 73238.38it/s]

 97%|█████████▋| 770154/793257 [00:11<00:00, 73756.11it/s]

 98%|█████████▊| 777596/793257 [00:11<00:00, 73953.63it/s]

 99%|█████████▉| 785135/793257 [00:11<00:00, 74377.21it/s]

100%|█████████▉| 792595/793257 [00:11<00:00, 74442.55it/s]

100%|██████████| 793257/793257 [00:11<00:00, 69262.06it/s]




W0401 03:14:42.866938   257 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0401 03:14:42.871160   257 device_context.cc:465] device: 0, cuDNN Version: 7.6.


[2022-04-01 03:14:53,557] [    INFO]

 - Weights from pretrained model not used in BertModel: ['cls.predictions.decoder_weight', 'cls.predictions.decoder_bias', 'cls.predictions.transform.weight', 'cls.predictions.transform.bias', 'cls.predictions.layer_norm.weight', 'cls.predictions.layer_norm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']




## 4. Model training

In [8]:
from paddle.callbacks import EarlyStopping


earlystop = EarlyStopping(patience=3)
model.fit(train_loader, dev_loader, epochs=epoch, verbose=2, log_freq=100, callbacks=[earlystop])

The loss value printed in the log is the current step, and the metric is the average value of previous steps.




Epoch 1/10




step  100/3068 - loss: 0.7838 - acc: 0.5599 - 741ms/step


step  200/3068 - loss: 0.5774 - acc: 0.6351 - 742ms/step


step  300/3068 - loss: 0.5790 - acc: 0.6703 - 741ms/step


step  400/3068 - loss: 0.5250 - acc: 0.6931 - 740ms/step


step  500/3068 - loss: 0.6629 - acc: 0.7078 - 738ms/step


step  600/3068 - loss: 0.4875 - acc: 0.7189 - 740ms/step


step  700/3068 - loss: 0.6679 - acc: 0.7274 - 741ms/step


step  800/3068 - loss: 0.4296 - acc: 0.7343 - 741ms/step


step  900/3068 - loss: 0.4747 - acc: 0.7409 - 742ms/step


step 1000/3068 - loss: 0.5366 - acc: 0.7460 - 743ms/step


step 1100/3068 - loss: 0.6329 - acc: 0.7509 - 744ms/step


step 1200/3068 - loss: 0.5263 - acc: 0.7545 - 744ms/step


step 1300/3068 - loss: 0.6339 - acc: 0.7577 - 742ms/step


step 1400/3068 - loss: 0.4434 - acc: 0.7609 - 743ms/step


step 1500/3068 - loss: 0.5852 - acc: 0.7639 - 744ms/step


step 1600/3068 - loss: 0.6341 - acc: 0.7660 - 744ms/step


step 1700/3068 - loss: 0.4575 - acc: 0.7684 - 744ms/step


step 1800/3068 - loss: 0.5728 - acc: 0.7703 - 743ms/step


step 1900/3068 - loss: 0.5242 - acc: 0.7722 - 743ms/step


step 2000/3068 - loss: 0.5130 - acc: 0.7742 - 743ms/step


step 2100/3068 - loss: 0.4959 - acc: 0.7759 - 743ms/step


step 2200/3068 - loss: 0.3243 - acc: 0.7776 - 743ms/step


step 2300/3068 - loss: 0.4933 - acc: 0.7791 - 743ms/step


step 2400/3068 - loss: 0.5193 - acc: 0.7807 - 743ms/step


step 2500/3068 - loss: 0.5234 - acc: 0.7821 - 743ms/step


step 2600/3068 - loss: 0.5142 - acc: 0.7832 - 743ms/step


step 2700/3068 - loss: 0.5594 - acc: 0.7840 - 743ms/step


step 2800/3068 - loss: 0.5498 - acc: 0.7852 - 744ms/step


step 2900/3068 - loss: 0.4518 - acc: 0.7862 - 744ms/step


step 3000/3068 - loss: 0.3956 - acc: 0.7875 - 744ms/step


step 3068/3068 - loss: 0.4777 - acc: 0.7882 - 743ms/step


Eval begin...




step 100/154 - loss: 0.4897 - acc: 0.8319 - 311ms/step


step 154/154 - loss: 0.2996 - acc: 0.8317 - 304ms/step


Eval samples: 19647




Epoch 2/10




step  100/3068 - loss: 0.4709 - acc: 0.8742 - 743ms/step


step  200/3068 - loss: 0.3352 - acc: 0.8703 - 749ms/step


step  300/3068 - loss: 0.3576 - acc: 0.8698 - 747ms/step


step  400/3068 - loss: 0.2614 - acc: 0.8706 - 747ms/step


step  500/3068 - loss: 0.4212 - acc: 0.8718 - 745ms/step


step  600/3068 - loss: 0.3960 - acc: 0.8712 - 745ms/step


step  700/3068 - loss: 0.3967 - acc: 0.8709 - 744ms/step


step  800/3068 - loss: 0.4181 - acc: 0.8703 - 744ms/step


step  900/3068 - loss: 0.4204 - acc: 0.8700 - 746ms/step


step 1000/3068 - loss: 0.3436 - acc: 0.8700 - 745ms/step


step 1100/3068 - loss: 0.5002 - acc: 0.8698 - 745ms/step


step 1200/3068 - loss: 0.4327 - acc: 0.8700 - 745ms/step


step 1300/3068 - loss: 0.3308 - acc: 0.8694 - 745ms/step


step 1400/3068 - loss: 0.4334 - acc: 0.8692 - 746ms/step


step 1500/3068 - loss: 0.3106 - acc: 0.8692 - 745ms/step


step 1600/3068 - loss: 0.3062 - acc: 0.8692 - 744ms/step


step 1700/3068 - loss: 0.3123 - acc: 0.8692 - 744ms/step


step 1800/3068 - loss: 0.3616 - acc: 0.8689 - 744ms/step


step 1900/3068 - loss: 0.3642 - acc: 0.8686 - 744ms/step


step 2000/3068 - loss: 0.4625 - acc: 0.8685 - 744ms/step


step 2100/3068 - loss: 0.3224 - acc: 0.8684 - 743ms/step


step 2200/3068 - loss: 0.3532 - acc: 0.8684 - 743ms/step


step 2300/3068 - loss: 0.3970 - acc: 0.8682 - 742ms/step


step 2400/3068 - loss: 0.3193 - acc: 0.8682 - 743ms/step


step 2500/3068 - loss: 0.2957 - acc: 0.8680 - 743ms/step


step 2600/3068 - loss: 0.3881 - acc: 0.8679 - 743ms/step


step 2700/3068 - loss: 0.3674 - acc: 0.8680 - 743ms/step


step 2800/3068 - loss: 0.3816 - acc: 0.8679 - 742ms/step


step 2900/3068 - loss: 0.3445 - acc: 0.8680 - 743ms/step


step 3000/3068 - loss: 0.3494 - acc: 0.8678 - 744ms/step


step 3068/3068 - loss: 0.3549 - acc: 0.8679 - 744ms/step


Eval begin...




step 100/154 - loss: 0.4668 - acc: 0.8430 - 307ms/step


step 154/154 - loss: 0.3350 - acc: 0.8409 - 306ms/step


Eval samples: 19647




Epoch 3/10




step  100/3068 - loss: 0.1537 - acc: 0.9210 - 732ms/step


step  200/3068 - loss: 0.1959 - acc: 0.9220 - 738ms/step


step  300/3068 - loss: 0.2540 - acc: 0.9221 - 737ms/step


step  400/3068 - loss: 0.2166 - acc: 0.9223 - 739ms/step


step  500/3068 - loss: 0.2577 - acc: 0.9214 - 741ms/step


step  600/3068 - loss: 0.2583 - acc: 0.9206 - 743ms/step


step  700/3068 - loss: 0.1490 - acc: 0.9206 - 744ms/step


step  800/3068 - loss: 0.1987 - acc: 0.9203 - 743ms/step


step  900/3068 - loss: 0.2329 - acc: 0.9198 - 745ms/step


step 1000/3068 - loss: 0.3312 - acc: 0.9193 - 746ms/step


step 1100/3068 - loss: 0.1789 - acc: 0.9187 - 745ms/step


step 1200/3068 - loss: 0.1393 - acc: 0.9186 - 745ms/step


step 1300/3068 - loss: 0.2628 - acc: 0.9185 - 745ms/step


step 1400/3068 - loss: 0.2384 - acc: 0.9183 - 746ms/step


step 1500/3068 - loss: 0.2655 - acc: 0.9184 - 746ms/step


step 1600/3068 - loss: 0.2076 - acc: 0.9185 - 745ms/step


step 1700/3068 - loss: 0.2253 - acc: 0.9186 - 745ms/step


step 1800/3068 - loss: 0.2975 - acc: 0.9183 - 745ms/step


step 1900/3068 - loss: 0.2470 - acc: 0.9181 - 744ms/step


step 2000/3068 - loss: 0.1712 - acc: 0.9179 - 744ms/step


step 2100/3068 - loss: 0.1915 - acc: 0.9176 - 745ms/step


step 2200/3068 - loss: 0.2746 - acc: 0.9175 - 744ms/step


step 2300/3068 - loss: 0.1378 - acc: 0.9174 - 745ms/step


step 2400/3068 - loss: 0.2144 - acc: 0.9171 - 744ms/step


step 2500/3068 - loss: 0.1921 - acc: 0.9170 - 744ms/step


step 2600/3068 - loss: 0.2603 - acc: 0.9169 - 745ms/step


step 2700/3068 - loss: 0.2007 - acc: 0.9167 - 744ms/step


step 2800/3068 - loss: 0.3456 - acc: 0.9165 - 745ms/step


step 2900/3068 - loss: 0.2530 - acc: 0.9164 - 745ms/step


step 3000/3068 - loss: 0.2102 - acc: 0.9164 - 745ms/step


step 3068/3068 - loss: 0.1773 - acc: 0.9164 - 745ms/step


Eval begin...




step 100/154 - loss: 0.3725 - acc: 0.8345 - 310ms/step


step 154/154 - loss: 0.6539 - acc: 0.8401 - 303ms/step


Eval samples: 19647




Epoch 4/10




step  100/3068 - loss: 0.1205 - acc: 0.9562 - 743ms/step


step  200/3068 - loss: 0.1381 - acc: 0.9565 - 749ms/step


step  300/3068 - loss: 0.0722 - acc: 0.9549 - 744ms/step


step  400/3068 - loss: 0.1292 - acc: 0.9545 - 744ms/step


step  500/3068 - loss: 0.0695 - acc: 0.9553 - 740ms/step


step  600/3068 - loss: 0.1514 - acc: 0.9547 - 741ms/step


step  700/3068 - loss: 0.1944 - acc: 0.9538 - 742ms/step


step  800/3068 - loss: 0.0716 - acc: 0.9534 - 743ms/step


step  900/3068 - loss: 0.0873 - acc: 0.9533 - 742ms/step


step 1000/3068 - loss: 0.1902 - acc: 0.9528 - 743ms/step


step 1100/3068 - loss: 0.1102 - acc: 0.9523 - 743ms/step


step 1200/3068 - loss: 0.1941 - acc: 0.9520 - 744ms/step


step 1300/3068 - loss: 0.1534 - acc: 0.9518 - 744ms/step


step 1400/3068 - loss: 0.1946 - acc: 0.9513 - 744ms/step


step 1500/3068 - loss: 0.1228 - acc: 0.9510 - 744ms/step


step 1600/3068 - loss: 0.1300 - acc: 0.9510 - 745ms/step


step 1700/3068 - loss: 0.1397 - acc: 0.9510 - 746ms/step


step 1800/3068 - loss: 0.1945 - acc: 0.9508 - 745ms/step


step 1900/3068 - loss: 0.1754 - acc: 0.9504 - 746ms/step


step 2000/3068 - loss: 0.1699 - acc: 0.9501 - 746ms/step


step 2100/3068 - loss: 0.1964 - acc: 0.9500 - 745ms/step


step 2200/3068 - loss: 0.0838 - acc: 0.9498 - 744ms/step


step 2300/3068 - loss: 0.1931 - acc: 0.9497 - 744ms/step


step 2400/3068 - loss: 0.1712 - acc: 0.9495 - 745ms/step


step 2500/3068 - loss: 0.1660 - acc: 0.9493 - 745ms/step


step 2600/3068 - loss: 0.1941 - acc: 0.9494 - 744ms/step


step 2700/3068 - loss: 0.1653 - acc: 0.9492 - 744ms/step


step 2800/3068 - loss: 0.1279 - acc: 0.9491 - 744ms/step


step 2900/3068 - loss: 0.1054 - acc: 0.9491 - 744ms/step


step 3000/3068 - loss: 0.1829 - acc: 0.9490 - 743ms/step


step 3068/3068 - loss: 0.1874 - acc: 0.9489 - 744ms/step


Eval begin...




step 100/154 - loss: 0.4570 - acc: 0.8334 - 305ms/step


step 154/154 - loss: 0.9378 - acc: 0.8313 - 302ms/step


Eval samples: 19647




Epoch 4: Early stopping.




## 5. Prediction

In [9]:
import paddle.nn.functional as F
from tqdm import tqdm


def predict(test_loader, id_inpath, out):
    predictions = []
    logits = model.predict(test_loader)

    for batch in tqdm(logits[0]):
        batch = paddle.to_tensor(batch)
        probs = F.softmax(batch, axis=1)
        preds = paddle.argmax(probs, axis=1).numpy().tolist()
        predictions.extend(preds)

    df = pd.read_csv(id_inpath)
    ids = df['pairID'].tolist()
    label_map = {0: 'entailment', 1: 'neutral', 2: 'contradiction'}
    results = [[idx, label_map[p]] for idx, p in zip(ids, predictions)]
    columns = ['pairID', 'gold_label']
    pd.DataFrame(results, columns=columns).to_csv(out, index=False)
    print(f"{out} has been saved!")

In [10]:
predict(test1_loader, 'data/data135628/test_matched.csv', 'test_matched_preds.csv')

Predict begin...






step  2/77 [..............................]

 - ETA: 32s - 430ms/step





step  4/77 [>.............................]

 - ETA: 25s - 349ms/step





step  6/77 [=>............................]

 - ETA: 23s - 324ms/step





step  8/77 [==>...........................]

 - ETA: 22s - 321ms/step





step 10/77 [==>...........................]

 - ETA: 21s - 317ms/step





step 12/77 [===>..........................]

 - ETA: 20s - 315ms/step





step 14/77 [====>.........................]

 - ETA: 19s - 317ms/step





step 16/77 [=====>........................]

 - ETA: 19s - 319ms/step







 - ETA: 19s - 323ms/step







 - ETA: 18s - 325ms/step







 - ETA: 17s - 325ms/step







 - ETA: 17s - 323ms/step







 - ETA: 16s - 319ms/step







 - ETA: 15s - 318ms/step







 - ETA: 14s - 315ms/step







 - ETA: 14s - 317ms/step







 - ETA: 13s - 316ms/step







 - ETA: 12s - 314ms/step







 - ETA: 12s - 313ms/step







 - ETA: 11s - 311ms/step







 - ETA: 10s - 308ms/step







 - ETA: 10s - 309ms/step







 - ETA: 9s - 309ms/step 







 - ETA: 8s - 310ms/step







 - ETA: 8s - 311ms/step







 - ETA: 7s - 311ms/step







 - ETA: 7s - 310ms/step







 - ETA: 6s - 308ms/step







 - ETA: 5s - 309ms/step







 - ETA: 5s - 307ms/step







 - ETA: 4s - 307ms/step







 - ETA: 3s - 306ms/step







 - ETA: 3s - 307ms/step







 - ETA: 2s - 306ms/step







 - ETA: 2s - 305ms/step







 - ETA: 1s - 305ms/step







 - ETA: 0s - 304ms/step







 - ETA: 0s - 303ms/step







 - 300ms/step          


Predict samples: 9796




  0%|          | 0/77 [00:00<?, ?it/s]

100%|██████████| 77/77 [00:00<00:00, 8171.07it/s]




test_matched_preds.csv has been saved!




In [11]:
predict(test2_loader, 'data/data135628/test_mismatched.csv', 'test_mismatched_preds.csv')

Predict begin...






step  2/77 [..............................]

 - ETA: 33s - 452ms/step





step  4/77 [>.............................]

 - ETA: 26s - 366ms/step





step  6/77 [=>............................]

 - ETA: 24s - 344ms/step





step  8/77 [==>...........................]

 - ETA: 23s - 344ms/step





step 10/77 [==>...........................]

 - ETA: 22s - 334ms/step





step 12/77 [===>..........................]

 - ETA: 21s - 333ms/step





step 14/77 [====>.........................]

 - ETA: 20s - 329ms/step





step 16/77 [=====>........................]

 - ETA: 19s - 328ms/step







 - ETA: 19s - 325ms/step







 - ETA: 18s - 319ms/step







 - ETA: 17s - 317ms/step







 - ETA: 16s - 319ms/step







 - ETA: 16s - 317ms/step







 - ETA: 15s - 319ms/step







 - ETA: 14s - 317ms/step







 - ETA: 14s - 317ms/step







 - ETA: 13s - 315ms/step







 - ETA: 12s - 311ms/step







 - ETA: 11s - 307ms/step







 - ETA: 11s - 306ms/step







 - ETA: 10s - 304ms/step







 - ETA: 10s - 304ms/step







 - ETA: 9s - 305ms/step 







 - ETA: 8s - 304ms/step







 - ETA: 8s - 304ms/step







 - ETA: 7s - 305ms/step







 - ETA: 7s - 304ms/step







 - ETA: 6s - 304ms/step







 - ETA: 5s - 303ms/step







 - ETA: 5s - 303ms/step







 - ETA: 4s - 303ms/step







 - ETA: 3s - 304ms/step







 - ETA: 3s - 303ms/step







 - ETA: 2s - 302ms/step







 - ETA: 2s - 301ms/step







 - ETA: 1s - 302ms/step







 - ETA: 0s - 301ms/step







 - ETA: 0s - 299ms/step







 - 298ms/step          


Predict samples: 9847




  0%|          | 0/77 [00:00<?, ?it/s]

100%|██████████| 77/77 [00:00<00:00, 8004.60it/s]




test_mismatched_preds.csv has been saved!


