In [1]:
# !unzip /home/aistudio/data/data78992/lcqmc.zip -d /home/aistudio/data/
!unzip /home/aistudio/data/data78992/paws-x-zh.zip -d /home/aistudio/data/
# !unzip /home/aistudio/data/data78992/bq_corpus.zip -d /home/aistudio/data/

Archive:  /home/aistudio/data/data78992/paws-x-zh.zip


   creating: /home/aistudio/data/paws-x-zh/
  inflating: /home/aistudio/data/paws-x-zh/train.tsv  


   creating: /home/aistudio/data/__MACOSX/
   creating: /home/aistudio/data/__MACOSX/paws-x-zh/
  inflating: /home/aistudio/data/__MACOSX/paws-x-zh/._train.tsv  
  inflating: /home/aistudio/data/paws-x-zh/dev.tsv  
  inflating: /home/aistudio/data/__MACOSX/paws-x-zh/._dev.tsv  
  inflating: /home/aistudio/data/paws-x-zh/License.pdf  
  inflating: /home/aistudio/data/__MACOSX/paws-x-zh/._License.pdf  
  inflating: /home/aistudio/data/paws-x-zh/test.tsv  


  inflating: /home/aistudio/data/__MACOSX/paws-x-zh/._test.tsv  
  inflating: /home/aistudio/data/__MACOSX/._paws-x-zh  


## 1. Load dataset

In [2]:
def load_dataset(fpath, num_row_skip=0):

    def read(fp):
        data = open(fp)

        for _ in range(num_row_skip):
            next(data)

        if "test" in fp:
            for line in data:
                line = line.strip().split('\t')
                yield line[0], line[1]
        else:
            for line in data:
                line = line.strip().split('\t')
                if len(line) == 3:
                    yield line[0], line[1], int(line[2])

    if isinstance(fpath, str):
        return list(read(fpath))
    elif isinstance(fpath, (list, tuple)):
        return [list(read(fp)) for fp in fpath]
    else:
        raise TypeError("Input fpath must be a str or a list/tuple of str")

In [3]:
train_set, dev_set, test_set = load_dataset(['./data/paws-x-zh/train.tsv', './data/paws-x-zh/dev.tsv', './data/paws-x-zh/test.tsv'])
# len(train_set), len(dev_set), len(test_set)
train_set = train_set + dev_set

## 2. Transform text

In [4]:
from paddlenlp.datasets import MapDataset
from paddle.io import BatchSampler, DataLoader
from paddlenlp.data import Pad, Stack, Tuple
from paddlenlp.transformers import RobertaModel as SeqClfModel
from paddlenlp.transformers import RobertaTokenizer as PTMTokenizer
import numpy as np


MODEL_NAME = "roberta-wwm-ext-large"
tokenizer = PTMTokenizer.from_pretrained(MODEL_NAME)


def example_converter(example, max_seq_length, tokenizer):
    text_a, text_b, label = example
    encoded = tokenizer(text=text_a, text_pair=text_b, max_seq_len=max_seq_length)
    input_ids = encoded["input_ids"]
    token_type_ids = encoded["token_type_ids"]
    label = np.array([label], dtype="int64")
    return input_ids, token_type_ids, label


def get_trans_fn(max_seq_length=128, tokenizer=tokenizer):
    return lambda ex: example_converter(ex, max_seq_length, tokenizer)


batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
    Stack(dtype="int64")
    ): fn(samples)


def create_dataloader(dataset,
                      trans_fn,
                      batchify_fn,
                      test=False,
                      batch_size=128,
                      shuffle=True,
                      sampler=BatchSampler):

    if test:
        dataset = [d + (0,) for d in dataset]

    if not isinstance(dataset, MapDataset):
        dataset = MapDataset(dataset)

    dataset.map(trans_fn)
    batch_sampler = sampler(dataset,
                            shuffle=shuffle,
                            batch_size=batch_size)

    dataloder = DataLoader(dataset,
                           batch_sampler=batch_sampler,
                           collate_fn=batchify_fn)

    return dataloder

[2022-03-29 01:32:57,059] [    INFO]

 - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/vocab.txt and saved to /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large




[2022-03-29 01:32:57,062] [    INFO]

 - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/vocab.txt




  0%|          | 0/107 [00:00<?, ?it/s]

100%|██████████| 107/107 [00:00<00:00, 39015.09it/s]




In [5]:
max_seq_length = 128; batch_size = 32
trans_fn = get_trans_fn(max_seq_length)
train_loader = create_dataloader(train_set, trans_fn, batchify_fn, batch_size=batch_size)
# dev_loader = create_dataloader(dev_set, trans_fn, batchify_fn, batch_size=batch_size)
test_loader = create_dataloader(test_set, trans_fn, batchify_fn, shuffle=False, test=True, batch_size=batch_size)

## 3. Model building

In [6]:
from paddle import nn
import paddle


class PTM(nn.Layer):

    def __init__(self, pretrained_model, dropout=0.1, num_class=2):
        super().__init__()

        self.ptm = pretrained_model
        ptm_out_dim = self.ptm.config["hidden_size"]
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(ptm_out_dim, ptm_out_dim // 2)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(ptm_out_dim // 2, num_class)

    def encoder(self, input_ids, token_type_ids):
        _, embd = self.ptm(input_ids, token_type_ids)
        embd = self.dropout(embd)
        return embd

    def forward(self, input_ids, token_type_ids):
        embd = self.encoder(input_ids, token_type_ids)
        hidden = self.relu(self.fc1(embd))
        logits = self.fc2(hidden)
        return logits

In [7]:
from paddlenlp.transformers import LinearDecayWithWarmup

epoch = 4
weight_decay = 0.001
warmup_proportion = 0.1
lr_scheduler = LinearDecayWithWarmup(4e-5, len(train_loader) * epoch,
                                         warmup_proportion)

def get_model(model):
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
    parameters=model.parameters(),
    learning_rate=lr_scheduler,
    weight_decay=weight_decay,
    apply_decay_param_fun=lambda x: x in decay_params)

    criterion = paddle.nn.CrossEntropyLoss()

    model = paddle.Model(model)
    metric = paddle.metric.Accuracy()
    model.prepare(optimizer, criterion, metric)
    return model

In [8]:
ptm = SeqClfModel.from_pretrained(MODEL_NAME)
model = PTM(ptm)
model = get_model(model)

[2022-03-29 01:32:57,233] [    INFO]

 - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/roberta_chn_large.pdparams and saved to /home/aistudio/.paddlenlp/models/roberta-wwm-ext-large




[2022-03-29 01:32:57,235] [    INFO]

 - Downloading roberta_chn_large.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/roberta_large/roberta_chn_large.pdparams




  0%|          | 0/1271615 [00:00<?, ?it/s]

  0%|          | 275/1271615 [00:00<42:57, 493.22it/s]

  0%|          | 5382/1271615 [00:00<30:04, 701.70it/s]

  1%|          | 11569/1271615 [00:00<21:03, 997.57it/s]

  1%|▏         | 17799/1271615 [00:00<14:45, 1415.38it/s]

  2%|▏         | 24137/1271615 [00:00<10:22, 2002.80it/s]

  2%|▏         | 30411/1271615 [00:01<07:19, 2822.53it/s]

  3%|▎         | 36931/1271615 [00:01<05:11, 3958.71it/s]

  3%|▎         | 43223/1271615 [00:01<03:43, 5506.80it/s]

  4%|▍         | 49485/1271615 [00:01<02:41, 7581.14it/s]

  4%|▍         | 55699/1271615 [00:01<01:58, 10292.02it/s]

  5%|▍         | 61946/1271615 [00:01<01:28, 13733.18it/s]

  5%|▌         | 68524/1271615 [00:01<01:06, 18007.57it/s]

  6%|▌         | 74899/1271615 [00:01<00:52, 22946.30it/s]

  6%|▋         | 81171/1271615 [00:01<00:42, 27868.87it/s]

  7%|▋         | 88316/1271615 [00:01<00:34, 34110.63it/s]

  8%|▊         | 95851/1271615 [00:02<00:28, 40811.28it/s]

  8%|▊         | 103150/1271615 [00:02<00:24, 47031.26it/s]

  9%|▊         | 110695/1271615 [00:02<00:21, 53021.58it/s]

  9%|▉         | 117790/1271615 [00:02<00:20, 56940.28it/s]

 10%|▉         | 124821/1271615 [00:02<00:19, 59212.95it/s]

 10%|█         | 131702/1271615 [00:02<00:18, 61325.56it/s]

 11%|█         | 138525/1271615 [00:02<00:18, 62726.08it/s]

 11%|█▏        | 145289/1271615 [00:02<00:17, 64089.04it/s]

 12%|█▏        | 152819/1271615 [00:02<00:16, 67084.93it/s]

 13%|█▎        | 160481/1271615 [00:02<00:15, 69684.12it/s]

 13%|█▎        | 168114/1271615 [00:03<00:15, 71551.50it/s]

 14%|█▍        | 175573/1271615 [00:03<00:15, 72435.05it/s]

 14%|█▍        | 182959/1271615 [00:03<00:14, 72856.03it/s]

 15%|█▍        | 190451/1271615 [00:03<00:14, 73461.50it/s]

 16%|█▌        | 197858/1271615 [00:03<00:15, 69872.73it/s]

 16%|█▌        | 204925/1271615 [00:03<00:15, 68692.20it/s]

 17%|█▋        | 211855/1271615 [00:03<00:15, 66716.29it/s]

 17%|█▋        | 218582/1271615 [00:03<00:16, 65573.46it/s]

 18%|█▊        | 226019/1271615 [00:03<00:15, 67978.48it/s]

 18%|█▊        | 233528/1271615 [00:04<00:14, 69964.57it/s]

 19%|█▉        | 241190/1271615 [00:04<00:14, 71835.02it/s]

 20%|█▉        | 248422/1271615 [00:04<00:14, 70626.75it/s]

 20%|██        | 255523/1271615 [00:04<00:14, 69995.33it/s]

 21%|██        | 262551/1271615 [00:04<00:14, 69950.54it/s]

 21%|██        | 270155/1271615 [00:04<00:13, 71671.94it/s]

 22%|██▏       | 277763/1271615 [00:04<00:13, 72937.94it/s]

 22%|██▏       | 285294/1271615 [00:04<00:13, 73632.56it/s]

 23%|██▎       | 292822/1271615 [00:04<00:13, 74119.03it/s]

 24%|██▎       | 300533/1271615 [00:04<00:12, 74987.99it/s]

 24%|██▍       | 308083/1271615 [00:05<00:12, 75135.75it/s]

 25%|██▍       | 315769/1271615 [00:05<00:12, 75644.28it/s]

 25%|██▌       | 323455/1271615 [00:05<00:12, 76003.56it/s]

 26%|██▌       | 331061/1271615 [00:05<00:12, 75641.28it/s]

 27%|██▋       | 338629/1271615 [00:05<00:13, 69667.39it/s]

 27%|██▋       | 345692/1271615 [00:05<00:17, 53979.40it/s]

 28%|██▊       | 352445/1271615 [00:05<00:16, 57434.99it/s]

 28%|██▊       | 359459/1271615 [00:05<00:15, 60735.08it/s]

 29%|██▉       | 366293/1271615 [00:05<00:14, 62832.26it/s]

 29%|██▉       | 373204/1271615 [00:06<00:13, 64591.50it/s]

 30%|██▉       | 380613/1271615 [00:06<00:13, 67172.76it/s]

 31%|███       | 387908/1271615 [00:06<00:12, 68803.87it/s]

 31%|███       | 395433/1271615 [00:06<00:12, 70616.57it/s]

 32%|███▏      | 403123/1271615 [00:06<00:11, 72390.17it/s]

 32%|███▏      | 410643/1271615 [00:06<00:11, 73200.54it/s]

 33%|███▎      | 418181/1271615 [00:06<00:11, 73840.33it/s]

 33%|███▎      | 425807/1271615 [00:06<00:11, 74548.25it/s]

 34%|███▍      | 433464/1271615 [00:06<00:11, 75141.55it/s]

 35%|███▍      | 441003/1271615 [00:06<00:11, 72554.33it/s]

 35%|███▌      | 448295/1271615 [00:07<00:11, 71450.76it/s]

 36%|███▌      | 455470/1271615 [00:07<00:11, 70705.36it/s]

 36%|███▋      | 462563/1271615 [00:07<00:11, 70366.06it/s]

 37%|███▋      | 469616/1271615 [00:07<00:11, 69254.69it/s]

 37%|███▋      | 476561/1271615 [00:07<00:11, 69309.97it/s]

 38%|███▊      | 483531/1271615 [00:07<00:11, 69423.62it/s]

 39%|███▊      | 490481/1271615 [00:07<00:11, 68948.83it/s]

 39%|███▉      | 497388/1271615 [00:07<00:11, 68984.96it/s]

 40%|███▉      | 504291/1271615 [00:07<00:11, 68910.00it/s]

 40%|████      | 511185/1271615 [00:08<00:11, 68831.95it/s]

 41%|████      | 518112/1271615 [00:08<00:10, 68962.26it/s]

 41%|████▏     | 525010/1271615 [00:08<00:10, 68904.14it/s]

 42%|████▏     | 531911/1271615 [00:08<00:10, 68935.49it/s]

 42%|████▏     | 538806/1271615 [00:08<00:10, 68846.17it/s]

 43%|████▎     | 545692/1271615 [00:08<00:10, 68549.38it/s]

 43%|████▎     | 552548/1271615 [00:08<00:10, 67856.05it/s]

 44%|████▍     | 559336/1271615 [00:08<00:10, 67761.16it/s]

 45%|████▍     | 566214/1271615 [00:08<00:10, 68061.72it/s]

 45%|████▌     | 573171/1271615 [00:08<00:10, 68504.46it/s]

 46%|████▌     | 580099/1271615 [00:09<00:10, 68734.49it/s]

 46%|████▌     | 586974/1271615 [00:09<00:09, 68510.79it/s]

 47%|████▋     | 593955/1271615 [00:09<00:09, 68895.37it/s]

 47%|████▋     | 600921/1271615 [00:09<00:09, 69120.36it/s]

 48%|████▊     | 607901/1271615 [00:09<00:09, 69322.83it/s]

 48%|████▊     | 614835/1271615 [00:09<00:09, 69282.32it/s]

 49%|████▉     | 621793/1271615 [00:09<00:09, 69368.60it/s]

 49%|████▉     | 628731/1271615 [00:09<00:09, 69187.38it/s]

 50%|████▉     | 635651/1271615 [00:09<00:09, 68840.31it/s]

 51%|█████     | 642541/1271615 [00:09<00:09, 68856.11it/s]

 51%|█████     | 649428/1271615 [00:10<00:09, 68805.38it/s]

 52%|█████▏    | 656447/1271615 [00:10<00:08, 69213.99it/s]

 52%|█████▏    | 663370/1271615 [00:10<00:08, 69204.21it/s]

 53%|█████▎    | 670291/1271615 [00:10<00:08, 69140.20it/s]

 53%|█████▎    | 677213/1271615 [00:10<00:08, 69164.04it/s]

 54%|█████▍    | 684130/1271615 [00:10<00:08, 69115.48it/s]

 54%|█████▍    | 691459/1271615 [00:10<00:08, 70315.07it/s]

 55%|█████▍    | 698661/1271615 [00:10<00:08, 70815.05it/s]

 56%|█████▌    | 705747/1271615 [00:10<00:08, 70438.11it/s]

 56%|█████▌    | 712795/1271615 [00:10<00:07, 70163.31it/s]

 57%|█████▋    | 719815/1271615 [00:11<00:07, 69593.59it/s]

 57%|█████▋    | 726778/1271615 [00:11<00:07, 69542.08it/s]

 58%|█████▊    | 733755/1271615 [00:11<00:07, 69608.98it/s]

 58%|█████▊    | 740718/1271615 [00:11<00:07, 69563.77it/s]

 59%|█████▉    | 747751/1271615 [00:11<00:07, 69791.25it/s]

 59%|█████▉    | 754752/1271615 [00:11<00:07, 69852.72it/s]

 60%|█████▉    | 761786/1271615 [00:11<00:07, 69995.67it/s]

 60%|██████    | 768787/1271615 [00:11<00:07, 69256.07it/s]

 61%|██████    | 775715/1271615 [00:11<00:07, 67035.95it/s]

 62%|██████▏   | 782436/1271615 [00:11<00:07, 65344.75it/s]

 62%|██████▏   | 788992/1271615 [00:12<00:07, 65138.33it/s]

 63%|██████▎   | 795521/1271615 [00:12<00:07, 60229.36it/s]

 63%|██████▎   | 801628/1271615 [00:12<00:07, 59819.52it/s]

 64%|██████▎   | 808402/1271615 [00:12<00:07, 61993.32it/s]

 64%|██████▍   | 814664/1271615 [00:12<00:07, 58162.73it/s]

 65%|██████▍   | 820572/1271615 [00:12<00:08, 54469.59it/s]

 65%|██████▍   | 826131/1271615 [00:12<00:08, 51845.34it/s]

 66%|██████▌   | 832939/1271615 [00:12<00:07, 55839.30it/s]

 66%|██████▌   | 839537/1271615 [00:12<00:07, 58537.11it/s]

 66%|██████▋   | 845575/1271615 [00:13<00:07, 59077.45it/s]

 67%|██████▋   | 853036/1271615 [00:13<00:06, 63012.29it/s]

 68%|██████▊   | 859483/1271615 [00:13<00:06, 60597.87it/s]

 68%|██████▊   | 867020/1271615 [00:13<00:06, 64382.28it/s]

 69%|██████▉   | 874645/1271615 [00:13<00:05, 67534.13it/s]

 69%|██████▉   | 882243/1271615 [00:13<00:05, 69860.92it/s]

 70%|██████▉   | 889881/1271615 [00:13<00:05, 71695.93it/s]

 71%|███████   | 897435/1271615 [00:13<00:05, 72806.60it/s]

 71%|███████   | 905100/1271615 [00:13<00:04, 73917.64it/s]

 72%|███████▏  | 912776/1271615 [00:13<00:04, 74747.49it/s]

 72%|███████▏  | 920373/1271615 [00:14<00:04, 75109.43it/s]

 73%|███████▎  | 928060/1271615 [00:14<00:04, 75627.44it/s]

 74%|███████▎  | 935722/1271615 [00:14<00:04, 75921.18it/s]

 74%|███████▍  | 943330/1271615 [00:14<00:04, 75706.10it/s]

 75%|███████▍  | 950912/1271615 [00:14<00:04, 73059.16it/s]

 75%|███████▌  | 958246/1271615 [00:14<00:04, 68894.99it/s]

 76%|███████▌  | 965204/1271615 [00:14<00:04, 68006.18it/s]

 76%|███████▋  | 972055/1271615 [00:14<00:04, 66193.03it/s]

 77%|███████▋  | 978721/1271615 [00:14<00:04, 63976.31it/s]

 77%|███████▋  | 985273/1271615 [00:15<00:04, 64429.25it/s]

 78%|███████▊  | 992547/1271615 [00:15<00:04, 66715.93it/s]

 79%|███████▊  | 1000005/1271615 [00:15<00:03, 68893.79it/s]

 79%|███████▉  | 1007594/1271615 [00:15<00:03, 70851.34it/s]

 80%|███████▉  | 1015270/1271615 [00:15<00:03, 72524.87it/s]

 80%|████████  | 1022958/1271615 [00:15<00:03, 73778.11it/s]

 81%|████████  | 1030654/1271615 [00:15<00:03, 74703.79it/s]

 82%|████████▏ | 1038258/1271615 [00:15<00:03, 75097.92it/s]

 82%|████████▏ | 1045859/1271615 [00:15<00:02, 75368.29it/s]

 83%|████████▎ | 1053557/1271615 [00:15<00:02, 75843.36it/s]

 83%|████████▎ | 1061187/1271615 [00:16<00:02, 75975.20it/s]

 84%|████████▍ | 1068831/1271615 [00:16<00:02, 76113.96it/s]

 85%|████████▍ | 1076509/1271615 [00:16<00:02, 76311.60it/s]

 85%|████████▌ | 1084221/1271615 [00:16<00:02, 76549.80it/s]

 86%|████████▌ | 1091909/1271615 [00:16<00:02, 76647.00it/s]

 86%|████████▋ | 1099576/1271615 [00:16<00:02, 76495.01it/s]

 87%|████████▋ | 1107251/1271615 [00:16<00:02, 76571.08it/s]

 88%|████████▊ | 1114910/1271615 [00:16<00:02, 76576.12it/s]

 88%|████████▊ | 1122643/1271615 [00:16<00:01, 76794.32it/s]

 89%|████████▉ | 1130324/1271615 [00:16<00:01, 76578.40it/s]

 89%|████████▉ | 1138047/1271615 [00:17<00:01, 76770.67it/s]

 90%|█████████ | 1145725/1271615 [00:17<00:01, 76458.24it/s]

 91%|█████████ | 1153396/1271615 [00:17<00:01, 76532.47it/s]

 91%|█████████▏| 1161050/1271615 [00:17<00:01, 76376.99it/s]

 92%|█████████▏| 1168704/1271615 [00:17<00:01, 76422.46it/s]

 93%|█████████▎| 1176390/1271615 [00:17<00:01, 76551.09it/s]

 93%|█████████▎| 1184046/1271615 [00:17<00:01, 76548.15it/s]

 94%|█████████▎| 1191701/1271615 [00:17<00:01, 76246.44it/s]

 94%|█████████▍| 1199349/1271615 [00:17<00:00, 76314.64it/s]

 95%|█████████▍| 1206981/1271615 [00:17<00:00, 75763.82it/s]

 96%|█████████▌| 1214559/1271615 [00:18<00:00, 75671.83it/s]

 96%|█████████▌| 1222127/1271615 [00:18<00:00, 75454.93it/s]

 97%|█████████▋| 1229776/1271615 [00:18<00:00, 75759.08it/s]

 97%|█████████▋| 1237531/1271615 [00:18<00:00, 76287.26it/s]

 98%|█████████▊| 1245162/1271615 [00:18<00:00, 72561.90it/s]

 98%|█████████▊| 1252458/1271615 [00:18<00:00, 72239.01it/s]

 99%|█████████▉| 1259710/1271615 [00:18<00:00, 70999.74it/s]

100%|█████████▉| 1266835/1271615 [00:18<00:00, 69517.79it/s]

100%|██████████| 1271615/1271615 [00:18<00:00, 67510.48it/s]




W0329 01:33:16.140797   256 device_context.cc:447] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1
W0329 01:33:16.144894   256 device_context.cc:465] device: 0, cuDNN Version: 7.6.


## 4. Model training

In [9]:
model.fit(train_loader, epochs=epoch, verbose=2, log_freq=100)

The loss value printed in the log is the current step, and the metric is the average value of previous steps.




Epoch 1/4




step  100/1598 - loss: 0.6788 - acc: 0.5416 - 634ms/step


step  200/1598 - loss: 0.4744 - acc: 0.5670 - 632ms/step


step  300/1598 - loss: 0.3809 - acc: 0.6301 - 631ms/step


step  400/1598 - loss: 0.4158 - acc: 0.6837 - 634ms/step


step  500/1598 - loss: 0.1070 - acc: 0.7194 - 635ms/step


step  600/1598 - loss: 0.5350 - acc: 0.7435 - 636ms/step


step  700/1598 - loss: 0.1247 - acc: 0.7607 - 636ms/step


step  800/1598 - loss: 0.3263 - acc: 0.7761 - 637ms/step


step  900/1598 - loss: 0.2310 - acc: 0.7881 - 636ms/step


step 1000/1598 - loss: 0.4084 - acc: 0.7973 - 636ms/step


step 1100/1598 - loss: 0.2407 - acc: 0.8062 - 637ms/step


step 1200/1598 - loss: 0.2225 - acc: 0.8136 - 637ms/step


step 1300/1598 - loss: 0.4447 - acc: 0.8195 - 637ms/step


step 1400/1598 - loss: 0.2245 - acc: 0.8253 - 637ms/step


step 1500/1598 - loss: 0.2974 - acc: 0.8304 - 636ms/step


step 1598/1598 - loss: 0.1756 - acc: 0.8352 - 636ms/step


Epoch 2/4




step  100/1598 - loss: 0.0413 - acc: 0.9287 - 638ms/step


step  200/1598 - loss: 0.1997 - acc: 0.9295 - 635ms/step


step  300/1598 - loss: 0.1591 - acc: 0.9325 - 636ms/step


step  400/1598 - loss: 0.3017 - acc: 0.9316 - 637ms/step


step  500/1598 - loss: 0.1402 - acc: 0.9333 - 637ms/step


step  600/1598 - loss: 0.1055 - acc: 0.9347 - 637ms/step


step  700/1598 - loss: 0.2448 - acc: 0.9361 - 637ms/step


step  800/1598 - loss: 0.1487 - acc: 0.9374 - 637ms/step


step  900/1598 - loss: 0.1321 - acc: 0.9375 - 637ms/step


step 1000/1598 - loss: 0.0818 - acc: 0.9373 - 638ms/step


step 1100/1598 - loss: 0.1934 - acc: 0.9361 - 638ms/step


step 1200/1598 - loss: 0.2014 - acc: 0.9353 - 639ms/step


step 1300/1598 - loss: 0.1117 - acc: 0.9354 - 639ms/step


step 1400/1598 - loss: 0.2010 - acc: 0.9357 - 639ms/step


step 1500/1598 - loss: 0.0722 - acc: 0.9359 - 639ms/step


step 1598/1598 - loss: 0.1657 - acc: 0.9360 - 639ms/step


Epoch 3/4




step  100/1598 - loss: 0.0575 - acc: 0.9631 - 638ms/step


step  200/1598 - loss: 0.0455 - acc: 0.9673 - 638ms/step


step  300/1598 - loss: 0.1715 - acc: 0.9663 - 638ms/step


step  400/1598 - loss: 0.0355 - acc: 0.9652 - 640ms/step


step  500/1598 - loss: 0.0380 - acc: 0.9657 - 639ms/step


step  600/1598 - loss: 0.1044 - acc: 0.9660 - 639ms/step


step  700/1598 - loss: 0.0357 - acc: 0.9658 - 637ms/step


step  800/1598 - loss: 0.0767 - acc: 0.9660 - 637ms/step


step  900/1598 - loss: 0.0581 - acc: 0.9657 - 638ms/step


step 1000/1598 - loss: 0.1073 - acc: 0.9656 - 638ms/step


step 1100/1598 - loss: 0.0824 - acc: 0.9657 - 638ms/step


step 1200/1598 - loss: 0.1111 - acc: 0.9654 - 638ms/step


step 1300/1598 - loss: 0.1194 - acc: 0.9657 - 638ms/step


step 1400/1598 - loss: 0.0708 - acc: 0.9656 - 639ms/step


step 1500/1598 - loss: 0.2322 - acc: 0.9661 - 638ms/step


step 1598/1598 - loss: 0.0545 - acc: 0.9663 - 638ms/step


Epoch 4/4




step  100/1598 - loss: 0.0307 - acc: 0.9831 - 631ms/step


step  200/1598 - loss: 0.1388 - acc: 0.9838 - 633ms/step


step  300/1598 - loss: 0.0073 - acc: 0.9842 - 634ms/step


step  400/1598 - loss: 0.0067 - acc: 0.9832 - 635ms/step


step  500/1598 - loss: 0.0227 - acc: 0.9833 - 637ms/step


step  600/1598 - loss: 0.0877 - acc: 0.9833 - 638ms/step


step  700/1598 - loss: 0.0078 - acc: 0.9842 - 638ms/step


step  800/1598 - loss: 0.0056 - acc: 0.9842 - 639ms/step


step  900/1598 - loss: 0.0049 - acc: 0.9846 - 638ms/step


step 1000/1598 - loss: 0.0164 - acc: 0.9846 - 638ms/step


step 1100/1598 - loss: 0.0517 - acc: 0.9841 - 638ms/step


step 1200/1598 - loss: 0.0132 - acc: 0.9841 - 638ms/step


step 1300/1598 - loss: 0.0036 - acc: 0.9845 - 638ms/step


step 1400/1598 - loss: 0.0278 - acc: 0.9846 - 639ms/step


step 1500/1598 - loss: 0.0254 - acc: 0.9846 - 638ms/step


step 1598/1598 - loss: 0.0070 - acc: 0.9846 - 638ms/step


## 5. Prediction

In [10]:
import paddle.nn.functional as F


predictions = []
logits = model.predict(test_loader)

for batch in logits[0]:
    batch = paddle.to_tensor(batch)
    probs = F.softmax(batch, axis=1)
    preds = paddle.argmax(probs, axis=1).numpy().tolist()
    predictions.extend(preds)

Predict begin...






step  2/63 [..............................]

 - ETA: 17s - 291ms/step





step  4/63 [>.............................]

 - ETA: 15s - 265ms/step





step  6/63 [=>............................]

 - ETA: 14s - 258ms/step





step  8/63 [==>...........................]

 - ETA: 14s - 255ms/step





step 10/63 [===>..........................]

 - ETA: 13s - 254ms/step





step 12/63 [====>.........................]

 - ETA: 12s - 253ms/step





step 14/63 [=====>........................]

 - ETA: 12s - 252ms/step







 - ETA: 11s - 250ms/step







 - ETA: 11s - 248ms/step







 - ETA: 10s - 248ms/step







 - ETA: 10s - 247ms/step







 - ETA: 9s - 246ms/step 







 - ETA: 9s - 246ms/step







 - ETA: 8s - 246ms/step







 - ETA: 8s - 246ms/step







 - ETA: 7s - 245ms/step







 - ETA: 7s - 244ms/step







 - ETA: 6s - 243ms/step







 - ETA: 6s - 242ms/step







 - ETA: 5s - 242ms/step







 - ETA: 5s - 242ms/step







 - ETA: 4s - 242ms/step







 - ETA: 4s - 242ms/step







 - ETA: 3s - 243ms/step







 - ETA: 3s - 243ms/step







 - ETA: 2s - 243ms/step







 - ETA: 2s - 244ms/step







 - ETA: 1s - 243ms/step







 - ETA: 1s - 243ms/step







 - ETA: 0s - 243ms/step







 - ETA: 0s - 242ms/step







 - 240ms/step          


Predict samples: 2000




In [11]:
with open('paws-x.tsv', 'w') as f:
    f.write("index\tprediction")
    for idx, p in enumerate(predictions):
        f.write(f"\n{idx}\t{p}")
    f.close()