In [16]:
import PIL.Image
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio
import os

from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim, ms_ssim

from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, ZPlus, Pass, Norm

from data.dataset import ImageDataset, CelebHQAttrDataset
from init_classifier import LinearClassifier, VQVAEClassifier, ResNet50Classifier, ViTClassifier
from xai_lrp import xai_zennit, show_attributions


class CheckpointedUNetWrapper(torch.nn.Module):
    def __init__(self, model):
        super(CheckpointedUNetWrapper, self).__init__()
        self.model = model

    def checkpointed_forward(self, module, *inputs):
        def custom_forward(*inputs):
            return module(*inputs)
        return checkpoint(custom_forward, *inputs)

    def forward(self, sample, timestep):

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.model.time_proj(timesteps)
        #t_emb = t_emb.to(dtype=self.dtype)
        emb = self.model.time_embedding(t_emb)

        # 2. pre-process
        skip_sample = sample
        sample = self.model.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.model.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = self.checkpointed_forward(downsample_block, sample, emb, skip_sample)
            else:
                sample, res_samples = self.checkpointed_forward(downsample_block, sample, emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.checkpointed_forward(self.model.mid_block, sample, emb)

        # 5. up
        skip_sample = None
        for upsample_block in self.model.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb, skip_sample)
            else:
                sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb)

        # 6. post-process
        sample = self.model.conv_norm_out(sample)
        sample = self.model.conv_act(sample)
        sample = self.model.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        return {"sample": sample}

def classifier_loss(classifier, images, targets, idx):
    preds = classifier(images)
    if idx % 2 == 0:
        print(f"Classifier prediction: {preds[0][cls_id]}")
    targets = torch.tensor(targets).to(device)
    #error = torch.nn.functional.binary_cross_entropy_with_logits(preds[0][31], targets)
    error = torch.abs(preds[0][cls_id] - targets).mean()
    preds_binary = torch.sigmoid(preds[0][cls_id]) > 0.5

    return error, preds_binary

def minDist_loss(counterfactual_images, original_images):
    # l1 distance
    error = torch.abs(counterfactual_images - original_images).mean()

    # ssim distance
    # normalize images
    #original_images = (original_images + 1) / 2
    #counterfactual_images = (counterfactual_images + 1) / 2
    #error = 1 - ssim(original_images, counterfactual_images, data_range=1.0, size_average=True)
    return error


# data loading with ground truth no smiling
data = ImageDataset('/home/dai/GPU-Student-2/Cederic/DataSciPro/data/misclsData_gt0', image_size=256, exts=['jpg', 'JPG', 'png'], do_augment=False, sort_names=True)
dataloader = DataLoader(data, batch_size=1, shuffle=False)

# create output folders
directory_names = []
for i, _ in enumerate(dataloader):
    img_index = dataloader.dataset.paths[i].name.split('_')[0]
    directory_name = os.path.join("/home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear", f'folder_IMG_{img_index}')
    directory_names.append(directory_name)
    os.makedirs(directory_name, exist_ok=True)
    print(f'Created directory: {directory_name}')

#
device = "cuda" if torch.cuda.is_available() else "cpu"
cls_type = 'linear'
cls_id =  CelebHQAttrDataset.cls_to_id['Smiling']

# load all models
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

unet.to(device)
vqvae.to(device)

checkpointed_unet = CheckpointedUNetWrapper(unet)

# load all models
if cls_type == 'linear':    
    classifier = LinearClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b128linear2024-06-02 13:08:28.ckpt",
                                            input_dim = data[0]['img'].shape,
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))
elif cls_type == 'vqvae':
    classifier = VQVAEClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b32vqvae2024-06-01 08:48:59.ckpt",
                                       num_classes = len(CelebHQAttrDataset.id_to_cls))

elif cls_type == 'res50':
    classifier = ResNet50Classifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b64res502024-06-02 17:06:41.ckpt",
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))


classifier.to(device)
classifier.eval()
# check functionality of classifier
all_outputs = []
with torch.no_grad():
    for batch in dataloader:
        inputs = batch['img'].to(classifier.device)
        outputs = classifier(inputs)
        print(outputs[0][cls_id])

        preds_binary = torch.sigmoid(outputs[:, cls_id].cpu()) > 0.5
        all_outputs.append(preds_binary) 
all_outputs = torch.cat(all_outputs, dim=0)
print(all_outputs)


###### explainable ai lrp
# lrp rules
layer_map_lrp_0 = [
    (torch.nn.ReLU, Pass()),  # ignore activations
    (torch.nn.Linear, Epsilon(epsilon=0)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_zplus = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, ZPlus()),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_eps = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, Epsilon(epsilon=1)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

#before manipulation
for i, batch in enumerate(dataloader):
    inputs = batch['img'].to(classifier.device)
    attr_znt_0 = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[i], attr_znt_0, title='Pre LRP-0')
    #show_attributions(directory_names[i], attr_znt_eps, title='Pre LRP-EPS')
    #show_attributions(directory_names[i], attr_znt_zplus, title='Pre LRP-Z+')
        
## Inversion
def invert(
    start_latents,
    num_inference_steps,
    device=device,
):

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = checkpointed_unet(latent_model_input, t)["sample"]

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = scheduler.alphas_cumprod[current_t]
        alpha_t_next = scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)


class LatentNoise(torch.nn.Module):
    """
    The LatentNoise Module makes it easier to update the noise tensor with torch optimizers.
    """

    def __init__(self, noise: torch.Tensor):
        super().__init__()
        self.noise = torch.nn.Parameter(noise)

    def forward(self):
        return self.noise


def diffusion_pipe(noise_module: LatentNoise, num_inference_steps):
        z = noise_module()
        for i in range(start_step, num_inference_steps):
            t = scheduler.timesteps[i]
            z = scheduler.scale_model_input(z, t)
            with torch.no_grad():
                noise_pred = checkpointed_unet(z, t)["sample"]
            z = scheduler.step(noise_pred, t, z).prev_sample
            z0 = scheduler.step(noise_pred, t, z).pred_original_sample
        return z, z0

def plot_images(images, titles=None, figsize=(50, 5), save_path=None):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(n*5, 5))

    if n == 1:
        axes = [axes]

    #just for image sving
    os.makedirs(save_path, exist_ok=True)

    for i, img in enumerate(images):
        img.save(f"{save_path}/{i}.png")

    #for i, img in enumerate(images):
    #    axes[i].imshow(img)
    #    axes[i].axis('off')
    #    #if titles is not None:
    #    #    axes[i].set_title(titles[i])
#
    #if save_path:
    #    plt.savefig(save_path)
    #plt.show()
    plt.close(fig)

def plot_to_pil(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def tensor_to_pil_image(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    return image

# conditional sampling
num_inference_steps = 100
start_step = 20


for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    #plot_to_pil(batch['img'])
    with torch.no_grad():
        z = vqvae.encode(batch['img'].to(device))   # encode the image in the latent space
    z = z.latents
    #plot_to_pil(z)
    
    #cond = z.view(1,-1)
    #cond = normalize(cond)
    #cond = cond + 0.5 * math.sqrt(512) * classifier.fc1.weight[31].unsqueeze(0)
    #cond = denormalize(cond)
    #z = cond.view(1,3,64,64)
    #dec_z = vqvae.decode(z)[0]
    #plot_to_pil(dec_z)
    
    inverted_latents = invert(z, num_inference_steps)                  # do the ddim scheduler reversed to add noise to the latents
    z = inverted_latents[-(start_step + 1)].unsqueeze(0)                  # use these latents to start the sampling. better performance when using not the last latent sample
    #plot_to_pil(z)
    noise_module = LatentNoise(z.clone()).to(device)                    # convert latent noise to a parameter module for optimization
    noise_module.noise.requires_grad = True
    intermediate_results = [batch['img'].to(device)]   # list to store the results of the steering
    intermediate_preds = [round(classifier(batch['img'].to(device))[0][cls_id].item(), 5)]
    
    optimizer = torch.optim.Adam(
        noise_module.parameters(), lr=0.01, maximize=False # not minimize gradient ascent
    )
    learning_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    
    x = torch.zeros_like(z)
    current_loss = float('inf')
    preds_binary = False
    current_pred = 0.0
    i = 0
    # for the linear classifier it works perfect to break out of the loop if the prediction switches.
    #while (current_pred < 1.0) & (i < 20) :
    while (preds_binary == False) & (i < 20) :
            
            optimizer.zero_grad()
            x, x0 = diffusion_pipe(noise_module, num_inference_steps) # forward
            #plot_to_pil(x)
            decoded_x = vqvae.decode(x)[0]
            #plot_to_pil(decoded_x)
            current_pred = classifier(decoded_x)[0][cls_id]

            if i % 1 == 0:
                intermediate_results.append(decoded_x)
                intermediate_preds.append(round(current_pred.item(), 5))

            loss, preds_binary = classifier_loss(classifier, decoded_x, 2.0, i)
            l1_dist = minDist_loss(decoded_x, batch['img'].to(device))
            #implementing the ssim and msssim distance
            #ssim_dist = minDist_loss(decoded_x, batch['img'].to(device))

            loss += l1_dist * 10
            #loss += ssim_dist * 20
            
            if i % 2 == 0:
                print(i, "loss:", loss.item(), "lr:", learning_scheduler.get_lr(), "l1-dist", l1_dist.item())
                #print(i, "loss:", loss.item())
            loss.backward()
            optimizer.step()
            learning_scheduler.step()

            current_loss = loss.item()
            i += 1
    
    with torch.no_grad():
        image = vqvae.decode(x)[0]

    print(f"Diffusion Counterfactual generated with loss: {current_loss} | classifier_prediction: {current_pred} | l1_dist: {l1_dist} | in {i} optimization steps")
    
    #lrp after manipulation
    attr_znt_0 = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[step], attr_znt_0, title='Post LRP-0')
    #show_attributions(directory_names[step], attr_znt_eps, title='Post LRP-EPS')
    #show_attributions(directory_names[step], attr_znt_zplus, title='Post LRP-Z+')
    image.requires_grad = False
    
    images = [tensor_to_pil_image(tensor) for tensor in intermediate_results]
    gif_path = f"{directory_names[step]}/GIF.gif"
    imageio.mimsave(gif_path, images, format='GIF', duration=10.0, loop=0)  # duration is in seconds

    row_path = f"{directory_names[step]}/sequence_small"
    plot_images(images, intermediate_preds, save_path=row_path)

    # process image
    image_processed = image.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image_pil = PIL.Image.fromarray(np.array(image_processed[0] * 255).astype(np.uint8))
    ori_processed = batch['img'].cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    ori_image = PIL.Image.fromarray(np.array(ori_processed[0] * 255).astype(np.uint8))

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(image_pil)
    axs[0].axis('off')
    axs[0].set_title('Diffusion Counterfactual Image')
    axs[1].imshow(ori_image)
    axs[1].axis('off')
    axs[1].set_title('Original Image')
    #plt.show()
    fig.savefig(f'{directory_names[step]}/ori_vs_DCE.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

    image_pil.save(f"{directory_names[step]}/diffCounter_IMG.png")
    print('finish')

Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_27152
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_27291
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28459
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28510
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28561
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28641
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28749
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28754
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28875
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_line

The config attributes {'timestep_values': None, 'timesteps': 1000} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


tensor(-3.0893, device='cuda:0')
tensor(-2.2242, device='cuda:0')
tensor(-1.4349, device='cuda:0')
tensor(-3.0878, device='cuda:0')
tensor(-0.5016, device='cuda:0')
tensor(-0.9032, device='cuda:0')
tensor(-1.8061, device='cuda:0')
tensor(-1.5006, device='cuda:0')
tensor(-1.5091, device='cuda:0')
tensor(0.5627, device='cuda:0')
tensor(-3.0440, device='cuda:0')
tensor(-1.7235, device='cuda:0')
tensor(-2.6975, device='cuda:0')
tensor(-1.1172, device='cuda:0')
tensor(-0.9103, device='cuda:0')
tensor(0.8591, device='cuda:0')
tensor(-1.3892, device='cuda:0')
tensor(-0.6905, device='cuda:0')
tensor(-1.6558, device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False,  True,
        False, False, False, False, False,  True, False, False, False])


100%|██████████| 99/99 [00:02<00:00, 38.68it/s]


Classifier prediction: -3.2522175312042236
0 loss: 5.699406623840332 lr: [0.01] l1-dist 0.0447189137339592
Classifier prediction: -2.5887632369995117
2 loss: 5.0674028396606445 lr: [0.009931100837462445] l1-dist 0.047863952815532684
Classifier prediction: -2.0944790840148926
4 loss: 4.670379638671875 lr: [0.009774869058090914] l1-dist 0.0575900636613369
Classifier prediction: -1.5859897136688232
6 loss: 4.23384952545166 lr: [0.009543642776065642] l1-dist 0.06478600203990936
Classifier prediction: -1.0579445362091064
8 loss: 3.7961385250091553 lr: [0.009241066670644704] l1-dist 0.07381940633058548
Classifier prediction: -0.43256035447120667
10 loss: 3.2947678565979004 lr: [0.008871910576983217] l1-dist 0.08622074127197266
Classifier prediction: 0.36656710505485535
12 loss: 2.682431697845459 lr: [0.008441994226264132] l1-dist 0.10489988327026367
Diffusion Counterfactual generated with loss: 2.682431697845459 | classifier_prediction: 0.36656710505485535 | l1_dist: 0.10489988327026367 | in

  5%|▌         | 1/19 [00:29<08:53, 29.65s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.61it/s]


Classifier prediction: -2.318789482116699
0 loss: 4.804568290710449 lr: [0.01] l1-dist 0.04857790097594261
Classifier prediction: -1.6148781776428223
2 loss: 4.227785110473633 lr: [0.009931100837462445] l1-dist 0.061290718615055084
Classifier prediction: -1.0618886947631836
4 loss: 3.7591583728790283 lr: [0.009774869058090914] l1-dist 0.06972696632146835
Classifier prediction: -0.47366341948509216
6 loss: 3.306386947631836 lr: [0.009543642776065642] l1-dist 0.08327235281467438
Classifier prediction: 0.17199984192848206
8 loss: 2.8316407203674316 lr: [0.009241066670644704] l1-dist 0.10036405920982361
Diffusion Counterfactual generated with loss: 2.8316407203674316 | classifier_prediction: 0.17199984192848206 | l1_dist: 0.10036405920982361 | in 9 optimization steps


 11%|█         | 2/19 [00:50<07:00, 24.71s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.01it/s]


Classifier prediction: -1.474960207939148
0 loss: 4.118856906890869 lr: [0.01] l1-dist 0.06438963860273361
Classifier prediction: -0.9383646845817566
2 loss: 3.5223708152770996 lr: [0.009931100837462445] l1-dist 0.05840061604976654
Classifier prediction: -0.40189120173454285
4 loss: 3.0802111625671387 lr: [0.009774869058090914] l1-dist 0.06783197820186615
Classifier prediction: 0.04470425099134445
6 loss: 2.707106590270996 lr: [0.009543642776065642] l1-dist 0.07518108934164047
Diffusion Counterfactual generated with loss: 2.707106590270996 | classifier_prediction: 0.04470425099134445 | l1_dist: 0.07518108934164047 | in 7 optimization steps


 16%|█▌        | 3/19 [01:08<05:40, 21.27s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 35.89it/s]


Classifier prediction: -3.186439275741577
0 loss: 5.549551486968994 lr: [0.01] l1-dist 0.03631121292710304
Classifier prediction: -2.553800582885742
2 loss: 4.9681196212768555 lr: [0.009931100837462445] l1-dist 0.04143190383911133
Classifier prediction: -1.9538724422454834
4 loss: 4.52001953125 lr: [0.009774869058090914] l1-dist 0.05661468952894211
Classifier prediction: -1.4631935358047485
6 loss: 4.133827209472656 lr: [0.009543642776065642] l1-dist 0.06706339865922928
Classifier prediction: -0.7307223677635193
8 loss: 3.490025520324707 lr: [0.009241066670644704] l1-dist 0.07593031972646713
Classifier prediction: 0.03512641042470932
10 loss: 2.811046838760376 lr: [0.008871910576983217] l1-dist 0.08461733162403107
Diffusion Counterfactual generated with loss: 2.811046838760376 | classifier_prediction: 0.03512641042470932 | l1_dist: 0.08461733162403107 | in 11 optimization steps


 21%|██        | 4/19 [01:33<05:45, 23.06s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.85it/s]


Classifier prediction: -0.4957011044025421
0 loss: 2.8765904903411865 lr: [0.01] l1-dist 0.03808894008398056
Classifier prediction: 0.1650649607181549
2 loss: 2.517275333404541 lr: [0.009931100837462445] l1-dist 0.06823403388261795
Diffusion Counterfactual generated with loss: 2.517275333404541 | classifier_prediction: 0.1650649607181549 | l1_dist: 0.06823403388261795 | in 3 optimization steps


 26%|██▋       | 5/19 [01:43<04:12, 18.05s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.28it/s]


Classifier prediction: -0.8376410007476807
0 loss: 3.197890281677246 lr: [0.01] l1-dist 0.03602492809295654
Classifier prediction: 0.08676400035619736
2 loss: 2.4435505867004395 lr: [0.009931100837462445] l1-dist 0.05303145572543144
Diffusion Counterfactual generated with loss: 2.4435505867004395 | classifier_prediction: 0.08676400035619736 | l1_dist: 0.05303145572543144 | in 3 optimization steps


 32%|███▏      | 6/19 [01:52<03:14, 14.97s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.41it/s]


Classifier prediction: -1.9193310737609863
0 loss: 4.337862491607666 lr: [0.01] l1-dist 0.04185314103960991
Classifier prediction: -1.170490026473999
2 loss: 3.656827926635742 lr: [0.009931100837462445] l1-dist 0.04863378405570984
Classifier prediction: -0.5744560360908508
4 loss: 3.4269518852233887 lr: [0.009774869058090914] l1-dist 0.08524959534406662
Classifier prediction: 0.06149240583181381
6 loss: 2.6546671390533447 lr: [0.009543642776065642] l1-dist 0.07161596417427063
Diffusion Counterfactual generated with loss: 2.6546671390533447 | classifier_prediction: 0.06149240583181381 | l1_dist: 0.07161596417427063 | in 7 optimization steps


 37%|███▋      | 7/19 [02:08<03:06, 15.54s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.39it/s]


Classifier prediction: -1.6750441789627075
0 loss: 4.079682350158691 lr: [0.01] l1-dist 0.04046384245157242
Classifier prediction: -1.4980902671813965
2 loss: 4.777265548706055 lr: [0.009931100837462445] l1-dist 0.1279175579547882
Diffusion Counterfactual generated with loss: 3.1036183834075928 | classifier_prediction: 0.567772388458252 | l1_dist: 0.16713908314704895 | in 4 optimization steps


 42%|████▏     | 8/19 [02:19<02:34, 14.03s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 37.45it/s]


Classifier prediction: -1.632059097290039
0 loss: 4.041840076446533 lr: [0.01] l1-dist 0.0409780852496624
Classifier prediction: -0.6064081192016602
2 loss: 3.2814807891845703 lr: [0.009931100837462445] l1-dist 0.06750728189945221
Classifier prediction: 0.11454679816961288
4 loss: 2.635000705718994 lr: [0.009774869058090914] l1-dist 0.07495476305484772
Diffusion Counterfactual generated with loss: 2.635000705718994 | classifier_prediction: 0.11454679816961288 | l1_dist: 0.07495476305484772 | in 5 optimization steps


 47%|████▋     | 9/19 [02:32<02:17, 13.73s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.34it/s]


Classifier prediction: 0.5189474821090698
0 loss: 1.802379846572876 lr: [0.01] l1-dist 0.03213272988796234
Diffusion Counterfactual generated with loss: 1.802379846572876 | classifier_prediction: 0.5189474821090698 | l1_dist: 0.03213272988796234 | in 1 optimization steps


 53%|█████▎    | 10/19 [02:37<01:38, 11.00s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.46it/s]


Classifier prediction: -3.263070583343506
0 loss: 5.676961898803711 lr: [0.01] l1-dist 0.04138915613293648
Classifier prediction: -2.7082138061523438
2 loss: 5.157123565673828 lr: [0.009931100837462445] l1-dist 0.04489099979400635
Classifier prediction: -2.270517587661743
4 loss: 4.784476280212402 lr: [0.009774869058090914] l1-dist 0.05139588564634323
Classifier prediction: -1.84171462059021
6 loss: 4.428225994110107 lr: [0.009543642776065642] l1-dist 0.05865113437175751
Classifier prediction: -1.4122196435928345
8 loss: 4.080976486206055 lr: [0.009241066670644704] l1-dist 0.06687569618225098
Classifier prediction: -1.0332292318344116
10 loss: 3.7843105792999268 lr: [0.008871910576983217] l1-dist 0.0751081183552742
Classifier prediction: 3.389502763748169
12 loss: 2.959108829498291 lr: [0.008441994226264132] l1-dist 0.156960591673851
Diffusion Counterfactual generated with loss: 2.959108829498291 | classifier_prediction: 3.389502763748169 | l1_dist: 0.156960591673851 | in 13 optimizati

 58%|█████▊    | 11/19 [03:05<02:10, 16.34s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.51it/s]


Classifier prediction: -1.825016736984253
0 loss: 4.185042858123779 lr: [0.01] l1-dist 0.036002591252326965
Classifier prediction: -1.3805158138275146
2 loss: 3.9582600593566895 lr: [0.009931100837462445] l1-dist 0.057774417102336884
Classifier prediction: -1.067775011062622
4 loss: 3.786181926727295 lr: [0.009774869058090914] l1-dist 0.07184068858623505
Classifier prediction: -0.6110727787017822
6 loss: 3.450578212738037 lr: [0.009543642776065642] l1-dist 0.08395053446292877
Classifier prediction: -0.04625455290079117
8 loss: 2.9322566986083984 lr: [0.009241066670644704] l1-dist 0.08860021829605103
Diffusion Counterfactual generated with loss: 2.7515997886657715 | classifier_prediction: 0.1602388322353363 | l1_dist: 0.0911838635802269 | in 10 optimization steps


 63%|██████▎   | 12/19 [03:28<02:07, 18.24s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.88it/s]


Classifier prediction: -2.889946460723877
0 loss: 5.2585601806640625 lr: [0.01] l1-dist 0.036861352622509
Classifier prediction: -1.5254398584365845
2 loss: 4.340208530426025 lr: [0.009931100837462445] l1-dist 0.08147688955068588
Classifier prediction: -0.14848747849464417
4 loss: 3.214406728744507 lr: [0.009774869058090914] l1-dist 0.10659191012382507
Diffusion Counterfactual generated with loss: 2.835378646850586 | classifier_prediction: 0.30666348338127136 | l1_dist: 0.11420422792434692 | in 6 optimization steps


 68%|██████▊   | 13/19 [03:43<01:43, 17.25s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.15it/s]


Classifier prediction: -1.1418135166168213
0 loss: 3.5043153762817383 lr: [0.01] l1-dist 0.03625018522143364
Classifier prediction: -0.5310899615287781
2 loss: 3.1307365894317627 lr: [0.009931100837462445] l1-dist 0.059964656829833984
Classifier prediction: -0.027157403528690338
4 loss: 2.724557638168335 lr: [0.009774869058090914] l1-dist 0.06974003463983536
Diffusion Counterfactual generated with loss: 2.527944564819336 | classifier_prediction: 0.20703992247581482 | l1_dist: 0.0734984427690506 | in 6 optimization steps


 74%|███████▎  | 14/19 [03:58<01:23, 16.68s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.87it/s]


Classifier prediction: -0.643534779548645
0 loss: 3.346733570098877 lr: [0.01] l1-dist 0.07031987607479095
Diffusion Counterfactual generated with loss: 1.7437050342559814 | classifier_prediction: 1.39066481590271 | l1_dist: 0.1134369820356369 | in 2 optimization steps


 79%|███████▉  | 15/19 [04:05<00:54, 13.75s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.66it/s]


Classifier prediction: 0.886259138584137
0 loss: 1.829973816871643 lr: [0.01] l1-dist 0.0716232880949974
Diffusion Counterfactual generated with loss: 1.829973816871643 | classifier_prediction: 0.886259138584137 | l1_dist: 0.0716232880949974 | in 1 optimization steps


 84%|████████▍ | 16/19 [04:10<00:33, 11.10s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.97it/s]


Classifier prediction: -1.447609782218933
0 loss: 3.9278740882873535 lr: [0.01] l1-dist 0.048026423901319504
Classifier prediction: -0.045012153685092926
2 loss: 2.6891040802001953 lr: [0.009931100837462445] l1-dist 0.06440918147563934
Diffusion Counterfactual generated with loss: 2.326993942260742 | classifier_prediction: 0.4459528625011444 | l1_dist: 0.07729469239711761 | in 4 optimization steps


 89%|████████▉ | 17/19 [04:21<00:22, 11.02s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.36it/s]


Classifier prediction: -0.7701396346092224
0 loss: 3.167670726776123 lr: [0.01] l1-dist 0.03975310176610947
Classifier prediction: -0.21880844235420227
2 loss: 2.6441030502319336 lr: [0.009931100837462445] l1-dist 0.042529474943876266
Classifier prediction: 0.19607123732566833
4 loss: 2.4090850353240967 lr: [0.009774869058090914] l1-dist 0.0605156235396862
Diffusion Counterfactual generated with loss: 2.4090850353240967 | classifier_prediction: 0.19607123732566833 | l1_dist: 0.0605156235396862 | in 5 optimization steps


 95%|█████████▍| 18/19 [04:34<00:11, 11.57s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.59it/s]


Classifier prediction: -1.8123350143432617
0 loss: 4.136782646179199 lr: [0.01] l1-dist 0.032444775104522705
Classifier prediction: -0.7637918591499329
2 loss: 3.321521759033203 lr: [0.009931100837462445] l1-dist 0.05577300116419792
Classifier prediction: -0.052337922155857086
4 loss: 2.9469642639160156 lr: [0.009774869058090914] l1-dist 0.08946263790130615
Diffusion Counterfactual generated with loss: 2.7157416343688965 | classifier_prediction: 0.3060149848461151 | l1_dist: 0.10217565298080444 | in 6 optimization steps


100%|██████████| 19/19 [04:49<00:00, 15.23s/it]

finish



