#### Downloading the dataset from huggingface hub

In [1]:
from datasets import load_dataset

ds = load_dataset("daspartho/stable-diffusion-prompts", split="train")
ds = ds.train_test_split(test_size=0.1, shuffle=True)
ds

Downloading readme:   0%|          | 0.00/426 [00:00<?, ?B/s]

Using custom data configuration daspartho--stable-diffusion-prompts-71a447bb593151ef


Downloading and preparing dataset None/None to /home/.cache/huggingface/datasets/daspartho___parquet/daspartho--stable-diffusion-prompts-71a447bb593151ef/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/102M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1819808 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /home/.cache/huggingface/datasets/daspartho___parquet/daspartho--stable-diffusion-prompts-71a447bb593151ef/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


DatasetDict({
    train: Dataset({
        features: ['prompt'],
        num_rows: 1637827
    })
    test: Dataset({
        features: ['prompt'],
        num_rows: 181981
    })
})

#### Tokenizing the dataset

In [2]:
from transformers import AutoTokenizer

context_length = 128
tokenizer = AutoTokenizer.from_pretrained('daspartho/prompt-tokenizer')

def tokenize(element):
    return tokenizer(
        element["prompt"],
        truncation=True,
        max_length=context_length,
    )

tok_ds = ds.map(
    tokenize, 
    batched=True,
)
tok_ds

Downloading:   0%|          | 0.00/255 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/837k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.20M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

  0%|          | 0/1638 [00:00<?, ?ba/s]

  0%|          | 0/182 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['prompt', 'input_ids', 'attention_mask'],
        num_rows: 1637827
    })
    test: Dataset({
        features: ['prompt', 'input_ids', 'attention_mask'],
        num_rows: 181981
    })
})

#### Initializing the model

In [3]:
from transformers import AutoConfig, GPT2LMHeadModel

config = AutoConfig.from_pretrained(
    'gpt2',
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

model = GPT2LMHeadModel(config)

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

#### Set up a data collator to take care of creating the batches

In [4]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(
    tokenizer, 
    mlm=False
)

#### Training time!

In [5]:
from transformers import Trainer, TrainingArguments

bs = 128
epochs = 5
lr = 1e-4

args = TrainingArguments(
    output_dir="prompt-extend",
    per_device_train_batch_size=bs,
    per_device_eval_batch_size=bs*2,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=epochs,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    learning_rate=lr,
    fp16=True,
    report_to='none',
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=tok_ds["train"],
    eval_dataset=tok_ds["test"],
)

trainer.train()

Cloning https://huggingface.co/daspartho/prompt-extend into local empty directory.


Download file pytorch_model.bin:   0%|          | 16.0k/492M [00:00<?, ?B/s]

Download file training_args.bin: 100%|##########| 3.31k/3.31k [00:00<?, ?B/s]

Clean file training_args.bin:  30%|###       | 1.00k/3.31k [00:00<?, ?B/s]

Clean file pytorch_model.bin:   0%|          | 1.00k/492M [00:00<?, ?B/s]

Using cuda_amp half precision backend
The following columns in the training set don't have a corresponding argument in `GPT2LMHeadModel.forward` and have been ignored: prompt. If prompt are not expected by `GPT2LMHeadModel.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1637827
  Num Epochs = 5
  Instantaneous batch size per device = 128
  Total train batch size (w. parallel, distributed & accumulation) = 128
  Gradient Accumulation steps = 1
  Total optimization steps = 63980
  Number of trainable parameters = 125778432
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,3.7436,2.542871
2,2.3292,2.071066
3,1.9439,1.844723
4,1.7059,1.732472
5,1.5775,1.710982


The following columns in the evaluation set don't have a corresponding argument in `GPT2LMHeadModel.forward` and have been ignored: prompt. If prompt are not expected by `GPT2LMHeadModel.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 181981
  Batch size = 256
Saving model checkpoint to prompt-extend/checkpoint-12796
Configuration saved in prompt-extend/checkpoint-12796/config.json
Model weights saved in prompt-extend/checkpoint-12796/pytorch_model.bin
tokenizer config file saved in prompt-extend/checkpoint-12796/tokenizer_config.json
Special tokens file saved in prompt-extend/checkpoint-12796/special_tokens_map.json
tokenizer config file saved in prompt-extend/tokenizer_config.json
Special tokens file saved in prompt-extend/special_tokens_map.json
The following columns in the evaluation set don't have a corresponding argument in `GPT2LMHeadModel.forward` and have been ignored: prompt. If prompt are not expected by `GPT2LMHeadModel.forward

TrainOutput(global_step=63980, training_loss=2.260038171108159, metrics={'train_runtime': 15829.7899, 'train_samples_per_second': 517.324, 'train_steps_per_second': 4.042, 'total_flos': 4.25950644615552e+17, 'train_loss': 2.260038171108159, 'epoch': 5.0})

#### Let's try it out

In [6]:
from transformers import TextGenerationPipeline

text_pipe = TextGenerationPipeline(
    model=model, 
    tokenizer=tokenizer,
    device=0,
)

prompt = "munchkin village house"
extended_prompt = text_pipe(prompt+',', num_return_sequences=1)[0]["generated_text"]
extended_prompt

Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


'munchkin village house, by thomas kinkade, trending on artstation, photorealistic, wild vegetation, overgrown'

#### Push the model to Hub

In [7]:
trainer.push_to_hub()

Saving model checkpoint to prompt-extend
Configuration saved in prompt-extend/config.json
Model weights saved in prompt-extend/pytorch_model.bin
tokenizer config file saved in prompt-extend/tokenizer_config.json
Special tokens file saved in prompt-extend/special_tokens_map.json
Several commits (2) will be pushed upstream.
W1216 09:36:36.967799 139777702393664 repository.py:1182] Several commits (2) will be pushed upstream.
The progress bars may be unreliable.
W1216 09:36:36.969866 139777702393664 repository.py:1186] The progress bars may be unreliable.


Upload file pytorch_model.bin:   0%|          | 32.0k/492M [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/daspartho/prompt-extend
   dbdc999..1d9c127  main -> main

W1216 09:38:00.809169 139777702393664 repository.py:1204] remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/daspartho/prompt-extend
   dbdc999..1d9c127  main -> main

Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}
To https://huggingface.co/daspartho/prompt-extend
   1d9c127..776d554  main -> main

W1216 09:38:21.126441 139777702393664 repository.py:1204] To https://huggingface.co/daspartho/prompt-extend
   1d9c127..776d554  main -> main



'https://huggingface.co/daspartho/prompt-extend/commit/1d9c127412ec2965c2d43ff41e157dc4327bd0f7'