# <a id='toc1_'></a>[Inference on CE conditions](#toc0_)

**Table of contents**<a id='toc0_'></a>    
- [Inference on CE conditions](#toc1_)    
- [Inference on CAM map](#toc2_)    
    - [0: Single Image Inference](#toc2_1_1_)    
    - [1: Single Image Inference](#toc2_1_2_)    
    - [1: Batch Inference](#toc2_1_3_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

In [1]:
from share import *
import config

import cv2
import einops
import numpy as np
import torch
import random
import os

from pytorch_lightning import seed_everything
from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler


apply_canny = CannyDetector()

model = create_model("./models/cldm_v15.yaml").cpu()
# model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
model.load_state_dict(
    load_state_dict("./checkpoints/fusar_v1/fusar_v1_ce-epoch=9.ckpt", location="cuda")
)

model = model.cuda()
ddim_sampler = DDIMSampler(model)

  from .autonotebook import tqdm as notebook_tqdm


No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.3.layer_norm2.bias', 'vision_model.encoder.layers.10.self_attn.v_proj.bias', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.4.mlp.fc1.bias', 'vision_model.encoder.layers.13.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.q_proj.bias', 'vision_model.encoder.layers.17.self_attn.k_proj.weight', 'vision_model.encoder.layers.14.mlp.fc2.bias', 'vision_model.embeddings.patch_embedding.weight', 'vision_model.encoder.layers.1.self_attn.v_proj.bias', 'vision_model.encoder.layers.1.self_attn.out_proj.bias', 'vision_model.encoder.layers.14.layer_norm1.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.bias', 'vision_model.post_layernorm.we

Loaded model config from [./models/cldm_v15.yaml]
Loaded state_dict from [./checkpoints/fusar_v1/fusar_v1_ce-epoch=9.ckpt]


In [47]:
def random_rotate_image(image, random_angle=0):
    rows, cols, _ = image.shape
    # random_angle = np.random.uniform(min_angle, max_angle)
    random_angle = random_angle
    rotation_matrix = cv2.getRotationMatrix2D((cols / 2, rows / 2), random_angle, 1)
    rotated_image = cv2.warpAffine(image, rotation_matrix, (cols, rows))
    return rotated_image


def process(
    input_image,
    prompt,
    output_dir="./gen/debug",
    a_prompt="SAR image",
    n_prompt="colored image",
    guess_mode=False,
    num_samples=1,
    image_resolution=512,
    ddim_steps=20,
    strength=1.0,
    scale=9.0,
    seed=12345,
    eta=0.0,
    low_threshold=100,
    high_threshold=200,
):
    with torch.no_grad():
        img_name = os.path.basename(input_image)
        input_image = cv2.imread(input_image)
        img = resize_image(HWC3(input_image), image_resolution)
        # img = random_rotate_image(img, random_angle=60)
        H, W, C = img.shape

        detected_map = apply_canny(img, low_threshold, high_threshold)
        detected_map = HWC3(detected_map)

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, "b h w c -> b c h w").clone()

        if seed == -1:
            seed = random.randint(0, 65535)
        seed_everything(seed)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        cond = {
            "c_concat": [control],
            "c_crossattn": [
                model.get_learned_conditioning([prompt + ", " + a_prompt] * num_samples)
            ],
        }
        un_cond = {
            "c_concat": None if guess_mode else [control],
            "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)],
        }
        shape = (4, H // 8, W // 8)

        if config.save_memory:
            model.low_vram_shift(is_diffusing=True)

        model.control_scales = (
            [strength * (0.825 ** float(12 - i)) for i in range(13)]
            if guess_mode
            else ([strength] * 13)
        )  # Magic number. IDK why. Perhaps because 0.825**12<0.01 but 0.826**12>0.01
        samples, intermediates = ddim_sampler.sample(
            ddim_steps,
            num_samples,
            shape,
            cond,
            verbose=False,
            eta=eta,
            unconditional_guidance_scale=scale,
            unconditional_conditioning=un_cond,
        )

        if config.save_memory:
            model.low_vram_shift(is_diffusing=False)

        x_samples = model.decode_first_stage(samples)
        x_samples = (
            (einops.rearrange(x_samples, "b c h w -> b h w c") * 127.5 + 127.5)
            .cpu()
            .numpy()
            .clip(0, 255)
            .astype(np.uint8)
        )

        results = [x_samples[i] for i in range(num_samples)]
        output_stack = []
        output_stack.append(img)
        output_stack.append(detected_map)
        output_stack.extend(results)
        combined_image = np.hstack(output_stack)

        # Set the text, font, size, and color
        text = prompt
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (255, 255, 255)  # White color
        font_thickness = 2

        # Measure the size of the text
        text_size, _ = cv2.getTextSize(text, font, font_scale, font_thickness)

        # Create a new image with extra space at the bottom for the text
        new_height = combined_image.shape[0] + text_size[1] * 2
        new_image = np.zeros((new_height, combined_image.shape[1], 3), dtype=np.uint8)
        new_image[: combined_image.shape[0], :, :] = combined_image

        # Add the text to the new image
        text_position = (10, int(combined_image.shape[0] + text_size[1] * 1.5))
        cv2.putText(
            new_image, text, text_position, font, font_scale, font_color, font_thickness
        )

        cv2.imwrite(output_dir, new_image)

In [49]:
input_image = "./training/fusar_mix/target/Ship_C03S02N0004.png"
prompt = "a Dredger ship in a top-down grayscale SAR image"
output_dir = "./gen/str/Ship_C03S02N0004.png"
get_ims = process(input_image, prompt, output_dir, num_samples=4, strength=0.5)

Global seed set to 12345


Data shape for DDIM sampling is (4, 4, 64, 64), eta 0.0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:17<00:00,  1.15it/s]


# <a id='toc2_'></a>[Inference on CAM map](#toc0_)

### <a id='toc2_1_1_'></a>[0: Single Image Inference](#toc0_)

In [9]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys

sys.path.append("/workspace/dso/ControlSAR")
import importlib
from inf_dir import inference_n
from share import setup_config
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

In [5]:
setup_config()
model = create_model("../models/cldm_v15.yaml").cpu()

No module 'xformers'. Proceeding without it.
ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loaded model config from [../models/cldm_v15.yaml]


### <a id='toc2_1_2_'></a>[1: Single Image Inference](#toc0_)

In [7]:
# model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
# model.load_state_dict(load_state_dict("./checkpoints/fusrs_v2/256_cam_0/fusrs_epoch=5.ckpt", location="cuda"))
# model.load_state_dict(load_state_dict("./checkpoints/fusrs_v2/256_cam_0/fusrs_epoch=13.ckpt", location="cuda"))
model.load_state_dict(
    load_state_dict(
        # "./checkpoints/fusar_v1/fusar_v1_ce-epoch=9.ckpt", location="cuda:5"
        "../checkpoints/fusrs_v2_lc/fusrs_epoch=89.ckpt",
        location="cuda:0",
    )
)

model = model.cuda()
ddim_sampler = DDIMSampler(model)

Loaded state_dict from [../checkpoints/fusrs_v2_lc/fusrs_epoch=89.ckpt]


In [13]:
# input_img = "./training/fusrs_v2_256_cam/source/L274_8.png"
# prompt = "a high-quality Dredger ship in a grayscale SAR radar satellite image, captured during a search and rescue operation"
# # prompt = "a bird's-eye view of a Dredger ship navigating in the sea, captured in a SAR grayscale radar image"
# save_prefix = "./gen/debug/epoch23_step20/L274_8"

input_img = "../training/fusrs_v2_256_cam_v2/source/Ship_C04S02N0087.png"
# prompt = "a high-quality Dredger ship in a grayscale SAR radar satellite image, captured during a search and rescue operation"
# prompt = "Fishing ship in SAR image, fishing vessel, maneuverable hull, distinct activity pattern, marine resources, shorter and narrower profile, lower radar backscatter, unique operational behaviors, longer lingering time, rapid changes in direction, suspended fishing equipment, frequent turns, designated operation zones, efficient hull design, fishing-specific equipment, high-intensity scattering points, marine life tracking, compact ship construction, specialized storage compartments, handling fishing activities."
prompt = "Fishing ship in SAR image, smaller hull size, trawler, narrow profile, fishing gear, masts, booms, radar scattering points, reflection intensity, fishing nets, longlines, fishing pots, shorter length, sea activity patterns, storage sections, operational behavior, designated fishing areas, seasonal migration patterns, artisanal boats, commercial fishing boats."
save_path = "../gen/debug/230508/P2"

_ = inference_n(
    input_img,
    prompt,
    model,
    ddim_sampler,
    save_path,
    guess_mode=False,
    num_samples=6,
    image_resolution=256,
    ddim_steps=20,
    strength=1,
    scale=9.0,
    seed=0,
    eta=0,
)

Global seed set to 0


Data shape for DDIM sampling is (6, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:06<00:00,  3.10it/s]


### <a id='toc2_1_3_'></a>[1: Batch Inference](#toc0_)

In [13]:
import importlib
import inf_cam

importlib.reload(inf_cam)
from inf_cam import inference
from share import setup_config
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler

setup_config()
model = create_model("./models/cldm_v15.yaml").cpu()

ControlLDM: Running in eps-prediction mode
DiffusionWrapper has 859.52 M params.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loaded model config from [./models/cldm_v15.yaml]


In [14]:
# model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda'))
# model.load_state_dict(load_state_dict("./checkpoints/fusrs_v2/256_cam_0/fusrs_epoch=5.ckpt", location="cuda"))
# model.load_state_dict(load_state_dict("./checkpoints/fusrs_v2/256_cam_0/fusrs_epoch=13.ckpt", location="cuda"))
model.load_state_dict(
    load_state_dict(
        "./checkpoints/fusrs_v2/256_cam_1/fusrs_epoch=10.ckpt", location="cuda"
    )
)

model = model.cuda()
ddim_sampler = DDIMSampler(model)

Loaded state_dict from [./checkpoints/fusrs_v2/256_cam_1/fusrs_epoch=10.ckpt]


In [15]:
import json
import os


def generate_batch(prompts_file, input_dir, output_dir, model, ddim_sampler):
    with open(prompts_file, "r") as f:
        # for idx, line in enumerate(f):
        for idx in range(50):
            line = f.readline()
            data = json.loads(line)
            condition_path = data["condition"]
            condition_path = os.path.join(input_dir, condition_path)
            prompt = data["prompt"]
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            save_img_prefix = (
                f"P{idx}_" + os.path.splitext(os.path.basename(condition_path))[0]
            )
            save_prefix = os.path.join(output_dir, save_img_prefix)
            print(condition_path, prompt, save_prefix)
            _ = inference(
                condition_path,
                prompt,
                model,
                ddim_sampler,
                save_prefix,
                guess_mode=False,
                num_samples=1,
                image_resolution=256,
                ddim_steps=20,
                strength=1,
                scale=9.0,
                seed=0,
                eta=0,
            )


# Set the paths
# prompts_file = "/workspace/dso/ControlSAR/gen/fusrs_v2_cam/cond+prompt.json"
prompts_file = "/workspace/dso/ControlSAR/gen/fusrs_v2_cam/dredger.json"
input_dir = "./training/fusrs_v2_256_cam/source"
output_dir = "./gen/fusrs_v2_cam/epoch23_step20_eta0"

# Call the generate_batch function
generate_batch(prompts_file, input_dir, output_dir, model, ddim_sampler)

Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a high-resolution grayscale SAR image depicting a Fishing ship sailing in rough seas ./gen/fusrs_v2_cam/epoch23_step20_eta0/P0_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.65it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a bird's-eye view of a Fishing ship navigating in the sea, captured in a SAR grayscale radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P1_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.78it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a Fishing ship seen from a top-down grayscale satellite SAR radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P2_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:03<00:00,  5.03it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png an aerial grayscale SAR image of a Fishing ship maneuvering on the sea ./gen/fusrs_v2_cam/epoch23_step20_eta0/P3_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.71it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a Fishing ship in a grayscale SAR radar satellite image, captured during a search and rescue operation ./gen/fusrs_v2_cam/epoch23_step20_eta0/P4_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.70it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a top-down view of a Fishing ship leaving a harbor in a high-resolution grayscale SAR satellite radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P5_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.63it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a Fishing ship conducting a maritime operation, as seen in a grayscale overhead SAR radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P6_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.73it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a bird's-eye view of a Fishing ship passing through an ice field, visible in a grayscale SAR image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P7_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.80it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a grayscale SAR satellite radar image capturing a Fishing ship near an offshore platform ./gen/fusrs_v2_cam/epoch23_step20_eta0/P8_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.96it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L416_2.png a Fishing ship in a top-down grayscale SAR image, with visible deck equipment and structure ./gen/fusrs_v2_cam/epoch23_step20_eta0/P9_L416_2
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.68it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a high-resolution grayscale SAR image depicting a Cargo ship sailing in rough seas ./gen/fusrs_v2_cam/epoch23_step20_eta0/P10_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.60it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a bird's-eye view of a Cargo ship navigating in the sea, captured in a SAR grayscale radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P11_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.92it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a Cargo ship seen from a top-down grayscale satellite SAR radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P12_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.72it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png an aerial grayscale SAR image of a Cargo ship maneuvering on the sea ./gen/fusrs_v2_cam/epoch23_step20_eta0/P13_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.73it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a Cargo ship in a grayscale SAR radar satellite image, captured during a search and rescue operation ./gen/fusrs_v2_cam/epoch23_step20_eta0/P14_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.73it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a top-down view of a Cargo ship leaving a harbor in a high-resolution grayscale SAR satellite radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P15_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.75it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a Cargo ship conducting a maritime operation, as seen in a grayscale overhead SAR radar image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P16_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.86it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a bird's-eye view of a Cargo ship passing through an ice field, visible in a grayscale SAR image ./gen/fusrs_v2_cam/epoch23_step20_eta0/P17_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.89it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a grayscale SAR satellite radar image capturing a Cargo ship near an offshore platform ./gen/fusrs_v2_cam/epoch23_step20_eta0/P18_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.81it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L443_4.png a Cargo ship in a top-down grayscale SAR image, with visible deck equipment and structure ./gen/fusrs_v2_cam/epoch23_step20_eta0/P19_L443_4
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler: 100%|██████████| 20/20 [00:04<00:00,  4.88it/s]
Global seed set to 0


./training/fusrs_v2_256_cam/source/./L274_8.png a high-resolution grayscale SAR image depicting a Dredger ship sailing in rough seas ./gen/fusrs_v2_cam/epoch23_step20_eta0/P20_L274_8
Data shape for DDIM sampling is (1, 4, 32, 32), eta 0
Running DDIM Sampling with 20 timesteps


DDIM Sampler:  80%|████████  | 16/20 [00:03<00:00,  4.16it/s]


KeyboardInterrupt: 