In [31]:
import os, sys
import paddle
sys.path.append('/workspace/fnet_paddle/PaddleNLP')
from paddlenlp.datasets import load_dataset

In [32]:
test_ds = load_dataset("glue", name="cola", splits=("test"))

INFO:paddle.utils.download:unique_endpoints {''}


In [33]:
len(test_ds)

1063

In [34]:
def convert_example(example,
                    tokenizer,
                    max_seq_length=512,
                    is_test=False):
    text = example["sentence"]
    text_pair = None
    encoded_inputs = tokenizer(
        text=text, text_pair=text_pair, max_seq_len=max_seq_length)
    input_ids = encoded_inputs["input_ids"]
    token_type_ids = encoded_inputs["token_type_ids"]

    if is_test:
        return input_ids, token_type_ids
    label = np.array([example["labels"]], dtype="int64")
    return input_ids, token_type_ids, label

In [35]:
def create_dataloader(dataset,
                      mode='train',
                      batch_size=1,
                      batchify_fn=None,
                      trans_fn=None):
    if trans_fn:
        dataset = dataset.map(trans_fn)

    shuffle = True if mode == 'train' else False
    if mode == 'train':
        batch_sampler = paddle.io.DistributedBatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)
    else:
        batch_sampler = paddle.io.BatchSampler(
            dataset, batch_size=batch_size, shuffle=shuffle)

    return paddle.io.DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        return_list=True)

In [36]:
import argparse
import os

import paddle
import paddle.nn.functional as F
import paddlenlp as ppnlp
from paddlenlp.data import Tuple, Pad
from functools import partial

In [37]:
parser = argparse.ArgumentParser()
parser.add_argument("--params_path", type=str, required=False, default="checkpoints/model_900/model_state.pdparams", help="The path to model parameters to be loaded.")
parser.add_argument("--max_seq_length", type=int, default=128, help="The maximum total input sequence length after tokenization. "
    "Sequences longer than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size per GPU/CPU for training.")
parser.add_argument('--device', choices=['cpu', 'gpu', 'xpu', 'npu'], default="gpu", help="Select which device to train model, defaults to gpu.")
args = parser.parse_args([])

In [38]:
args.params_path = 'checkpoints/model_1200/model_state.pdparams'

In [39]:
fnet = ppnlp.transformers.FNetModel.from_pretrained('pretrained_model/paddle/large')
model = ppnlp.transformers.FNetForSequenceClassification(fnet, num_classes=len(test_ds.label_list))
tokenizer = ppnlp.transformers.FNetTokenizer.from_pretrained('fnet-large')

In [40]:
if args.params_path and os.path.isfile(args.params_path):
    state_dict = paddle.load(args.params_path)
    model.set_dict(state_dict)
    print("Loaded parameters from %s" % args.params_path)

In [41]:
trans_func = partial(
    convert_example,
    tokenizer=tokenizer,
    max_seq_length=args.max_seq_length,
    is_test=True)
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # segment
): [data for data in fn(samples)]

test_data_loader = create_dataloader(
    test_ds,
    mode='test',
    batch_size=args.batch_size,
    batchify_fn=batchify_fn,
    trans_fn=trans_func)

In [42]:
results = []
model.eval()
for batch in test_data_loader:
    input_ids, token_type_ids = batch
    logits = model(input_ids, token_type_ids)
    probs = F.softmax(logits, axis=1)
    idx = paddle.argmax(probs, axis=1).numpy()
    idx = idx.tolist()
    results.extend(idx)

In [43]:
results[:3]

[1, 1, 1]

In [44]:
import pandas as pd
res_df = pd.DataFrame()

In [45]:
res_df['prediction'] = results
res_df.index.name = 'index'

In [46]:
res_df.head(5)

Unnamed: 0_level_0,prediction
index,Unnamed: 1_level_1
0,1
1,1
2,1
3,1
4,1


In [47]:
(res_df['prediction'] == 1).sum()

1063

In [48]:
res_df.to_csv('CoLA.tsv', sep='\t')