In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3,0,1,2,4,5'

import numpy as np
import pandas as pd
import torch

from tqdm.auto import tqdm

BATCH_SIZE = 64
MIN_LONG_PROMPT_LENGTH = 150
MIN_LONG_PROMPT_TOKEN_LENGTH = 50
MAX_PROMPT_TOKEN_LENGTH = 150
MAX_PROMPT_LENGTH = 1500
MAX_TAG_LENGTH = 100

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments
from transformers import DataCollatorForLanguageModeling

# from peft import get_peft_model, LoraConfig, TaskType, prepare_model_for_int8_training

In [3]:
from prompt_datasets import PromptDataset, PairedPromptDataset, MultipleDataset

In [4]:
# # MODEL_PATH = 'crumb/bloom-560m-RLHF-SD2-prompter'
# MODEL_PATH = 'weight/bloom-560m-shuffle'

# tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
# tokenizer.padding_side = 'right'

In [5]:
MODEL_PATH = 'FredZhang7/distilgpt2-stable-diffusion-v2'
# MODEL_PATH = 'weight/distilgpt2-paired'
# MODEL_PATH = 'GPT2-paired/checkpoint-5000/'

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
tokenizer.pad_token = '\x7f'
tokenizer.pad_token_id = tokenizer('\x7f').input_ids[0]

In [20]:
def preprocess(df):
    
    df.fillna('', inplace=True)
    
    df['negative_prompt'] = df['negative_prompt'].str.replace(r'[^,]*negative[^,]*,?\s?', ' ', regex=True).str.replace(r'([^;])$', r'\1;', regex=True)
    
    df['positive_length'] = df['positive_prompt'].str.len()
    df['negative_length'] = df['negative_prompt'].str.len()
    df.query(f'positive_length < {MAX_PROMPT_LENGTH} and negative_length < {MAX_PROMPT_LENGTH}', inplace=True)
    
    df['positive_prompt'] = [i if (type(i) == str and max(map(len, i.split(','))) < MAX_TAG_LENGTH) else None for i in df['positive_prompt']]
    df['negative_prompt'] = [i if (type(i) == str and max(map(len, i.split(','))) < MAX_TAG_LENGTH) else None for i in df['negative_prompt']]
    
    df.drop_duplicates(['positive_prompt', 'negative_prompt'], inplace=True)

## paired

In [8]:
prompt_file_paths = [
    '../dataset/nonredundant-civitai_prompts.tsv',
    '../dataset/nonredundant-midjourney_prompts-paired.tsv',
    '../dataset/nonredundant-leonardo_prompts.tsv',
]

In [9]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)
prompts.dropna(inplace=True)

In [10]:
preprocess(prompts)
prompts.dropna(inplace=True)

In [12]:
paired_dataset = PairedPromptDataset(prompts, tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0.2, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH, overflow_method='split')

  0%|          | 0/1215202 [00:00<?, ?it/s]

## long paired

In [13]:
long_paired_dataset = PairedPromptDataset(prompts.iloc[:0], tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0., min_prompt_length=MIN_LONG_PROMPT_TOKEN_LENGTH, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH, overflow_method='split')

0it [00:00, ?it/s]

In [14]:
long_paired_dataset.samples = [(p, n) for p, n in paired_dataset.samples if p.shape[0] > MIN_LONG_PROMPT_TOKEN_LENGTH and n.shape[0] > MIN_LONG_PROMPT_TOKEN_LENGTH]

## single

In [18]:
prompt_file_paths = [
    '../dataset/nonredundant-discord_prompts.tsv',
    '../dataset/nonredundant-lexica_prompts-train.tsv',
    '../dataset/nonredundant-lexica_prompts-eval.tsv',
    '../dataset/nonredundant-civitai_prompts.tsv',
    '../dataset/nonredundant-leonardo_prompts.tsv'
]

In [19]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [21]:
preprocess(prompts)
prompts.fillna('', inplace=True)

  df.fillna('', inplace=True)


In [22]:
single_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0.2, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH*2-1, overflow_method='split')

  0%|          | 0/2989134 [00:00<?, ?it/s]

# long single

In [23]:
prompt_file_paths = [
    '../dataset/nonredundant-discord_prompts.tsv',
    '../dataset/nonredundant-lexica_prompts-train.tsv',
    '../dataset/nonredundant-lexica_prompts-eval.tsv',
    '../dataset/nonredundant-civitai_prompts.tsv',
    '../dataset/nonredundant-leonardo_prompts.tsv',
    '../dataset/nonredundant-dalle_captions.tsv',
    '../dataset/nonredundant-dalle_chatgpt_prompts.tsv',
    '../dataset/nonredundant-dalle_discord_prompts.tsv',
    '../dataset/nonredundant-midjourney_prompts-paired.tsv',
    '../dataset/long-laion2B-en-aesthetic.tsv',
]

In [24]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [25]:
preprocess(prompts)
prompts.fillna('', inplace=True)

  df.fillna('', inplace=True)


In [26]:
prompts.query(f'(positive_length > {MIN_LONG_PROMPT_LENGTH} or negative_length > {MIN_LONG_PROMPT_LENGTH})', inplace=True)

In [27]:
long_single_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.5, max_shuffle=3, p_cut=0., min_prompt_length=MIN_LONG_PROMPT_TOKEN_LENGTH, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH*2-1, overflow_method='split')

  0%|          | 0/1710446 [00:00<?, ?it/s]

## extend

In [28]:
prompt_file_paths = [
    '../dataset/nonredundant-dalle_captions.tsv',
    '../dataset/nonredundant-dalle_chatgpt_prompts.tsv',
    '../dataset/nonredundant-dalle_discord_prompts.tsv',
    '../dataset/nonredundant-midjourney_prompts-paired.tsv',
]

In [29]:
prompts = list()

for file_path in prompt_file_paths:
    df = pd.read_csv(file_path, sep='\t')
    prompts.append(df)
    
prompts = pd.concat(prompts, axis=0, ignore_index=True)

In [30]:
preprocess(prompts)
prompts.fillna('', inplace=True)

  df.fillna('', inplace=True)


In [31]:
extend_dataset = PromptDataset(prompts, tokenizer, p_shuffle=0.25, max_shuffle=2, p_cut=0.1, max_prompt_length=MAX_PROMPT_TOKEN_LENGTH*2-1, overflow_method='split')

  0%|          | 0/1283329 [00:00<?, ?it/s]

In [32]:
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('sampled_aesthetic.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]
extend_dataset.samples += [(torch.tensor(i), True) for i in np.load('sampled_midjourney.npy', allow_pickle=True) if len(i) < extend_dataset.max_prompt_length]

# merge

In [48]:
merged_dataset = MultipleDataset([paired_dataset, long_paired_dataset, single_dataset, long_single_dataset, extend_dataset], probabilities=[0.15, 0.2, 0.25, 0.3, 0.1])
len(paired_dataset), len(long_paired_dataset), len(single_dataset), len(long_single_dataset), len(extend_dataset), len(merged_dataset)

(1506488, 715428, 3142830, 1256639, 5936744, 5936744)

In [49]:
len_val_set = int(len(merged_dataset) * 0.001)
train_set, val_set = torch.utils.data.random_split(merged_dataset, (len(merged_dataset) - len_val_set, len_val_set))

In [50]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [51]:
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True)

In [52]:
# model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, local_files_only=True, load_in_8bit=True, device_map='auto')
# model = prepare_model_for_int8_training(model)

In [53]:
# peft_config = LoraConfig(
#     task_type=TaskType.CAUSAL_LM, inference_mode=False, r=16, lora_alpha=16, lora_dropout=0.1, bias="all"
# )

In [54]:
# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

In [55]:
args = TrainingArguments(
    output_dir="GPT2-paired",
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    logging_steps=1_000,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    weight_decay=0.1,
    warmup_steps=1_000,
    lr_scheduler_type="cosine",
    learning_rate=5e-4,
    save_steps=5_000,
    fp16=True,
    
    push_to_hub=False,
    
    evaluation_strategy="steps",
    eval_steps=1_000,
    do_eval=True,
)

In [56]:
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=train_set,
    eval_dataset=val_set,
)

In [None]:
trainer.train()

In [58]:
tokenizer.save_pretrained('weight/distilgpt2-paired')

('weight/distilgpt2-paired/tokenizer_config.json',
 'weight/distilgpt2-paired/special_tokens_map.json',
 'weight/distilgpt2-paired/vocab.json',
 'weight/distilgpt2-paired/merges.txt',
 'weight/distilgpt2-paired/added_tokens.json',
 'weight/distilgpt2-paired/tokenizer.json')

In [59]:
model.save_pretrained('weight/distilgpt2-paired')