# 🧠 MS Lesion Synthesis Inference Notebook

In [163]:
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import os
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
seed = 17844

In [184]:
def generate_simple_lesion(image, mask, lesion_size, lesion_intensity):
    if isinstance(image, Image.Image):
        image = transforms.ToTensor()(image).unsqueeze(0)
    lesion = torch.zeros_like(image)
    lesion[0, 0, mask[0]:mask[0]+lesion_size[0], mask[1]:mask[1]+lesion_size[1]] = lesion_intensity
    return lesion

def generate_mask(image, mask_size, mask_position, show=True):
    if isinstance(image, Image.Image):
        image = transforms.ToTensor()(image).unsqueeze(0)
    
    mask = torch.zeros_like(image)
    mask[0, 0, mask_position[0]:mask_position[0]+mask_size[0], mask_position[1]:mask_position[1]+mask_size[1]] = 1
    
    if show:
        plt.imshow(mask[0, 0].cpu().numpy(), cmap='gray')
        plt.show()
        print(f"Mask generated. Size: {mask_size}, position: {mask_position}")
    return mask

def generate_lesion(pipe, image, mask, prompt, guidance, device='cuda', seed=17844):
    generator = torch.Generator(device=device).manual_seed(seed)
    return pipe(prompt=prompt,
                image=image,
                mask_image=mask,
                num_inference_steps=25,
                generator=generator,
                guidance=guidance, #guidance_scale??
            ).images[0]

def load_image(path, size=256, show=True):
    image = Image.open(path).convert('RGB')
    image = transforms.Compose([
        transforms.Resize(size),
        transforms.CenterCrop(size)
    ])(image)

    if show:
        plt.imshow(image)
        plt.show()
        print(f"Image {path} loaded. Size: {image.size}, mode: {image.mode}")
    return image

In [185]:
# # ---- Load pipeline ----
# model_path = "../lesion-inpating-dreambooth-model-new"        # <- Replace this with your fine-tuned model folder
# device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

# pipe = StableDiffusionInpaintPipeline.from_pretrained(
#     model_path,
#     torch_dtype=torch.float32,
#     safety_checker=None,
# )

# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# pipe = pipe.to(device)
# # pipe.set_progress_bar_config(disable=True)
# # inspect.signature(pipe.__call__)

In [186]:
image_path = "./input_data/flair.png"                     # <- Path to FLAIR MRI image (PNG, JPG)
prompt = "Multiple sclerosis lesion inpainting"           # <- Prompt for the model
output_path = "./output_data/lesion_output.png"           # <- Where to save the output
image_size = 512  # Assumes model was trained at 512x512

In [187]:
# Load image and mask
image = load_image(image_path, image_size, show=False)

# Generate mask
mask_size = (27, 33)
mask_position = (150, 150)
mask = generate_mask(image, mask_size, mask_position, show=False)

# Generate simple lesion
lesion_size = (15, 15)
lesion_intensity = 0.5
lesion = generate_simple_lesion(image, mask_position, lesion_size, lesion_intensity)

# Generate lesion
lesion = generate_lesion(pipe, image, mask, prompt, lesion)


ValueError: Incorrect configuration settings! The config of `pipeline.unet`: FrozenDict({'sample_size': 64, 'in_channels': 9, 'out_channels': 4, 'center_input_sample': False, 'flip_sin_to_cos': True, 'freq_shift': 0, 'down_block_types': ['CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'], 'mid_block_type': 'UNetMidBlock2DCrossAttn', 'up_block_types': ['UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'], 'only_cross_attention': False, 'block_out_channels': [320, 640, 1280, 1280], 'layers_per_block': 2, 'downsample_padding': 1, 'mid_block_scale_factor': 1, 'dropout': 0.0, 'act_fn': 'silu', 'norm_num_groups': 32, 'norm_eps': 1e-05, 'cross_attention_dim': 768, 'transformer_layers_per_block': 1, 'reverse_transformer_layers_per_block': None, 'encoder_hid_dim': None, 'encoder_hid_dim_type': None, 'attention_head_dim': 8, 'num_attention_heads': None, 'dual_cross_attention': False, 'use_linear_projection': False, 'class_embed_type': None, 'addition_embed_type': None, 'addition_time_embed_dim': None, 'num_class_embeds': None, 'upcast_attention': False, 'resnet_time_scale_shift': 'default', 'resnet_skip_time_act': False, 'resnet_out_scale_factor': 1.0, 'time_embedding_type': 'positional', 'time_embedding_dim': None, 'time_embedding_act_fn': None, 'timestep_post_act': None, 'time_cond_proj_dim': None, 'conv_in_kernel': 3, 'conv_out_kernel': 3, 'projection_class_embeddings_input_dim': None, 'attention_type': 'default', 'class_embeddings_concat': False, 'mid_block_only_cross_attention': None, 'cross_attention_norm': None, 'addition_embed_type_num_heads': 64, '_class_name': 'UNet2DConditionModel', '_diffusers_version': '0.32.2', '_name_or_path': '../lesion-inpating-dreambooth-model-new/unet'}) expects 9 but received `num_channels_latents`: 4 + `num_channels_mask`: 3 + `num_channels_masked_image`: 4 = 11. Please verify the config of `pipeline.unet` or your `mask_image` or `image` input.