In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from mingpt.model import GPT
from mingpt.utils import set_seed
from datasets import load_dataset
import torch
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import numpy as np
from scipy.stats import norm
set_seed(3407)

In [15]:
# Load in UL2 tokenizer to see what's going on
from transformers import AutoTokenizer, GPT2Tokenizer

tokenizer = AutoTokenizer.from_pretrained("google/ul2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer

PreTrainedTokenizer(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True)})

In [2]:
dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", 'plain_text', cache_dir='datasets')
dataset = dataset['train']

# Custom dataset class for the Red Pajama dataset
class RedPajamaDataset(Dataset):
    def __init__(self, data, max_length=1024):
        self.data = data
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.add_tokens([f'new_id_{i}' for i in range(200)])
        self.tokenizer.add_tokens(['[S2S]', '[NLU]', '[NLG]'])
        self.tokenizer.pad_token_id = 50256
        self.max_length = max_length
        self.vocab_size = self.tokenizer.vocab_size

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]['text']
        # Tokenize the text
        tokens = self.tokenizer.encode(text, add_special_tokens=True, max_length=self.max_length, truncation=True, return_tensors='pt', padding=True)
        # Split the tokens into chunks of max_length
        # Shift the tokens to get targets (excluding the [CLS] token)
        target_tokens = tokens[:, 1:].clone()  # Exclude the [CLS] token
        tokens = tokens[:, :-1]  # Exclude the last token to match the shifted targets
        return tokens, target_tokens
    

# Create an instance of the custom dataset
red_pajama_dataset = RedPajamaDataset(dataset)

Found cached dataset red_pajama-data-1_t-sample (/Users/dylanskinner/Desktop/CS 674 Projects/MinGPT_UL2/datasets/togethercomputer___red_pajama-data-1_t-sample/plain_text/1.0.0/6ea3bc8ec2e84ec6d2df1930942e9028ace8c5b9d9143823cf911c50bbd92039)


  0%|          | 0/1 [00:00<?, ?it/s]

In [22]:
ff = [0,1,2,3,4,5,6,7,8,9,10]

ff = ff[:3] + ['nice'] + ff[4:]

ff

[0, 1, 2, 'nice', 4, 5, 6, 7, 8, 9, 10]

In [88]:
new_tokens = [f'new_id_{i}' for i in range(200)]
tokenizer.add_tokens(new_tokens)
tokenizer.add_tokens(['[S2S]', '[NLU]', '[NLG]'])

def r_denoising(ids, tokenizer, corruption_pct=0.15, span_length=np.arange(2,6)):
    chance = (corruption_pct / np.mean(span_length)) * (1 + np.max(span_length) / len(ids))
    old_toks = []

    steps_to_skip = 0
    tokens_used = 0

    for i in range(1, len(ids)):
        if steps_to_skip > 0:
            steps_to_skip -= 1
            continue

        if np.random.random() < chance:
            # Get the token we are using for this space
            mask_token = tokenizer.convert_tokens_to_ids(new_tokens[tokens_used])
            tokens_used += 1

            span = np.random.choice(span_length)
            old_toks.append(mask_token)
            old_toks.extend(ids[i:i+span].copy())
            ids = ids[:i] + [mask_token] + ids[i+span:]

            steps_to_skip = span
        
    
    return ids, old_toks

def s_denoising(ids, tokenizer):
    # Get the length of our input
    len_ids = len(ids)

    # Build Gaussian distribution of probabilities for each token
    p = norm.pdf(np.arange(len_ids)/len_ids, loc=np.mean(np.arange(len_ids)/len_ids), scale=np.std(np.arange(len_ids)/len_ids))

    # Normalize the probabilities and get the index: to remove.
    remove_index = np.random.choice(np.arange(len_ids), p=p/p.sum())

    # Get the token we are using for this space
    mask_token = tokenizer.convert_tokens_to_ids(new_tokens[0])

    # Get the tokens we are removing
    old_tok = [mask_token]
    old_tok.extend(ids[remove_index:].copy())

    # Mask the tokens
    ids = ids[:remove_index]
    ids[-1] = mask_token
    return ids, old_tok

def x_denoising(ids, tokenizer, corruption_pct=0.50, span_length=np.arange(2,6)):
    return r_denoising(ids, tokenizer, corruption_pct, span_length)



item = next(iter(dataset))['text']

'''Testing R-Denoising'''
# ids, old_toks = r_denoising(ids, tokenizer, corruption_pct=0.15, span_length=np.arange(2,6))
# old_toks

'''Testing S-Denoising'''
# ids, old_tok = s_denoising(ids, tokenizer)
# tokenizer.decode(ids)

'''Testing X-Denoising'''
# ids, old_toks = x_denoising(ids, tokenizer, corruption_pct=0.50, span_length=np.arange(2,6))
# tokenizer.decode(ids)

'''Testing all at once'''
token_dict = {'s': ['[S2S]', s_denoising], 'r': ['[NLU]', r_denoising], 'x': ['[NLG]', x_denoising]}

# Get the token to prepend and function to use.
begin_id, func = token_dict[np.random.choice(['s', 'r', 'x'], size=1, p=[0.5, 0.25, 0.25])[0]]

# Prepend token to string and tokenize
item = begin_id + ' ' + item
ids = tokenizer(item, truncation=True, max_length=1024)['input_ids']


ids, old_toks = func(ids, tokenizer)
tokenizer.decode(ids)
print('Total non-tokenized', len(ids))
print(f'Total tokenized {len(old_toks)}')
print(f'Total tokens {len(ids) + len(old_toks)}')
print(tokenizer.decode(ids[0]))
tokenizer.decode(old_toks)


# old_toks_dec = [tokenizer.decode(tok) for tok in old_toks]
# old_toks_dec

Total non-tokenized 953
Total tokenized 139
Total tokens 1092
[NLU]


'new_id_0 \\section new_id_1 }\n\\label new_id_2. In\nparticular new_id_3 measures for \\em new_id_4 ~\\cite new_id_5 afus new_id_6,\n  new_id_7  that stems from participants new_id_8  In particular, new_id_9 -scale studies of the new_id_10 2014surveydatas new_id_11 scale longitudinal studies new_id_12 } and new_id_13 Contributions new_id_14  world regions new_id_15 researchquestion} new_id_16 hipres new_id_17 \nWe geol new_id_18 first/last) new_id_19  of North new_id_20.}\n\nA replication new_id_21 \n\\ new_id_22 replication new_id_23  david2008foss new_id_24  Software (FOSS) new_id_25 In 2008 Barah new_id_26 geodiversity} conducted new_id_27  similar to ours new_id_28  (7 years) in new_id_29 ona2008geodiversity new_id_30 new_id_31 new_id_32 new_id_33'

In [71]:
tokenizer.decode(ids)

'[NLU] \\section{Introduction}\n\\label new_id_0 intro}\n\n\\emph{Gender diversity}, or more often its lack thereof, among participants to\nsoftware development activities has been thoroughly studied in recent years. new_id_1 particular, the presence of, effects of, and countermeasures for \\emph{gender\n  bias} in Free/Open new_id_2 OSS) have received a lot of attention\nover the past decade~\\cite{david2008f new_id_3, qiu2010kdewomen new_id_4  nafus2012patches, kuechler2012genderfoss, new_id_5 cu2014gender, new_id_6  oneil2016debiansurvey, robles2016womeninfoss, terrell2017gender,\n  zac new_id_7 2021gender}.  \\em new_id_8 } is on the other hand the\nkind new_id_9  that stems from participants in some global activity coming\nfrom different world regions and cultures.\n\n new_id_10  FOSS has received relatively little attention in scholarly\nworks. In particular, while seminal survey-based and\npoint-in-time medium-scale new_id_11  the geographic origins of FOSS\ncontributors exist n

In [57]:
tokenizer.convert_tokens_to_ids(new_tokens)

[50260,
 50261,
 50262,
 50263,
 50264,
 50265,
 50266,
 50267,
 50268,
 50269,
 50270,
 50271,
 50272,
 50273,
 50274,
 50275,
 50276,
 50277,
 50278,
 50279,
 50280,
 50281,
 50282,
 50283,
 50284,
 50285,
 50286,
 50287,
 50288,
 50289,
 50290,
 50291,
 50292,
 50293,
 50294,
 50295,
 50296,
 50297,
 50298,
 50299,
 50300,
 50301,
 50302,
 50303,
 50304,
 50305,
 50306,
 50307,
 50308,
 50309,
 50310,
 50311,
 50312,
 50313,
 50314,
 50315,
 50316,
 50317,
 50318,
 50319,
 50320,
 50321,
 50322,
 50323,
 50324,
 50325,
 50326,
 50327,
 50328,
 50329,
 50330,
 50331,
 50332,
 50333,
 50334,
 50335,
 50336,
 50337,
 50338,
 50339,
 50340,
 50341,
 50342,
 50343,
 50344,
 50345,
 50346,
 50347,
 50348,
 50349,
 50350,
 50351,
 50352,
 50353,
 50354,
 50355,
 50356,
 50357,
 50358,
 50359,
 50360,
 50361,
 50362,
 50363,
 50364,
 50365,
 50366,
 50367,
 50368,
 50369,
 50370,
 50371,
 50372,
 50373,
 50374,
 50375,
 50376,
 50377,
 50378,
 50379,
 50380,
 50381,
 50382,
 50383,
 50384,
