<a href="https://colab.research.google.com/github/mengxiaozhibo/AlgoNotes/blob/main/notebooks/text2im.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Run this line in Colab to install the package if it is
# not already installed.
!pip install git+https://github.com/openai/glide-text2im

In [None]:
import tensorflow as tf
import numpy as np

def apply_category_scatter(seq, category_seq):
    """
    类目打散函数

    Args:
        seq: 商品序列 [batch_size, seq_len]
        category_seq: 类目序列 [batch_size, seq_len]

    Returns:
        scattered_seq: 打散后的商品序列 [batch_size, seq_len]
    """
    batch_size = tf.shape(seq)[0]
    seq_len = tf.shape(seq)[1]

    def scatter_single_batch(args):
        batch_seq, batch_category = args
        valid_mask = tf.greater_equal(batch_seq, 0)
        valid_indices = tf.where(valid_mask)
        valid_count = tf.shape(valid_indices)[0]

        # 如果没有有效元素，直接返回原序列
        def no_valid_items():
            return batch_seq

        def has_valid_items():
            # 提取有效元素
            valid_items = tf.gather_nd(batch_seq, valid_indices)
            valid_categories = tf.gather_nd(batch_category, valid_indices)

            # 按类目分组并保持原有顺序
            unique_categories, _ = tf.unique(valid_categories)
            category_groups = []

            # 使用 tf.while_loop 来处理动态循环
            def group_by_category(i, groups):
                if i >= tf.shape(unique_categories)[0]:
                    return i, groups

                cat = unique_categories[i]
                cat_mask = tf.equal(valid_categories, cat)
                cat_items = tf.boolean_mask(valid_items, cat_mask)
                groups.append(cat_items)
                return i + 1, groups

            # 简化版本：使用 Python 循环（在 graph 模式下可能需要调整）
            for i in tf.range(tf.shape(unique_categories)[0]):
                cat = unique_categories[i]
                cat_mask = tf.equal(valid_categories, cat)
                cat_items = tf.boolean_mask(valid_items, cat_mask)
                category_groups.append(cat_items)

            # 轮询式打散：从每个类目组轮流取一个元素
            scattered_items = []
            group_indices = [0] * len(category_groups)

            # 计算总元素数
            total_items = tf.shape(valid_items)[0]

            # 使用 tf.while_loop 实现轮询
            def collect_items(step, items, indices):
                if step >= total_items:
                    return step, items, indices

                # 遍历所有类目组
                def process_groups(group_idx, step, items, indices):
                    if group_idx >= len(category_groups):
                        return group_idx, step, items, indices

                    group_size = tf.shape(category_groups[group_idx])[0]
                    current_idx = indices[group_idx]

                    def add_item():
                        item = category_groups[group_idx][current_idx]
                        new_items = items + [item]
                        new_indices = indices[:group_idx] + [current_idx + 1] + indices[group_idx + 1:]
                        return new_items, new_indices

                    def skip_group():
                        return items, indices

                    new_items, new_indices = tf.cond(
                        current_idx < group_size,
                        add_item,
                        skip_group
                    )

                    return process_groups(group_idx + 1, step + 1, new_items, new_indices)

                return process_groups(0, step, items, indices)

            # 执行轮询收集
            _, scattered_items, _ = collect_items(0, [], group_indices)

            # 重新填充到原始位置
            scattered_seq = tf.zeros_like(batch_seq)
            scattered_seq = tf.tensor_scatter_nd_update(scattered_seq, valid_indices, scattered_items)

            return scattered_seq

        return tf.cond(tf.equal(valid_count, 0), no_valid_items, has_valid_items)

    scattered_seq = tf.map_fn(scatter_single_batch, (seq, category_seq), dtype=seq.dtype)
    return scattered_seq


# ===== 调试示例代码 =====

def test_apply_category_scatter():
    """测试类目打散函数"""
    print("=" * 60)
    print("测试类目打散函数")
    print("=" * 60)

    # 测试用例1：简单的类目打散
    print("\n【测试用例1：简单的类目打散】")
    seq1 = tf.constant([
        [1, 2, 3, 4, 5, 6],  # batch 0
        [7, 8, 9, 10, 11, 12]  # batch 1
    ], dtype=tf.int32)

    category_seq1 = tf.constant([
        [1, 1, 2, 2, 3, 3],  # batch 0: 类目1,1,2,2,3,3
        [1, 2, 1, 2, 1, 2]   # batch 1: 类目1,2,1,2,1,2
    ], dtype=tf.int32)

    print("原始序列:")
    print(seq1.numpy())
    print("类目序列:")
    print(category_seq1.numpy())

    result1 = apply_category_scatter(seq1, category_seq1)
    print("打散后序列:")
    print(result1.numpy())
    print("打散后类目:")
    # 需要重新获取打散后的类目
    print("(需要根据打散后的索引重新映射类目)")

    # 测试用例2：包含无效元素（-1或0）
    print("\n【测试用例2：包含无效元素】")
    seq2 = tf.constant([
        [1, 2, 3, -1, 5, -1],  # batch 0: 包含-1
        [7, -1, 9, 10, -1, 12]  # batch 1: 包含-1
    ], dtype=tf.int32)

    category_seq2 = tf.constant([
        [1, 1, 2, -1, 3, -1],
        [1, -1, 2, 2, -1, 3]
    ], dtype=tf.int32)

    print("原始序列:")
    print(seq2.numpy())
    print("类目序列:")
    print(category_seq2.numpy())

    result2 = apply_category_scatter(seq2, category_seq2)
    print("打散后序列:")
    print(result2.numpy())

    # 测试用例3：单个类目
    print("\n【测试用例3：单个类目】")
    seq3 = tf.constant([
        [1, 2, 3, 4, 5],
        [6, 7, 8, 9, 10]
    ], dtype=tf.int32)

    category_seq3 = tf.constant([
        [1, 1, 1, 1, 1],  # 全部是类目1
        [2, 2, 2, 2, 2]   # 全部是类目2
    ], dtype=tf.int32)

    print("原始序列:")
    print(seq3.numpy())
    print("类目序列:")
    print(category_seq3.numpy())

    result3 = apply_category_scatter(seq3, category_seq3)
    print("打散后序列:")
    print(result3.numpy())

    # 测试用例4：模拟你的实际场景
    print("\n【测试用例4：模拟实际场景】")
    # 假设有3个类目，每个类目有多个商品
    seq4 = tf.constant([
        [101, 102, 201, 202, 301, 302, 103, 203],  # batch 0
        [401, 402, 501, 502, 601, 602, 403, 503]   # batch 1
    ], dtype=tf.int32)

    category_seq4 = tf.constant([
        [1, 1, 2, 2, 3, 3, 1, 2],  # 类目分布
        [4, 4, 5, 5, 6, 6, 4, 5]
    ], dtype=tf.int32)

    print("原始序列:")
    print(seq4.numpy())
    print("类目序列:")
    print(category_seq4.numpy())

    result4 = apply_category_scatter(seq4, category_seq4)
    print("打散后序列:")
    print(result4.numpy())
    print("预期结果：应该按类目轮询排列，如 101,201,301,102,202,302,103,203")

    print("\n" + "=" * 60)
    print("测试完成")
    print("=" * 60)


# 简化版本的类目打散函数（更易于调试）
def apply_category_scatter_simple(seq, category_seq):
    """
    简化版类目打散函数，使用纯Python逻辑（仅用于调试理解）

    注意：这个版本不能在TensorFlow graph中使用，仅用于理解算法逻辑
    """
    batch_size = seq.shape[0]
    seq_len = seq.shape[1]

    scattered_seq = np.zeros_like(seq.numpy())

    for b in range(batch_size):
        batch_seq = seq[b].numpy()
        batch_category = category_seq[b].numpy()

        # 找到有效元素
        valid_mask = batch_seq >= 0
        valid_indices = np.where(valid_mask)[0]
        valid_items = batch_seq[valid_indices]
        valid_categories = batch_category[valid_indices]

        if len(valid_items) == 0:
            continue

        # 按类目分组
        category_dict = {}
        for item, cat in zip(valid_items, valid_categories):
            if cat not in category_dict:
                category_dict[cat] = []
            category_dict[cat].append(item)

        # 轮询式打散
        scattered_items = []
        group_indices = {cat: 0 for cat in category_dict.keys()}

        while len(scattered_items) < len(valid_items):
            for cat in sorted(category_dict.keys()):
                if group_indices[cat] < len(category_dict[cat]):
                    scattered_items.append(category_dict[cat][group_indices[cat]])
                    group_indices[cat] += 1
                    if len(scattered_items) >= len(valid_items):
                        break

        # 填充回原位置
        scattered_seq[b][valid_indices] = scattered_items

    return tf.constant(scattered_seq, dtype=seq.dtype)


def test_simple_version():
    """测试简化版本"""
    print("\n" + "=" * 60)
    print("测试简化版类目打散函数")
    print("=" * 60)

    seq = tf.constant([
        [101, 102, 201, 202, 301, 302, 103, 203],
        [401, 402, 501, 502, 601, 602, 403, 503]
    ], dtype=tf.int32)

    category_seq = tf.constant([
        [1, 1, 2, 2, 3, 3, 1, 2],
        [4, 4, 5, 5, 6, 6, 4, 5]
    ], dtype=tf.int32)

    print("原始序列:")
    print(seq.numpy())
    print("类目序列:")
    print(category_seq.numpy())

    result = apply_category_scatter_simple(seq, category_seq)
    print("打散后序列:")
    print(result.numpy())


if __name__ == "__main__":
    # 设置TensorFlow日志级别
    tf.get_logger().setLevel('INFO')

    # 运行测试
    test_simple_version()

    # 注意：完整的TensorFlow版本可能需要进一步调试
    # 因为在graph模式下，Python的for循环和列表操作可能不适用
    print("\n提示：完整的TensorFlow版本需要使用tf.while_loop等操作")
    print("建议先使用简化版本理解算法逻辑，再转换为TensorFlow操作")


In [None]:
from PIL import Image
from IPython.display import display
import torch as th

from glide_text2im.download import load_checkpoint
from glide_text2im.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    model_and_diffusion_defaults_upsampler
)

In [None]:
# This notebook supports both CPU and GPU.
# On CPU, generating one sample may take on the order of 20 minutes.
# On a GPU, it should be under a minute.

has_cuda = th.cuda.is_available()
device = th.device('cpu' if not has_cuda else 'cuda')

In [None]:
# Create base model.
options = model_and_diffusion_defaults()
options['use_fp16'] = has_cuda
options['timestep_respacing'] = '100' # use 100 diffusion steps for fast sampling
model, diffusion = create_model_and_diffusion(**options)
model.eval()
if has_cuda:
    model.convert_to_fp16()
model.to(device)
model.load_state_dict(load_checkpoint('base', device))
print('total base parameters', sum(x.numel() for x in model.parameters()))

In [None]:
# Create upsampler model.
options_up = model_and_diffusion_defaults_upsampler()
options_up['use_fp16'] = has_cuda
options_up['timestep_respacing'] = 'fast27' # use 27 diffusion steps for very fast sampling
model_up, diffusion_up = create_model_and_diffusion(**options_up)
model_up.eval()
if has_cuda:
    model_up.convert_to_fp16()
model_up.to(device)
model_up.load_state_dict(load_checkpoint('upsample', device))
print('total upsampler parameters', sum(x.numel() for x in model_up.parameters()))

In [None]:
def show_images(batch: th.Tensor):
    """ Display a batch of images inline. """
    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).cpu()
    reshaped = scaled.permute(2, 0, 3, 1).reshape([batch.shape[2], -1, 3])
    display(Image.fromarray(reshaped.numpy()))

In [None]:
# Sampling parameters
prompt = "an oil painting of a corgi"
batch_size = 1
guidance_scale = 3.0

# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997

In [None]:
##############################
# Sample from the base model #
##############################

# Create the text tokens to feed to the model.
tokens = model.tokenizer.encode(prompt)
tokens, mask = model.tokenizer.padded_tokens_and_mask(
    tokens, options['text_ctx']
)

# Create the classifier-free guidance tokens (empty)
full_batch_size = batch_size * 2
uncond_tokens, uncond_mask = model.tokenizer.padded_tokens_and_mask(
    [], options['text_ctx']
)

# Pack the tokens together into model kwargs.
model_kwargs = dict(
    tokens=th.tensor(
        [tokens] * batch_size + [uncond_tokens] * batch_size, device=device
    ),
    mask=th.tensor(
        [mask] * batch_size + [uncond_mask] * batch_size,
        dtype=th.bool,
        device=device,
    ),
)

# Create a classifier-free guidance sampling function
def model_fn(x_t, ts, **kwargs):
    half = x_t[: len(x_t) // 2]
    combined = th.cat([half, half], dim=0)
    model_out = model(combined, ts, **kwargs)
    eps, rest = model_out[:, :3], model_out[:, 3:]
    cond_eps, uncond_eps = th.split(eps, len(eps) // 2, dim=0)
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
    eps = th.cat([half_eps, half_eps], dim=0)
    return th.cat([eps, rest], dim=1)

# Sample from the base model.
model.del_cache()
samples = diffusion.p_sample_loop(
    model_fn,
    (full_batch_size, 3, options["image_size"], options["image_size"]),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]
model.del_cache()

# Show the output
show_images(samples)

In [None]:
##############################
# Upsample the 64x64 samples #
##############################

tokens = model_up.tokenizer.encode(prompt)
tokens, mask = model_up.tokenizer.padded_tokens_and_mask(
    tokens, options_up['text_ctx']
)

# Create the model conditioning dict.
model_kwargs = dict(
    # Low-res image to upsample.
    low_res=((samples+1)*127.5).round()/127.5 - 1,

    # Text tokens
    tokens=th.tensor(
        [tokens] * batch_size, device=device
    ),
    mask=th.tensor(
        [mask] * batch_size,
        dtype=th.bool,
        device=device,
    ),
)

# Sample from the base model.
model_up.del_cache()
up_shape = (batch_size, 3, options_up["image_size"], options_up["image_size"])
up_samples = diffusion_up.ddim_sample_loop(
    model_up,
    up_shape,
    noise=th.randn(up_shape, device=device) * upsample_temp,
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]
model_up.del_cache()

# Show the output
show_images(up_samples)