In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from collections import defaultdict
import itertools
import os
import json
import numpy as np
import os.path as osp
%cd ..

All exp are conducted on one A100

In [None]:


def read_from_dir(dir_path, data_name, model_name):
    exps = []
    for subdir in os.listdir(dir_path):
        subdir_path = osp.join(dir_path, subdir, 'benchmark_image_generation')
        image_fild_score_log_path = osp.join(subdir_path, 'image_fid_score_evaluation.log')
        image_clip_score_log_path = osp.join(subdir_path, 'image_clip_score_evaluation.log')
        generation_log_path = osp.join(subdir_path, 'text_to_image_generation.log')

        # read : 2024-12-05 19:32:02,322 - __main__ - INFO - Average CLIP Score: 20.951017379760742
        for line in open(image_clip_score_log_path):
            if "Average CLIP Score" in line:
                clip_score = float(line.split(":")[-1].strip())
                break
        
        # read : 2024-12-05 19:41:59,625 - __main__ - INFO - FID Score: 344.50787353515625
        for line in open(image_fild_score_log_path):
            if "FID Score" in line:
                fid_score = float(line.split(":")[-1].strip())
                break
        
        # read : 2024-12-05 16:55:17,154 - __main__ - INFO - Total time: 68.27094841003418 seconds for 97 iterations, throughput: 1.420809323131409 iterations/second
        for line in open(generation_log_path):
            if "Total time:" in line:
                throughput = float(line.split()[-2])
                break
        
        exps.append({
            "data_name": data_name,
            "model_name": model_name,
            "throughput": throughput,
            "clip_score": clip_score,
            "fid_score": fid_score,
        })
    return exps

In [None]:
read_kwargs = [
    {
        "dir_path": "RESULTS/coco_cc12m_no_cfg",
        "data_name": "coco",
        "model_name": "cc12m"
    },
    {
        "dir_path": "RESULTS/coco_finetuned_20_no_cfg",
        "data_name": "coco",
        "model_name": "finetuned_20"
    },
    {
        "dir_path": "RESULTS/coco_finetuned_40_no_cfg",
        "data_name": "coco",
        "model_name": "finetuned_40"
    },
    {
        "dir_path": "RESULTS/coco_midjv6_no_cfg",
        "data_name": "coco",
        "model_name": "midjv6"
    },
    {
        "dir_path": "RESULTS/coco_origin_baseline_no_cfg",
        "data_name": "coco",
        "model_name": "baseline-no-cfg"
    },
    {
        "dir_path": "RESULTS/coco_origin_baseline_cfg",
        "data_name": "coco",
        "model_name": "baseline-cfg"
    },
    {
        "dir_path": "RESULTS/midjv6_cc12m_no_cfg",
        "data_name": "midjv6",
        "model_name": "cc12m"
    },
    {
        "dir_path": "RESULTS/midjv6_finetuned_20_no_cfg",
        "data_name": "midjv6",
        "model_name": "finetuned_20"
    },
    {
        "dir_path": "RESULTS/midjv6_finetuned_40_no_cfg",
        "data_name": "midjv6",
        "model_name": "finetuned_40"
    },
    {
        "dir_path": "RESULTS/midjv6_midjv6_no_cfg",
        "data_name": "midjv6",
        "model_name": "midjv6"
    },
    {
        "dir_path": "RESULTS/midjv6_origin_baseline_no_cfg",
        "data_name": "midjv6",
        "model_name": "baseline-no-cfg"
    },
    {
        "dir_path": "RESULTS/midjv6_origin_baseline_cfg",
        "data_name": "midjv6",
        "model_name": "baseline-cfg"
    },
]

In [None]:
len(read_kwargs)

In [None]:
readed = [read_from_dir(**kwargs) for kwargs in read_kwargs]
datas = list(itertools.chain.from_iterable(readed))

In [None]:
datas

In [None]:
def select_datas(datas, select_fn):
    return [data for data in datas if select_fn(data)]

def sort_plot_line(ax, xs, ys, **kwargs):
    sort_index = sorted(range(len(xs)), key=lambda k: xs[k])
    xs = [xs[i] for i in sort_index]
    ys = [ys[i] for i in sort_index]
    ax.plot(xs, ys, **kwargs)

def plot(datas, group_by_fn, select_fn=lambda x: True, group_to_label_fn=lambda x: str(x)):
    datas = select_datas(datas, select_fn)
    title2datas = defaultdict(list)
    for data in datas:
        title2datas[group_by_fn(data)].append(data)
    
    # plot throughput -- clip score
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    ax = axes[0]
    for title, datas in title2datas.items():
        throughputs = [data["throughput"] for data in datas]
        clip_scores = [data["clip_score"] for data in datas]
        # ax.plot(throughputs, clip_scores, label=group_to_label_fn(title))
        sort_plot_line(ax, throughputs, clip_scores, label=group_to_label_fn(title))

    ax.set_xlabel("Throughput (images/sec)")
    ax.set_ylabel("Clip Score")
    ax.legend()
    ax.set_title(f"Throughput -- Clip Score")

    ax = axes[1]
    for title, datas in title2datas.items():
        throughputs = [data["throughput"] for data in datas]
        fid_scores = [data["fid_score"] for data in datas]
        # ax.plot(throughputs, fid_scores, label=group_to_label_fn(title))
        sort_plot_line(ax, throughputs, fid_scores, label=group_to_label_fn(title))
        
    ax.set_xlabel("Throughput (images/sec)")
    ax.set_ylabel("FID Score")
    ax.legend()
    ax.set_title(f"Throughput -- FID Score")
    plt.show()


# All experiments in One Plot
- Training helps improve quality at high throughput
- Trade-off Between Metric and Throughput

In [None]:
def plot_quadra_grouped():
    ...

In [None]:
print("select midjv6")
midjvt6_datas = select_datas(datas, lambda data: data["data_name"] == "midjv6")
plot(
    midjvt6_datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    group_to_label_fn=lambda x: x[1]   
)

print("select coco")
coco_datas = select_datas(datas, lambda data: data["data_name"] == "coco")
plot(
    coco_datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    group_to_label_fn=lambda x: x[1]
)

## Finer plotting

In [None]:
len(datas)


In [None]:
def plot_box_plot_line(datas):
    baseline_datas = select_datas(datas, lambda data: data["model_name"] == "baseline-cfg")
    baseline_no_cfg_datas = select_datas(datas, lambda data: data["model_name"] == "baseline-no-cfg")

    experiment_datas = select_datas(datas, lambda data: data["model_name"] not in ["baseline-no-cfg", "baseline-cfg"])

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    ax = axes[0]
    sort_plot_line(ax, [data["throughput"] for data in baseline_datas], [data["clip_score"] for data in baseline_datas], label="baseline-cfg")
    sort_plot_line(ax, [data["throughput"] for data in baseline_no_cfg_datas], [data["clip_score"] for data in baseline_no_cfg_datas], label="baseline-no-cfg")

    group_by_throughput_bin = defaultdict(list)
    bin_size = 0.2
    for data in experiment_datas:
        group_by_throughput_bin[data["throughput"] // bin_size].append(data)

    line_throughputs, line_scores = [], []
    for i, throughput_bin in enumerate(group_by_throughput_bin):
        grouped_datas = group_by_throughput_bin[throughput_bin]
        throughput_x = np.mean([data["throughput"] for data in grouped_datas])
        throughput_x = np.round(throughput_x, 2)
        ax.boxplot([data["clip_score"] for data in grouped_datas], positions=[throughput_x], widths=0.1, showfliers=False, capprops=dict(color="red"), medianprops=dict(color="red"))
        line_throughputs.append(throughput_x)
        line_scores.append(np.mean([data["clip_score"] for data in grouped_datas]))

    sort_plot_line(ax, line_throughputs, line_scores, label="experiments", color="red")

    ax.set_xlabel("Throughput (images/sec)")
    ax.set_ylabel("Clip Score")
    ax.legend()
    ax.set_title(f"Throughput -- Clip Score")

    ax = axes[1]
    line_throughputs, line_scores = [], []
    sort_plot_line(ax, [data["throughput"] for data in baseline_datas], [data["fid_score"] for data in baseline_datas], label="baseline-cfg")
    sort_plot_line(ax, [data["throughput"] for data in baseline_no_cfg_datas], [data["fid_score"] for data in baseline_no_cfg_datas], label="baseline-no-cfg")
    for i, throughput_bin in enumerate(group_by_throughput_bin):
        grouped_datas = group_by_throughput_bin[throughput_bin]
        throughput_x = np.mean([data["throughput"] for data in grouped_datas])
        throughput_x = np.round(throughput_x, 2)
        ax.boxplot([data["fid_score"] for data in grouped_datas], positions=[throughput_x], widths=0.1, showfliers=False, capprops=dict(color="red"), medianprops=dict(color="red"))
        line_throughputs.append(throughput_x)
        line_scores.append(np.mean([data["fid_score"] for data in grouped_datas]))

    sort_plot_line(ax, line_throughputs, line_scores, label="experiments", color="red")


    ax.set_xlabel("Throughput (images/sec)")
    ax.set_ylabel("FID Score")
    ax.legend()
    ax.set_title(f"Throughput -- FID Score")

    plt.show()
    

print("select coco")
plot_box_plot_line(select_datas(datas, lambda data: data["data_name"] == "coco"))


print("select midjv6")
plot_box_plot_line(select_datas(datas, lambda data: data["data_name"] == "midjv6"))


# Compariance on narrow distribution with general distribution


In [None]:
print("select midjv6")
# model_select_names = ['cc12m', 'midjv6', 'baseline-no-cfg', 'baseline-cfg']
model_select_names = ['cc12m', 'midjv6', ]
plot(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "midjv6",
    group_to_label_fn=lambda x: x[1]   
)

print("select coco")
plot(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "coco",
    group_to_label_fn=lambda x: x[1]
)

In [None]:
def plot_in_one(datas, group_by_fn, select_fn=lambda x: True, group_to_label_fn=lambda x: str(x)):
    datas = select_datas(datas, select_fn)
    title2datas = defaultdict(list)
    for data in datas:
        title2datas[group_by_fn(data)].append(data)
    
    # Create a single figure and axis
    fig, ax1 = plt.subplots(figsize=(8, 5))
    ax2 = ax1.twinx()

    # Plot CLIP scores on the left y-axis
    for title, d in title2datas.items():
        throughputs = [data["throughput"] for data in d]
        clip_scores = [data["clip_score"] for data in d]
        sort_plot_line(ax1, throughputs, clip_scores, label=group_to_label_fn(title))
    ax1.set_xlabel("Throughput (images/sec)")
    ax1.set_ylabel("CLIP Score", color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax1.set_title("Throughput vs CLIP & FID Scores")

    # Plot FID scores on the right y-axis
    for title, d in title2datas.items():
        throughputs = [data["throughput"] for data in d]
        fid_scores = [data["fid_score"] for data in d]
        sort_plot_line(ax2, throughputs, fid_scores, linestyle='--')  # Optional: different style for distinction
    ax2.set_ylabel("FID Score", color='tab:red')
    ax2.tick_params(axis='y', labelcolor='tab:red')

    # Combine legends from both axes
    lines1, labels1 = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    # Remove duplicates if necessary
    combined_lines = lines1 + lines2
    combined_labels = labels1 + labels2
    # You may want to use something like OrderedDict to remove duplicates
    ax1.legend(combined_lines, combined_labels, loc="upper left")

    plt.tight_layout()
    plt.show()


map_model_name_to_label = {
    "cc12m": "Trained on CC12M",
    "midjv6": "Trained on MidJv6",
}
print("select midjv6")
# model_select_names = ['cc12m', 'midjv6', 'baseline-no-cfg', 'baseline-cfg']
model_select_names = ['cc12m', 'midjv6', ]
plot_in_one(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "midjv6",
    group_to_label_fn=lambda x: map_model_name_to_label[x[1]]
)

print("select coco")
plot_in_one(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "coco",
    group_to_label_fn=lambda x: map_model_name_to_label[x[1]]
)

# Pretrain Finetune

In [None]:
datas

In [None]:

map_model_name_to_label = {
    "finetuned_20": "Pretrained on CC12M, Finetuned on MidJv6",
    "midjv6": "Trained on MidJv6",
}
print("select midjv6")
# model_select_names = ['cc12m', 'midjv6', 'baseline-no-cfg', 'baseline-cfg']
model_select_names = ['finetuned_20', 'midjv6', ]
plot_in_one(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "midjv6",
    group_to_label_fn=lambda x: map_model_name_to_label[x[1]]
)

print("select coco")
plot_in_one(
    datas,
    group_by_fn=lambda data: (data["data_name"], data["model_name"]),
    select_fn=lambda data: data["model_name"] in model_select_names and data["data_name"] == "coco",
    group_to_label_fn=lambda x: map_model_name_to_label[x[1]]
)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Define the function
def f(sigma, shift):
    y = shift * sigma / (1 + (shift - 1) * sigma)
    y = y ** (1.5)
    return y

# Range of sigma values
sigmas = np.linspace(0, 1, 200)

# Different shifts to plot
shifts = [1, 2, 3, 4, 5]

# Plot each shift on the same figure
plt.figure(figsize=(8, 6))
for s in shifts:
    y = f(sigmas, s)
    plt.plot(sigmas, y, label=f'shift={s}')

plt.xlabel('sigma')
plt.ylabel('f(sigma)')
plt.title('f(sigma) = shift * sigma / [1 + (shift - 1)*sigma] for multiple shifts')
plt.grid(True)
plt.legend()
plt.show()
