In [1]:
import argparse
import torch
from utils import get_model, LoadEncoder
from models.engine import DDIMSampler, DDIMSamplerEncoder
from torchvision.utils import save_image, make_grid
from collections import OrderedDict
from types import SimpleNamespace
import os

In [2]:
class Args(argparse.Namespace):
    arch = "unet"
    img_size=64
    num_timestep = 1000
    beta = (0.0001, 0.02)
    num_condition = [3, 3]
    emb_size = 128
    channel_mult = [1, 2, 2, 2]
    num_res_blocks = 2
    use_spatial_transformer = False
    num_heads = 4
    num_sample = 500
    w = 5
    projection_dim=512
    only_table = False
    concat = False
    only_encoder = False
    num_head_channels = -1
    encoder_path = None

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args = Args()

In [3]:
model = get_model(args)
epoch = 90
ckpt = torch.load(f"checkpoint/phison_cmlip/NoMiss/model_{epoch}.pth")["model"]
new_dict = OrderedDict()
    
for k, v in ckpt.items():
    if k.startswith("module"):
        new_dict[k[7:]] = v
    else:
        new_dict[k] = v
try:
    model.load_state_dict(new_dict)
    print("All keys successfully match")
except:
    print("some keys are missing!")

for p in model.parameters():
    p.requires_grad = False

model.eval()
model.to(device)

sampler = DDIMSampler(
    model=model,
    beta =args.beta,
    T=args.num_timestep,
    w=args.w,
).to(device)

if args.encoder_path != None:
    encoder = LoadEncoder(args).to(device)
    sampler = DDIMSamplerEncoder(
            model = model,
            encoder = encoder,
            beta = args.beta,
            T = args.num_timestep,
            w = args.w,
            only_encoder = args.only_encoder
    ).to(device)

All keys successfully match


In [4]:
# ATR2IDX = {
#     'good': 0,
#     'broke': 1,
#     'shift': 2,
# }

# OBJ2IDX = {
#     'F1210': 0,
#     'L2016': 1,
#     'SOT23': 2,
# }
ATR2IDX = {
    'good': 0,
    'broke': 1,
    'shift': 2,
}

OBJ2IDX = {
    'group1': 0,
    'group3': 1,
    'group7': 2,
}


target = "broke group3"
images = []
for i in range(8):
    atr, obj = ATR2IDX[target.split(" ")[0]], OBJ2IDX[target.split(" ")[-1]]
    atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample)
    obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample)
    x_i = torch.randn(args.num_sample, 3, 64, 64).to(device)
    x0 = sampler(x_i, atr, obj, steps=100)
    x0 = x0 * 0.5 + 0.5
    idx = 0
    for x in x0:
        save_image(x, f"SampledImg/*/epochs_{epoch}_{i}_{idx}.png")
#         save_image(x, f"/root/notebooks/nfs/work/barry.chen/DenoisingDiffusionProbabilityModel-ddpm-/dataset/only_group1_add/broke/epochs_{epoch}_{i}_{idx}.png")
        idx += 1


100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.57s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]
100%|[38;2;101;101;181m██████████[0m| 100/100 [04:17<00:00,  2.58s/it, step=1, sample=1]


In [None]:
# ATR2IDX = {
#     'good': 0,
#     'broke': 1,
#     'shift': 2,
# }

# OBJ2IDX = {
#     'F1210': 0,
#     'L2016': 1,
#     'SOT23': 2,
# }
ATR2IDX = {
    'good': 0,
    'broke': 1,
    'shift': 2,
}

OBJ2IDX = {
    'group1': 0,
    'group3': 1,
    'group7': 2,
}


targets = ["broke group3", "broke group3", "broke group3", "broke group3", "broke group3"]


for i in range(10):
    images = []
    for target in targets:
        atr, obj = ATR2IDX[target.split(" ")[0]], OBJ2IDX[target.split(" ")[-1]]
        atr = torch.tensor(atr, dtype=torch.long, device=device).repeat(args.num_sample)
        obj = torch.tensor(obj, dtype=torch.long, device=device).repeat(args.num_sample)

        x_i = torch.randn(args.num_sample, 3, 64, 64).to(device)
        x0 = sampler(x_i, atr, obj, steps=100)
        x0 = x0 * 0.5 + 0.5
        images.append(x0)
    images = torch.cat(images, dim=0)
    images = make_grid(images, nrow=args.num_sample)
    save_image(images, f"SampledImg/*/epochs_{epoch+1}_{i}.png")