# Visualise Batch
Check and make sure what's coming out of the data loading pipeline is expected.

In [1]:
import datasets
from data import ImageDatasetPrecoded, gen_buckets, AspectBucketSampler
from transformers import CLIPTokenizer
from diffusers import AutoencoderKL
import random
from torch.utils.data import DataLoader
from collections import defaultdict
import torch
import torchvision.transforms.functional as TVF
from tqdm import tqdm

In [2]:
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
base_revision = "462165984030d82259a11f4367a4eed129e94a7b"
device_batch_size = 16
num_workers = 2

vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae", revision="462165984030d82259a11f4367a4eed129e94a7b", torch_dtype=torch.float32, use_safetensors=True)
assert isinstance(vae, AutoencoderKL)

tokenizer = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer", revision=base_revision, use_fast=False)
tokenizer_2 = CLIPTokenizer.from_pretrained(base_model, subfolder="tokenizer_2", revision=base_revision, use_fast=False)
source_ds = datasets.load_dataset("parquet", data_files="../data/dataset.parquet")
assert isinstance(source_ds, datasets.DatasetDict)
source_ds = source_ds['train'].train_test_split(test_size=2048, seed=42)

train_ds = ImageDatasetPrecoded(source_ds['train'], tokenizer, tokenizer_2, datapath='../data/vaes')
test_ds = ImageDatasetPrecoded(source_ds['test'], tokenizer, tokenizer_2, datapath='../data/vaes')

train_buckets = gen_buckets(source_ds['train'])
validation_buckets = gen_buckets(source_ds['test'])

print("Aspect ratio buckets:")
#sorted_buckets = sorted(train_buckets, key=lambda x: (x.resolution, x.n_chunks))
for bucket in sorted(train_buckets, key=lambda x: (x.resolution, x.n_chunks)):
	print(f"{bucket.resolution}, {bucket.n_chunks}: {len(bucket.images)}")
print()


train_sampler = AspectBucketSampler(dataset=train_ds, buckets=train_buckets, batch_size=device_batch_size, num_replicas=1, rank=0, shuffle=True, ragged_batches=False)
validation_sampler = AspectBucketSampler(dataset=test_ds, buckets=validation_buckets, batch_size=device_batch_size, num_replicas=1, rank=0, shuffle=False, ragged_batches=True)

train_dataloader = DataLoader(
	train_ds,
	batch_sampler=train_sampler,
	num_workers=num_workers,
	collate_fn=train_ds.collate_fn,
)

validation_dataloader = DataLoader(
	test_ds,
	batch_sampler=validation_sampler,
	num_workers=num_workers,
	collate_fn=test_ds.collate_fn,
)

Generating train split: 0 examples [00:00, ? examples/s]

6714713it [00:03, 1812288.43it/s]
2048it [00:00, 1516852.30it/s]

Aspect ratio buckets:
(72, 208), 1: 1854
(72, 208), 2: 1114
(72, 208), 3: 31
(72, 216), 1: 2892
(72, 216), 2: 1778
(72, 216), 3: 48
(80, 192), 1: 10092
(80, 192), 2: 6237
(80, 192), 3: 156
(80, 200), 1: 3007
(80, 200), 2: 1771
(80, 200), 3: 57
(88, 168), 1: 19090
(88, 168), 2: 12283
(88, 168), 3: 335
(88, 176), 1: 7424
(88, 176), 2: 4665
(88, 176), 3: 115
(88, 184), 1: 10869
(88, 184), 2: 6947
(88, 184), 3: 175
(96, 160), 1: 35773
(96, 160), 2: 22332
(96, 160), 3: 623
(96, 168), 1: 111379
(96, 168), 2: 71599
(96, 168), 3: 2111
(104, 144), 1: 42279
(104, 144), 2: 26993
(104, 144), 3: 838
(104, 152), 1: 1328367
(104, 152), 2: 829638
(104, 152), 3: 21897
(112, 136), 1: 14660
(112, 136), 2: 9677
(112, 136), 3: 255
(112, 144), 1: 260286
(112, 144), 2: 165814
(112, 144), 3: 4543
(120, 128), 1: 9198
(120, 128), 2: 6372
(120, 128), 3: 175
(120, 136), 1: 10361
(120, 136), 2: 6685
(120, 136), 3: 183
(128, 120), 1: 10389
(128, 120), 2: 7273
(128, 120), 3: 224
(128, 128), 1: 60169
(128, 128), 2: 4




In [7]:
foo_res = defaultdict(int)

for bucket in sorted(train_buckets, key=lambda x: (x.resolution, x.n_chunks)):
	width = bucket.resolution[1] * 8
	height = bucket.resolution[0] * 8
	foo_res[(width, height)] += len(bucket.images)

for resolution, count in sorted(foo_res.items(), key=lambda x: x[1], reverse=True):
	print(f"{resolution[0]}x{resolution[1]}: {count}")

832x1216: 2229287
1216x832: 2179902
832x1152: 762149
1152x896: 430643
896x1152: 198820
1344x768: 185089
768x1344: 145989
1024x1024: 102374
1152x832: 70110
1280x768: 58728
768x1280: 42345
896x1088: 40613
1344x704: 31708
704x1344: 31163
704x1472: 27365
960x1088: 26303
1088x896: 24592
1472x704: 17991
960x1024: 17886
1088x960: 17229
1536x640: 16485
1024x960: 15745
704x1408: 14188
1408x704: 12204
1600x640: 4835
1728x576: 4718
1664x576: 2999
640x1536: 1827
640x1600: 635
576x1664: 456
576x1728: 335


In [9]:
IMPORTANT_TAGS = set(['watermark'])


def build_prompt_from_tags(tag_string: str, n_tags: int) -> str:
	# Split tag string into tags
	# Tags are shuffled, important tags are always included, and the number of tags is limited by n_tags
	tags = set(tag.strip() for tag in tag_string.split(",") if tag.strip())
	important_tags = tags.intersection(IMPORTANT_TAGS)
	n_tags = min(max(n_tags, len(important_tags)), len(tags))
	tags = list(important_tags) + random.sample(list(tags - important_tags), n_tags - len(important_tags))
	assert len(tags) <= n_tags, f"Expected {n_tags} tags, got {len(tags)}"
	random.shuffle(tags)

	# Prompt construction
	tag_type = random.randint(0, 2)   # Use underscores, spaces, or mixed

	prompt = ""
	for tag in tags:
		# Regularize across tags with spaces or underscores, or mixed.
		if tag_type == 1:
			tag = tag.replace("_", " ")
		elif tag_type == 2:
			if random.random() < 0.8:
				tag = tag.replace("_", " ")
		
		if len(prompt) > 0:
			prompt += ","
			# Space between most times
			# NOTE: I don't think this matters because CLIP tokenizer ignores spaces?
			if random.random() < 0.8:
				prompt += ' '
			prompt += tag
		else:
			prompt += tag
			
	return prompt


tag_strings = list(source_ds['train']['tag_string'])
random.shuffle(tag_strings)
tag_strings = tag_strings[:100000]

for n_tags in [8, 16, 32, 64]:
	counts = []
	for row in tqdm(tag_strings):
		prompt = build_prompt_from_tags(row, n_tags)
		foo = tokenizer.encode(prompt, add_special_tokens=False, padding=False)
		counts.append(len(foo))

	print(n_tags, sum(counts) / len(counts))

100%|██████████| 100000/100000 [00:13<00:00, 7273.89it/s]


8 21.3453


 22%|██▏       | 22372/100000 [00:04<00:15, 4894.64it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (82 > 77). Running this sequence through the model will result in indexing errors
100%|██████████| 100000/100000 [00:20<00:00, 4966.27it/s]


16 39.53047


100%|██████████| 100000/100000 [00:27<00:00, 3650.06it/s]


32 61.55775


100%|██████████| 100000/100000 [00:35<00:00, 2840.68it/s]

64 83.5767





In [3]:
indexes = defaultdict(int)
for batch in iter(train_sampler):
	for item in batch:
		indexes[item[1]] += 1

for key, value in indexes.items():
	if value > 1:
		print(f"Index {key} was sampled {value} times")

missing_indexes = set(range(len(train_ds))) - set(indexes.keys())
print(f"Number of missing indexes in training batches: {len(missing_indexes)}")



indexes = defaultdict(int)
for batch in iter(validation_sampler):
	for item in batch:
		indexes[item[1]] += 1

for key, value in indexes.items():
	if value > 1:
		print(f"Index {key} was sampled {value} times")

missing_indexes = set(range(len(test_ds))) - set(indexes.keys())
print(f"Number of missing indexes in validation batches (should be 0): {len(missing_indexes)}")

Number of missing indexes in training batches: 682
Number of missing indexes in validation batches (should be 0): 0


In [4]:
# Ensure that during training we see batches where the prompt is below 77 tokens
# Otherwise the model might always expect long prompts.
lengths = defaultdict(int)
min_original = 9999999999
min_target = 99999999999
for batch in tqdm(iter(train_dataloader)):
	length = batch['prompt'].shape[1]
	lengths[length] += 1

	min_original = min(min_original, batch['original_size'].min().item())
	min_target = min(min_target, batch['target_size'].min().item())

	if sum(lengths.values()) > 10000:
		break

print("Prompt lengths:")
for length, count in lengths.items():
	print(f"{length}: {count}")
print()
print(f"Minimum original size: {min_original}")
print(f"Minimum target size: {min_target}")

  3%|▎         | 10000/359050 [03:21<1:57:23, 49.55it/s]


Prompt lengths:
2: 7127
1: 2291
3: 583

Minimum original size: 527
Minimum target size: 576


In [None]:
train_sampler.set_epoch(random.randint(0, 1000000000))
x = iter(train_dataloader)
batch = next(x)
batch = next(x)
batch = next(x)
batch = next(x)
batch = next(x)

with torch.no_grad():
	for i in range(len(batch['latent'])):
		latent = batch['latent'][i]
		original_size = batch['original_size'][i]
		target_size = batch['target_size'][i]
		crop = batch['crop'][i]
		prompt = batch['prompt'][i]
		prompt_2 = batch['prompt_2'][i]

		# Decode
		latent = latent.float() / vae.config.scaling_factor
		image = vae.decode(latent.unsqueeze(0), return_dict=False)[0][0]
		image_pil = TVF.to_pil_image((image * 0.5 + 0.5).clamp(0, 1)).convert("RGB")

		display(image_pil)
		print(f"Original size (hxw): {original_size}")
		print(f"Target size (hxw): {target_size}")
		print(f"Crop (txl): {crop}")

		for line, line_2 in zip(prompt, prompt_2):
			print(tokenizer.decode(line))
			print(tokenizer_2.decode(line_2))

		print()