In [2]:
import PIL.Image
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio
import os

from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import torch.optim.lr_scheduler as lr_scheduler
from torchinfo import summary
from pytorch_msssim import ssim, ms_ssim

from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, ZPlus, Pass, Norm
from util.xai_lrp import xai_zennit, show_attributions

from data.dataset import ImageDataset, CelebHQAttrDataset
from classifier.train_classifier import LinearClassifier, ResNet50Classifier

# ------------------- Activation checkpointed UNet -------------------
class CheckpointedUNetWrapper(torch.nn.Module):
    def __init__(self, model):
        super(CheckpointedUNetWrapper, self).__init__()
        self.model = model

    def checkpointed_forward(self, module, *inputs):
        def custom_forward(*inputs):
            return module(*inputs)
        return checkpoint(custom_forward, *inputs)

    def forward(self, sample, timestep):

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.model.time_proj(timesteps)
        #t_emb = t_emb.to(dtype=self.dtype)
        emb = self.model.time_embedding(t_emb)

        # 2. pre-process
        skip_sample = sample
        sample = self.model.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.model.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = self.checkpointed_forward(downsample_block, sample, emb, skip_sample)
            else:
                sample, res_samples = self.checkpointed_forward(downsample_block, sample, emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.checkpointed_forward(self.model.mid_block, sample, emb)

        # 5. up
        skip_sample = None
        for upsample_block in self.model.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb, skip_sample)
            else:
                sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb)

        # 6. post-process
        sample = self.model.conv_norm_out(sample)
        sample = self.model.conv_act(sample)
        sample = self.model.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        return {"sample": sample}

# ------------------- Loss functions -------------------
def classifier_loss(classifier, images, targets, idx):
    preds = classifier(images)
    if idx % 2 == 0:
        print(f"Classifier prediction: {preds[0][cls_id]}")
    targets = torch.tensor(targets).to(device)
    error = torch.abs(preds[0][cls_id] - targets).mean()
    preds_binary = torch.sigmoid(preds[0][cls_id]) > 0.5

    return error, preds_binary

def minDist_loss(counterfactual_images, original_images):
    # MAE distance loss
    error = torch.abs(counterfactual_images - original_images).mean()

    # ssim distance loss
    # normalize images
    # original_images = (original_images + 1) / 2
    # counterfactual_images = (counterfactual_images + 1) / 2
    # error = 1 - ssim(original_images, counterfactual_images, data_range=1.0, size_average=True)
    return error


# ------------------- Data and Directory Preparation -------------------
data = ImageDataset('/home/dai/GPU-Student-2/Cederic/DataSciPro/data/test_counterfactual', image_size=256, exts=['jpg', 'JPG', 'png'], do_augment=False, sort_names=True)
dataloader = DataLoader(data, batch_size=1, shuffle=False)

directory_names = []
for i, _ in enumerate(dataloader):
    img_index = dataloader.dataset.paths[i].name.split('_')[0]
    directory_name = os.path.join("/home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_test", f'folder_IMG_{img_index}')
    directory_names.append(directory_name)
    os.makedirs(directory_name, exist_ok=True)
    print(f'Created directory: {directory_name}')

# ------------------- Model and Classifier Preparation -------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
cls_type = 'linear'
cls_id =  CelebHQAttrDataset.cls_to_id['Smiling']

unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

unet.to(device)
vqvae.to(device)

checkpointed_unet = CheckpointedUNetWrapper(unet)
#print(unet)
#summary(unet)
summary(checkpointed_unet)

if cls_type == 'linear':    
    classifier = LinearClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b128linear2024-06-02 13:08:28.ckpt",
                                            input_dim = data[0]['img'].shape,
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))
elif cls_type == 'res50':
    classifier = ResNet50Classifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b64res502024-06-02 17:06:41.ckpt",
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))
classifier.to(device)
classifier.eval()


# ------------------- Functionality of the classifier -------------------
all_outputs = []
with torch.no_grad():
    for batch in dataloader:
        inputs = batch['img'].to(classifier.device)
        outputs = classifier(inputs)
        print(outputs[0][cls_id])

        preds_binary = torch.sigmoid(outputs[:, cls_id].cpu()) > 0.5
        all_outputs.append(preds_binary) 
all_outputs = torch.cat(all_outputs, dim=0)
print(all_outputs)


# ------------------- LRP Visualization -------------------
layer_map_lrp_0 = [
    (torch.nn.ReLU, Pass()),  # ignore activations
    (torch.nn.Linear, Epsilon(epsilon=0)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]
layer_map_lrp_zplus = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, ZPlus()),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]
layer_map_lrp_eps = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, Epsilon(epsilon=1)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

# LRP before DVCE generation
for i, batch in enumerate(dataloader):
    inputs = batch['img'].to(classifier.device)
    attr_znt_0 = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_image_pre = show_attributions(directory_names[i], attr_znt_0, title='Pre LRP-0')
    #show_attributions(directory_names[i], attr_znt_eps, title='Pre LRP-EPS')
    #show_attributions(directory_names[i], attr_znt_zplus, title='Pre LRP-Z+')


# ------------------- DDIM forward diffusion -------------------
# adapted from https://huggingface.co/learn/diffusion-course/unit4/2
def invert(
    start_latents,
    num_inference_steps,
    device=device,
):
    latents = start_latents.clone()
    intermediate_latents = []
    scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = reversed(scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        if i >= num_inference_steps - 1:
            continue
        t = timesteps[i]
        latent_model_input = latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = checkpointed_unet(latent_model_input, t)["sample"]

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = scheduler.alphas_cumprod[current_t]
        alpha_t_next = scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)

# ------------------- Latent Noise as parameters -------------------
class LatentNoise(torch.nn.Module):
    """
    The LatentNoise Module makes it easier to update the noise tensor with torch optimizers.
    """

    def __init__(self, noise: torch.Tensor):
        super().__init__()
        self.noise = torch.nn.Parameter(noise)

    def forward(self):
        return self.noise

# ------------------- Reverse diffusion step -------------------
def diffusion_pipe(noise_module: LatentNoise, num_inference_steps):
        z = noise_module()
        for i in range(start_step, num_inference_steps):
            t = scheduler.timesteps[i]
            z = scheduler.scale_model_input(z, t)
            with torch.no_grad():
                noise_pred = checkpointed_unet(z, t)["sample"]
            z = scheduler.step(noise_pred, t, z).prev_sample
            z0 = scheduler.step(noise_pred, t, z).pred_original_sample
        return z, z0

# ------------------- Visualization of image sequences -------------------
def plot_images(images, titles=None, figsize=(50, 5), save_path=None):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(n*5, 5))
    if n == 1:
        axes = [axes]
    os.makedirs(save_path, exist_ok=True)
    for i, img in enumerate(images):
        img.save(f"{save_path}/{i}.png")
    for i, img in enumerate(images):
        axes[i].imshow(img)
        axes[i].axis('off')
        if titles is not None:
            axes[i].set_title(titles[i])
    if save_path:
        plt.savefig(save_path)
    # plt.show()
    plt.close(fig)

# ------------------- Visualization helper functions -------------------
def plot_to_pil(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def tensor_to_pil_image(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    return image


# ------------------- Main Loop -------------------
# setting the hyperparameters
num_inference_steps = 100
num_optimization_steps = 20
start_step = 20

for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    # plot_to_pil(batch['img'])
    with torch.no_grad():
        z = vqvae.encode(batch['img'].to(device))   # encode the image into latent space
    z = z.latents
    
    # plot_to_pil(z)
    # cond = z.view(1,-1)
    # cond = normalize(cond)
    # cond = cond + 0.5 * math.sqrt(512) * classifier.fc1.weight[31].unsqueeze(0)
    # cond = denormalize(cond)
    # z = cond.view(1,3,64,64)
    # dec_z = vqvae.decode(z)[0]
    # plot_to_pil(dec_z)
    
    inverted_latents = invert(z, num_inference_steps)                       # do the ddim scheduler reversed to add noise to the latents
    z = inverted_latents[-(start_step + 1)].unsqueeze(0)                    # use these latents to start the sampling. better performance when using not the last latent sample
    # plot_to_pil(z)
    noise_module = LatentNoise(z.clone()).to(device)                        # convert latent noise to a parameter module for optimization
    noise_module.noise.requires_grad = True
    intermediate_results = [batch['img'].to(device)]                        # list to store the results of the steering
    intermediate_preds = [round(classifier(batch['img'].to(device))[0][cls_id].item(), 5)]
    
    optimizer = torch.optim.Adam(
        noise_module.parameters(), lr=0.01, maximize=False
    )
    learning_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    
    x = torch.zeros_like(z)
    current_loss = float('inf')
    preds_binary = False
    current_pred = 0.0
    i = 0
    # optional while loop for the user study to generate DVCEs with longer range
    # while (current_pred < 1.0) & (i < num_optimization_steps) :
    while (preds_binary == False) & (i < num_optimization_steps) :
            
            optimizer.zero_grad()
            x, x0 = diffusion_pipe(noise_module, num_inference_steps)   # forward the latent noise through the reverse diffusion process
            # plot_to_pil(x)
            decoded_x = vqvae.decode(x)[0]
            # plot_to_pil(decoded_x)
            current_pred = classifier(decoded_x)[0][cls_id]

            if i % 1 == 0:
                intermediate_results.append(decoded_x)
                intermediate_preds.append(round(current_pred.item(), 5))

            loss, preds_binary = classifier_loss(classifier, decoded_x, 1.0, i)
            l1_dist = minDist_loss(decoded_x, batch['img'].to(device))
            # optional other distance metrics: ssim and msssim distance
            # ssim_dist = minDist_loss(decoded_x, batch['img'].to(device))
            loss += l1_dist * 10
            # loss += ssim_dist * 20
            if i % 2 == 0:
                print(i, "loss:", loss.item(), "lr:", learning_scheduler.get_lr(), "l1-dist", l1_dist.item())

            loss.backward()
            optimizer.step()
            learning_scheduler.step()

            current_loss = loss.item()
            i += 1

    with torch.no_grad():
        image = vqvae.decode(x)[0]

    print(f"Diffusion Counterfactual generated with loss: {current_loss} | classifier_prediction: {current_pred} | l1_dist: {l1_dist} | in {i} optimization steps")

    # LRP after DVCE generation
    attr_znt_0 = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_image_post = show_attributions(directory_names[step], attr_znt_0, title='Post LRP-0')
    #show_attributions(directory_names[step], attr_znt_eps, title='Post LRP-EPS')
    #show_attributions(directory_names[step], attr_znt_zplus, title='Post LRP-Z+')
    image.requires_grad = False
    
    # GIFs visualization
    images = [tensor_to_pil_image(tensor) for tensor in intermediate_results]
    gif_path = f"{directory_names[step]}/GIF.gif"
    imageio.mimsave(gif_path, images, format='GIF', duration=10.0, loop=0)  # duration is in seconds

    # image sequence visualization
    row_path = f"{directory_names[step]}/sequence_small"
    plot_images(images, intermediate_preds, save_path=row_path)

    # process original image and DVCE image
    image_processed = image.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image_pil = PIL.Image.fromarray(np.array(image_processed[0] * 255).astype(np.uint8))
    ori_processed = batch['img'].cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    ori_image = PIL.Image.fromarray(np.array(ori_processed[0] * 255).astype(np.uint8))

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(image_pil)
    axs[0].axis('off')
    axs[0].set_title('Diffusion Counterfactual Image')
    axs[1].imshow(ori_image)
    axs[1].axis('off')
    axs[1].set_title('Original Image')
    # plt.show()
    fig.savefig(f'{directory_names[step]}/ori_vs_DCE.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

    image_pil.save(f"{directory_names[step]}/diffCounter_IMG.png")
    print('finish')

Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_test/folder_IMG_28285


The config attributes {'timestep_values': None, 'timesteps': 1000} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


tensor(-0.5018, device='cuda:0')
tensor([False])


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 99/99 [00:02<00:00, 38.39it/s]


Classifier prediction: -0.48073992133140564
0 loss: 1.940827488899231 lr: [0.01] l1-dist 0.046008750796318054
Diffusion Counterfactual generated with loss: 0.8353585004806519 | classifier_prediction: 0.6785747408866882 | l1_dist: 0.05139332264661789 | in 2 optimization steps


100%|██████████| 1/1 [00:07<00:00,  7.57s/it]

finish



