In [1]:
import os
# os.environ['CUDA_VISIBLE_DEVICES'] = '3,0,1,2,4,5'
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3'

import numpy as np
import pandas as pd
import torch

from tqdm.auto import tqdm

BATCH_SIZE = 96

MIN_LONG_PROMPT_TOKEN_LENGTH = 50
MAX_PROMPT_TOKEN_LENGTH = 150
MAX_PROMPT_LENGTH = 1500
MAX_TAG_LENGTH = 100

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

In [3]:
from prompt_datasets import PromptDataset, MultipleDataset

In [4]:
# # MODEL_PATH = 'crumb/bloom-560m-RLHF-SD2-prompter'
# MODEL_PATH = 'weight/bloom-560m-shuffle'

# tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
# tokenizer.padding_side = 'right'

In [5]:
# MODEL_PATH = 'FredZhang7/distilgpt2-stable-diffusion-v2'
MODEL_PATH = 'weight/distilgpt2-extend/'

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
tokenizer.pad_token = '\x7f'
tokenizer.pad_token_id = tokenizer('\x7f').input_ids[0]

In [6]:
def preprocess(df):
    
    df.fillna('', inplace=True)
    
    df['negative_prompt'] = [', '.join([j.strip() for j in i.split(',') if 'negative' not in j]) for i in df['negative_prompt']]
    
    df['positive_length'] = df['positive_prompt'].str.len()
    df['negative_length'] = df['negative_prompt'].str.len()
    df.query(f'positive_length < {MAX_PROMPT_LENGTH} and negative_length < {MAX_PROMPT_LENGTH}', inplace=True)
    
    df['positive_prompt'] = [i if (type(i) == str and max(map(len, i.split(','))) < MAX_TAG_LENGTH) else None for i in df['positive_prompt']]
    df['negative_prompt'] = [i if (type(i) == str and max(map(len, i.split(','))) < MAX_TAG_LENGTH) else None for i in df['negative_prompt']]
    
    df.drop_duplicates(['positive_prompt', 'negative_prompt'], inplace=True)

## extend dataset

In [7]:
prompt_file_paths = [
    # '../dataset/nonredundant-laion2B_aesthetic.tsv',
    # '../dataset/nonredundant-midjourney_prompts.tsv',
    '../dataset/nonredundant-dalle_captions.tsv',
    '../dataset/nonredundant-dalle_chatgpt_prompts.tsv',
    '../dataset/nonredundant-dalle_discord_prompts.tsv',
    '../dataset/nonredundant-midjourney_prompts-paired.tsv',
    '../dataset/long-laion2B-en-aesthetic.tsv',
]

In [8]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [9]:
preprocess(prompts)
prompts.fillna('', inplace=True)

  df.fillna('', inplace=True)


In [10]:
extend_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.25, max_shuffle=2, p_cut=0.1, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH, overflow_method='split')

  0%|          | 0/1286753 [00:00<?, ?it/s]

In [11]:
# samples = list()
# lengths = list()
# for tokens, is_positive in tqdm(extend_dataset.samples):
#     length = len(tokens)
    
#     if length < 15:
#         p = ((length / 25) ** 3) * 0.02
#     elif length < 40:
#         p = ((length - 10) / 30) ** 2 * 0.4

# #     if length < 25:
# #         p = ((length / 25) ** 2) * 0.1
# #     elif length < 50:
# #         p = ((length - 25) / 25) ** 0.5 * 0.5
        
#     else:
#         p = 1
#     if np.random.rand() < p:
#         lengths.append(length)
#         samples.append((tokens, is_positive))
        
# len(samples)

In [12]:
# from matplotlib import pyplot
# _ = pyplot.hist(lengths[::100], bins=100)

In [13]:
# sampled_aesthetic = np.array([i[0].numpy().astype('int32') for i in samples], dtype='object')
# np.save('sampled_aesthetic.npy', sampled_aesthetic)

In [14]:
# sampled_midjourney = np.array([i[0].numpy().astype('int32') for i in samples], dtype='object')
# np.save('sampled_midjourney.npy', sampled_midjourney)

In [15]:
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('../dataset/sampled_aesthetic.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]

In [16]:
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('../dataset/sampled_midjourney.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]

## dataset

In [17]:
prompt_file_paths = [
    '../dataset/nonredundant-civitai_prompts.tsv',
    '../dataset/nonredundant-discord_prompts.tsv',
    '../dataset/nonredundant-leonardo_prompts.tsv',
    '../dataset/nonredundant-lexica_prompts-train.tsv',
    '../dataset/nonredundant-lexica_prompts-eval.tsv'
]

In [18]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [19]:
preprocess(prompts)
prompts.fillna('', inplace=True)

  df.fillna('', inplace=True)


In [20]:
dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0.2, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH, overflow_method='split')

  0%|          | 0/2987064 [00:00<?, ?it/s]

## long dataset

In [21]:
long_dataset = PromptDataset(prompts.iloc[:0], tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0., max_prompt_length=MAX_PROMPT_TOKEN_LENGTH, overflow_method='split')

0it [00:00, ?it/s]

In [22]:
long_dataset.samples = [(p, n) for p, n in dataset.samples if p.shape[0] > MIN_LONG_PROMPT_TOKEN_LENGTH]
long_dataset.samples = [(p, n) for p, n in extend_dataset.samples if p.shape[0] > MIN_LONG_PROMPT_TOKEN_LENGTH]

## merge & split dataset

In [23]:
merged_dataset = MultipleDataset([dataset, extend_dataset, long_dataset], probabilities=[0.3, 0.3, 0.4])
len(dataset), len(extend_dataset), len(long_dataset), len(merged_dataset)

(3352184, 5755983, 2003619, 5755983)

In [24]:
len_val_set = int(len(merged_dataset) * 0.001)
train_set, val_set = torch.utils.data.random_split(merged_dataset, (len(merged_dataset) - len_val_set, len_val_set))

# train

In [25]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [26]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)

In [27]:
args = TrainingArguments(
    output_dir="GPT2-extend",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_steps=1_000,
    gradient_accumulation_steps=4,
    num_train_epochs=4,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=1e-4,
    save_steps=5_000,
    fp16=True,
    
    push_to_hub=False,
    dataloader_drop_last=True,
    # dataloader_num_workers=8,
    # group_by_length=True,
    
    evaluation_strategy="steps",
    eval_steps=1_000,
    do_eval=True,
)

In [28]:
trainer = Trainer(
    model=model,                                                                                                 
    
    tokenizer=dataset.tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=train_set,
    eval_dataset=val_set,
)

In [None]:
trainer.train()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
1000,2.2667,2.166469
2000,2.2946,2.160431
3000,2.2973,2.174053


# save

In [None]:
dataset.tokenizer.save_pretrained('weight/distilgpt2-extend')

In [None]:
model.save_pretrained('weight/distilgpt2-extend')

In [None]:
from matplotlib import pyplot

In [None]:
_ = pyplot.hist([len(train_set[i]['input_ids']) for i in range(0, len(train_set), 100)], bins=100)