In [1]:
import pandas as pd
from datasets import Dataset,load_dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [3]:
# use pandas to read simCSE-wiki.txt
wiki_text_file = 'wiki1m_for_simcse.txt'
wiki = pd.read_csv(wiki_text_file,sep = '\t',header = None)
print(wiki)
wiki.columns = ['text']
# use Dataset.from_pandas to convert panda dataframe to hugging face dataset
wiki_dataset = Dataset.from_pandas(wiki,split= "train")
wiki_dataset

                                                        0
0                                 YMCA in South Australia
1       South Australia (SA)  has a unique position in...
2       The compound of philosophical radicalism, evan...
3       It was into this social setting that in Februa...
4       for apprentices and others, after their day's ...
...                                                   ...
995442                                  Rubaschow: Roman.
995443                 Typoskript, März 1940, 326 pages."
995444  He deemed the discovery important because ""Da...
995445  In 2018, he reported that Elsinor Verlag (publ...
995446  He also reported a new English translation to ...

[995447 rows x 1 columns]


Dataset({
    features: ['text'],
    num_rows: 995447
})

In [4]:
def prepare_features(examples):
    
    total = len(examples['text'])
    # total = batch_size
    
    # Avoid "None" fields 
    for idx in range(total):
        if examples['text'][idx] is None:
            examples['text'][idx] = " "
        if examples['text'][idx] is None:
            examples['text'][idx] = " "

    sentences = examples['text'] + examples['text']

    # set max_length here:
    sent_features = tokenizer(sentences, max_length=32, truncation=True, padding="max_length")

    features = {}
    for key in sent_features:
        features[key] = [[sent_features[key][i], sent_features[key][i+total]] for i in range(total)]
    
    return features

In [5]:
train_dataset = wiki_dataset.map(prepare_features,batched=True, remove_columns=['text'], batch_size=4000) 
#apply the prepare_features function to the entire dataset:
#take about 15 min

Map: 100%|████████████████████████████████████████████████████████████| 995447/995447 [13:26<00:00, 1234.17 examples/s]


In [7]:
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask'])

In [8]:
train_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 995447
})

In [9]:
train_dataset.save_to_disk("wiki_for_sts_32")

Saving the dataset (1/1 shards): 100%|█████████████████████████████| 995447/995447 [00:00<00:00, 2141661.34 examples/s]
