In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,2,3,4,5,7'

import numpy as np
import pandas as pd
import torch

from tqdm.auto import tqdm

BATCH_SIZE = 56
MAX_PROMPT_LENGTH = 150

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

In [3]:
from prompt_datasets import PromptDataset, MultipleDataset

In [4]:
# # MODEL_PATH = 'crumb/bloom-560m-RLHF-SD2-prompter'
# MODEL_PATH = 'weight/bloom-560m-shuffle'

# tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
# tokenizer.padding_side = 'right'

In [5]:
# MODEL_PATH = 'FredZhang7/distilgpt2-stable-diffusion-v2'
MODEL_PATH = 'weight/distilgpt2-long/'

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
tokenizer.pad_token = '\x7f'
tokenizer.pad_token_id = tokenizer('\x7f').input_ids[0]

## extend dataset

In [6]:
prompt_file_paths = [
    # '../dataset/nonredundant-laion2B_aesthetic.tsv'
    # '../dataset/nonredundant-midjourney_prompts.tsv'
    '../dataset/nonredundant-dalle_captions.tsv',
    '../dataset/nonredundant-dalle_chatgpt_prompts.tsv',
    '../dataset/nonredundant-dalle_discord_prompts.tsv',
    '../dataset/nonredundant-midjourney_prompts-paired.tsv'
]

In [7]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [8]:
extend_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.25, max_shuffle=2, p_cut=0.1, max_prompt_length=MAX_PROMPT_LENGTH)

  0%|          | 0/1321952 [00:00<?, ?it/s]

In [9]:
# samples = list()
# lengths = list()
# for tokens, is_positive in tqdm(extend_dataset.samples):
#     length = len(tokens)
    
#     if length < 25:
#         p = ((length / 25) ** 3) * 0.2
#     elif length < 50:
#         p = ((length - 25) / 25) ** 0.75

#     # if length < 25:
#     #     p = ((length / 25) ** 2) * 0.1
#     # elif length < 50:
#     #     p = ((length - 25) / 25) ** 0.5
        
#     else:
#         p = 1
#     if np.random.rand() < p:
#         lengths.append(length)
#         samples.append((tokens, is_positive))
        
# len(samples)

In [10]:
# sampled_aesthetic = np.array([i[0].numpy().astype('int32') for i in samples], dtype='object')
# np.save('sampled_aesthetic.npy', sampled_aesthetic)

In [11]:
# sampled_midjourney = np.array([i[0].numpy().astype('int32') for i in samples], dtype='object')
# np.save('sampled_midjourney.npy', sampled_midjourney)

In [12]:
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('sampled_aesthetic-extend.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]

In [13]:
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('sampled_midjourney-extend.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]

## dataset

In [6]:
prompt_file_paths = [
    # '../dataset/nonredundant-civitai_prompts.tsv',
    # '../dataset/nonredundant-discord_prompts.tsv',
    '../dataset/nonredundant-lexica_prompts-train.tsv',
    '../dataset/nonredundant-lexica_prompts-eval.tsv'
]

In [7]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [8]:
dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0.2, max_prompt_length=MAX_PROMPT_LENGTH, overflow_method='split')

  0%|          | 0/67789 [00:00<?, ?it/s]

## long dataset

In [17]:
prompt_file_paths = [
    '../dataset/long-laion2B-en-aesthetic.tsv',
    '../dataset/long-midjourney_prompts.tsv',
    '../dataset/long-midjourney_prompts-2.tsv',
    '../dataset/nonredundant-leonardo_prompts.tsv',
]

In [18]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [19]:
long_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.25, max_shuffle=1, p_cut=0.2, max_prompt_length=MAX_PROMPT_LENGTH, overflow_method='split')

  0%|          | 0/981564 [00:00<?, ?it/s]

## merge & split dataset

In [20]:
merged_dataset = MultipleDataset([dataset, extend_dataset, long_dataset], probabilities=[0.6, 0.2, 0.2])
len(dataset), len(extend_dataset), len(long_dataset), len(merged_dataset)

(3549578, 11391945, 1536356, 11391945)

In [21]:
len_val_set = int(len(merged_dataset) * 0.001)
train_set, val_set = torch.utils.data.random_split(merged_dataset, (len(merged_dataset) - len_val_set, len_val_set))

# train

In [22]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [23]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)

In [24]:
args = TrainingArguments(
    output_dir="GPT2-extend",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_steps=1_000,
    gradient_accumulation_steps=1,
    num_train_epochs=2,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=2e-4,
    save_steps=5_000,
    fp16=True,
    
    push_to_hub=False,
    dataloader_drop_last=True,
    # dataloader_num_workers=8,
    # group_by_length=True,
    
    evaluation_strategy="steps",
    eval_steps=1_000,
    do_eval=True,
)

In [25]:
trainer = Trainer(
    model=model,
    tokenizer=dataset.tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=train_set,
    eval_dataset=val_set,
)

In [None]:
trainer.train()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
1000,2.0328,1.975594
2000,2.1339,2.015074
3000,2.1534,2.017757
4000,2.1544,2.046337
5000,2.1597,2.03652




# save

In [27]:
dataset.tokenizer.save_pretrained('weight/distilgpt2-long')

('weight/distilgpt2-long/tokenizer_config.json',
 'weight/distilgpt2-long/special_tokens_map.json',
 'weight/distilgpt2-long/vocab.json',
 'weight/distilgpt2-long/merges.txt',
 'weight/distilgpt2-long/added_tokens.json',
 'weight/distilgpt2-long/tokenizer.json')

In [28]:
model.save_pretrained('weight/distilgpt2-long')