In [23]:
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from datasets import load_dataset
import tqdm
from transformers import GPT2TokenizerFast, BertTokenizerFast
import re
import tiktoken

In [27]:
data = load_dataset('stanfordnlp/snli')

In [26]:
tok = BertTokenizerFast.from_pretrained('google-bert/bert-base-uncased')

In [None]:
def tokenize_function(examples):
    return tok("" + examples["premise"] + "HYP" + examples['hypothesis'] + "", padding=False, return_tensors='pt')

In [55]:
test_tok = data['test'].map(tokenize_function)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [56]:
test_tok[0]

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.',
 'hypothesis': 'The church has cracks in the ceiling.',
 'label': 1,
 'input_ids': [[101,
   2023,
   2277,
   6596,
   10955,
   2000,
   1996,
   11678,
   2004,
   2027,
   6170,
   6569,
   3560,
   2774,
   2013,
   1996,
   2338,
   2012,
   1037,
   2277,
   1012,
   1996,
   2277,
   2038,
   15288,
   1999,
   1996,
   5894,
   1012,
   102]],
 'token_type_ids': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0]],
 'attention_mask': [[1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1]]}

In [38]:
from torch.utils.data import DataLoader

In [None]:
def col_fn(x):
    ids, lab = [], []
    for i in x:
        ids.append(i['input_ids'])
        lab.append(i['label'])
    return torch.tensor(ids), torch.tensor(lab)

In [52]:
DL = DataLoader(test_tok, batch_size=32, shuffle=True, collate_fn=col_fn)

In [53]:
for x in DL:
    print(x)
    break

[{'premise': 'A man is celebrating his victory while smiling and shooting champagne in the air with his teammate.', 'hypothesis': 'A man is celebrating his victory while smiling and shooting champagne in the air', 'label': 0, 'input_ids': [101, 1037, 2158, 2003, 12964, 2010, 3377, 2096, 5629, 1998, 5008, 12327, 1999, 1996, 2250, 2007, 2010, 10809, 1012, 1037, 2158, 2003, 12964, 2010, 3377, 2096, 5629, 1998, 5008, 12327, 1999, 1996, 2250, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}, {'premise': 'An emergency worker directs a man pulling a sled with emergency equipment on a snowy path.', 'hypothesis': 'An emergency worker works at a crash scene.', 'label': 1, 'input_ids': [101, 2019, 5057, 7309, 23303, 1037, 2158, 4815, 1037, 22889, 2098, 2007, 5057, 3941, 2006, 1037, 20981, 4130, 1012

TypeError: list indices must be integers or slices, not str

In [21]:
# data = load_dataset('Salesforce/wikitext', 'wikitext-103-raw-v1')
# data1 = load_dataset('Skylion007/openwebtext')
data = load_dataset('bookcorpus/bookcorpus')


In [40]:
torch.cuda.is_bf16_supported()

True

In [22]:
data

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 74004228
    })
})

In [23]:
data['train'].__getitem__(3)

{'text': "he 'd seen the movie almost by mistake , considering he was a little young for the pg cartoon , but with older cousins , along with her brothers , mason was often exposed to things that were older ."}

In [24]:
train = data['train']
# validation = data['validation']
# test = data['test']
# validation_owt = data1['train']['text'][int(0.8*len(data1['train']['text']))+1:int(0.9*len(data1['train']['text']))]
# test_owt = data1['train']['text'][int(0.9*len(data1['train']['text']))+1:]
# train_owt = data1['train']['text'][:int(0.8*len(data1['train']['text']))]

In [26]:
train

Dataset({
    features: ['text'],
    num_rows: 74004228
})

In [27]:
bos_token = "<|BOS|>"

In [None]:

def data_clean(input: list[str], seq_len=135) -> str:
    ret = ""
    for line in input:
        if len(line) == 0:  continue
        # remove @'s surrounding some characters
        line = re.sub(r' @([.,\-])@ ', r'\1', line)
        # find titles of articles and add bos_token
        matches = re.match(r'^ = ?(.+?) =?\n', line)    # this finds all title and subsection text
        if matches != None:
            c = line.count('=')
            if c == 2:
                # start new article
                ret += " " + bos_token
        ret += line

    ret = ret.split(" ")
    chunks = []
    curr_chunk = []
    cur_len = 0
    
    for word in ret:
        if cur_len > seq_len:
            chunks.append(" ". join(curr_chunk))
            curr_chunk = [word]
            cur_len = 1

        else:
            curr_chunk.append(word)
            cur_len += 1

    return chunks

In [39]:
tokenizer = tiktoken.get_encoding('gpt2')

In [None]:
train_join = data_clean(train['text']) + data_clean(train_owt)
# val_join = data_clean(validation['text']) + data_clean(validation_owt)
# test_join = data_clean(test['text']) + data_clean(test_owt)

In [29]:
def token_all(ex):
    return tokenizer(ex['text'], truncation=False, padding=False)

In [32]:
tok_dataset = train.map(token_all, batched=True, remove_columns=train.column_names)

Map:   0%|          | 0/74004228 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1119 > 1024). Running this sequence through the model will result in indexing errors


KeyboardInterrupt: 

In [None]:
class CombinedSentenceDataset(torch.utils.data.Dataset):
    def __init__(self, tokenized_dataset, max_length=128):
        self.tokenized_dataset = tokenized_dataset
        self.max_length = max_length
        self.samples = self._create_samples()

    def _create_samples(self):
        samples = []
        current_sample = []
        current_length = 0

        for sentence in self.tokenized_dataset["input_ids"]:
            sentence_length = len(sentence)

            if current_length + sentence_length <= self.max_length:
                current_sample.extend(sentence)
                current_length += sentence_length
            else:
                if current_length > 0:
                    # Pad the current sample if needed
                    padding_length = self.max_length - current_length
                    current_sample.extend([tokenizer.pad_token_id] * padding_length)
                    samples.append(current_sample)

                # Start a new sample with the current sentence
                current_sample = sentence[:self.max_length]
                current_length = min(sentence_length, self.max_length)

            # If we've reached exactly max_length, add the sample and reset
            if current_length == self.max_length:
                samples.append(current_sample)
                current_sample = []
                current_length = 0

        # Add the last sample if it's not empty
        if current_length > 0:
            padding_length = self.max_length - current_length
            current_sample.extend([tokenizer.pad_token_id] * padding_length)
            samples.append(current_sample)

        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.samples[idx])
        attention_mask = (input_ids != tokenizer.pad_token_id).long()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask
        }

# Create the combined sentence dataset
combined_dataset = CombinedSentenceDataset(tokenized_dataset)


In [None]:
train_tok = [tokenizer(chunk, truncation=False, return_tensors='pt')['input_ids'] for chunk in tqdm.tqdm(train)]
# val_tok = [tokenizer(chunk, max_length=129, truncation=True, return_tensors='pt')['input_ids'] for chunk in val_join]
# test_tok = [tokenizer(chunk, max_length=129, truncation=True, return_tensors='pt')['input_ids'] for chunk in test_join]

100%|██████████| 737084/737084 [05:41<00:00, 2156.80it/s]


In [None]:
torch.save(train_tok, 'data/train_data_bookcorp.pt')
# torch.save(val_tok, 'data/val_data_token_owt.pt')
# torch.save(test_tok, 'data/test_data_token_owt.pt')

In [None]:
len(train_tok)

737084