In [1]:
import datasets

In [2]:
config = datasets.DownloadConfig(resume_download=True, max_retries=100)

In [4]:
ds = datasets.load_dataset('ranWang/un_pdf_random_preprocessed', cache_dir="./hf_cache",
                           verification_mode="no_checks", download_config=config)

Found cached dataset parquet (/home/jia/workspace/parallel_corpus_mnbvc/notebooks/./hf_cache/ranWang___parquet/ranWang--un_pdf_random_preprocessed-c033500c86c8bab0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
ds

DatasetDict({
    train: Dataset({
        features: ['zh', 'en', 'fr', 'es', 'ru', 'record'],
        num_rows: 15293
    })
})

In [14]:
def convert_text_to_annotated_line_pair(text, is_hard_linebreak):
    text_lines = text.split('\n')
    assert len(text_lines) == len(is_hard_linebreak)+1
    line_pair_list = []
    label_list = []
    for i in range(len(is_hard_linebreak)):
        line = text_lines[i]
#         line = " ".join(line.split(' ')[-16:])
        line_next = text_lines[i+1]
#         line_next = " ".join(line_next.split(' ')[:15])
        label = is_hard_linebreak[i]
        line_pair_list.append(line + "\n" + line_next)
        label_list.append(label)
    return line_pair_list, label_list

In [15]:
import pathlib
import json
label_dir = pathlib.Path("./batch_cache/done/")
record_linebreak_dict = {}
for path in label_dir.glob('*.list'):
    record_id = path.stem
    with path.open("r") as infile:
        is_hard_linebreak = json.load(infile)
    record_linebreak_dict[record_id] = is_hard_linebreak

In [16]:
ds_train = ds['train']
ds_train = ds_train.filter(lambda x: x['record'] in record_linebreak_dict)

Loading cached processed dataset at /home/jia/workspace/parallel_corpus_mnbvc/notebooks/hf_cache/ranWang___parquet/ranWang--un_pdf_random_preprocessed-c033500c86c8bab0/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-fa7fbcca160ead70.arrow


In [26]:
pair_ds_train_dict = {"text": [], "label": []}
pair_ds_test_dict = {"text": [], "label": []}

for sample in ds_train.select(range(5000)):
    raw_text = sample["en"]
    record_id = sample["record"]
    text, label = convert_text_to_annotated_line_pair(raw_text, record_linebreak_dict[record_id])
    pair_ds_train_dict["text"].extend(text)
    pair_ds_train_dict["label"].extend(label)
    
for sample in ds_train.select(range(4000, 4100)):
    raw_text = sample["en"]
    record_id = sample["record"]
    text, label = convert_text_to_annotated_line_pair(raw_text, record_linebreak_dict[record_id])
    pair_ds_test_dict["text"].extend(text)
    pair_ds_test_dict["label"].extend(label)    

In [27]:
pair_ds = datasets.DatasetDict({
    "train": datasets.Dataset.from_dict(pair_ds_train_dict),
    "test": datasets.Dataset.from_dict(pair_ds_test_dict),
})

In [28]:
import pandas as pd

In [29]:
pd.DataFrame(pair_ds["train"][:20])

Unnamed: 0,text,label
0,General Assembly Distr.: General\n8 June 2004,True
1,8 June 2004\nOriginal: English,True
2,Original: English\n2,True
3,2\nUNCITRAL Digest on the CISG,True
4,UNCITRAL Digest on the CISG\n1. This provision...,True
5,1. This provision sets out those sales that ar...,False
6,sphere of application. The exclusions are of t...,False
7,"for which the goods were purchased, those base...",False
8,those based on the kinds of goods sold.1\nCons...,True
9,"Consumer sales\n2. According to Art. 2 (a), a ...",True


In [53]:
pair_ds.push_to_hub('liyongsea/un_linebreak-5000', token="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", )

Pushing split train to the Hub.
Resuming upload of the dataset shards.


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Pushing split test to the Hub.
Resuming upload of the dataset shards.


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
pair_ds['train'][100]

{'text': 'basis, in accordance with accepted United Nations procedures and practices, in order\nto provide specialized human and technical resources for emergency relief and',
 'label': False}

In [15]:
tokenized_datasets = pair_ds.map(tokenize_function, batched=True)

Map:   0%|          | 0/84049 [00:00<?, ? examples/s]

Map:   0%|          | 0/84617 [00:00<?, ? examples/s]

In [17]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 84049
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 84617
    })
})

In [16]:
import numpy as np
len(tokenized_datasets['train'][0]['input_ids'])

512

In [17]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [18]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir="test_trainer")

In [19]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [20]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [21]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")

In [22]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [23]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `DistilBertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `DistilBertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 375
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mlijia0765[0m ([33mmnbvc[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666880299987194, max=1.0)…

Epoch,Training Loss,Validation Loss



KeyboardInterrupt



In [72]:
1

1

In [75]:
trainer.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 8

KeyboardInterrupt

