In [2]:
import PIL.Image
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio
import os

from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import torch.optim.lr_scheduler as lr_scheduler
from pytorch_msssim import ssim, ms_ssim

from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, ZPlus, Pass, Norm

from data.dataset import ImageDataset, CelebHQAttrDataset
from classifier.train_classifier import LinearClassifier, ResNet50Classifier
from util.xai_lrp import xai_zennit, show_attributions


class CheckpointedUNetWrapper(torch.nn.Module):
    def __init__(self, model):
        super(CheckpointedUNetWrapper, self).__init__()
        self.model = model

    def checkpointed_forward(self, module, *inputs):
        def custom_forward(*inputs):
            return module(*inputs)
        return checkpoint(custom_forward, *inputs)

    def forward(self, sample, timestep):

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.model.time_proj(timesteps)
        #t_emb = t_emb.to(dtype=self.dtype)
        emb = self.model.time_embedding(t_emb)

        # 2. pre-process
        skip_sample = sample
        sample = self.model.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.model.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = self.checkpointed_forward(downsample_block, sample, emb, skip_sample)
            else:
                sample, res_samples = self.checkpointed_forward(downsample_block, sample, emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.checkpointed_forward(self.model.mid_block, sample, emb)

        # 5. up
        skip_sample = None
        for upsample_block in self.model.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb, skip_sample)
            else:
                sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb)

        # 6. post-process
        sample = self.model.conv_norm_out(sample)
        sample = self.model.conv_act(sample)
        sample = self.model.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        return {"sample": sample}

def classifier_loss(classifier, images, targets, idx):
    preds = classifier(images)
    if idx % 2 == 0:
        print(f"Classifier prediction: {preds[0][cls_id]}")
    targets = torch.tensor(targets).to(device)
    #error = torch.nn.functional.binary_cross_entropy_with_logits(preds[0][31], targets)
    error = torch.abs(preds[0][cls_id] - targets).mean()
    preds_binary = torch.sigmoid(preds[0][cls_id]) > 0.5

    return error, preds_binary

def minDist_loss(counterfactual_images, original_images):
    # l1 distance
    error = torch.abs(counterfactual_images - original_images).mean()

    # ssim distance
    # normalize images
    #original_images = (original_images + 1) / 2
    #counterfactual_images = (counterfactual_images + 1) / 2
    #error = 1 - ssim(original_images, counterfactual_images, data_range=1.0, size_average=True)
    return error


# data loading with ground truth no smiling
data = ImageDataset('/home/dai/GPU-Student-2/Cederic/DataSciPro/data/misclsData_gt0', image_size=256, exts=['jpg', 'JPG', 'png'], do_augment=False, sort_names=True)
dataloader = DataLoader(data, batch_size=1, shuffle=False)

# create output folders
directory_names = []
for i, _ in enumerate(dataloader):
    img_index = dataloader.dataset.paths[i].name.split('_')[0]
    directory_name = os.path.join("/home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear", f'folder_IMG_{img_index}')
    directory_names.append(directory_name)
    os.makedirs(directory_name, exist_ok=True)
    print(f'Created directory: {directory_name}')

#
device = "cuda" if torch.cuda.is_available() else "cpu"
cls_type = 'linear'
cls_id =  CelebHQAttrDataset.cls_to_id['Smiling']

# load all models
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

unet.to(device)
vqvae.to(device)

checkpointed_unet = CheckpointedUNetWrapper(unet)

# load all models
if cls_type == 'linear':    
    classifier = LinearClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b128linear2024-06-02 13:08:28.ckpt",
                                            input_dim = data[0]['img'].shape,
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))
elif cls_type == 'res50':
    classifier = ResNet50Classifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b64res502024-06-02 17:06:41.ckpt",
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))


classifier.to(device)
classifier.eval()
# check functionality of classifier
all_outputs = []
with torch.no_grad():
    for batch in dataloader:
        inputs = batch['img'].to(classifier.device)
        outputs = classifier(inputs)
        print(outputs[0][cls_id])

        preds_binary = torch.sigmoid(outputs[:, cls_id].cpu()) > 0.5
        all_outputs.append(preds_binary) 
all_outputs = torch.cat(all_outputs, dim=0)
print(all_outputs)


###### explainable ai lrp
# lrp rules
layer_map_lrp_0 = [
    (torch.nn.ReLU, Pass()),  # ignore activations
    (torch.nn.Linear, Epsilon(epsilon=0)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_zplus = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, ZPlus()),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_eps = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, Epsilon(epsilon=1)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

#before manipulation
for i, batch in enumerate(dataloader):
    inputs = batch['img'].to(classifier.device)
    attr_znt_0 = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[i], attr_znt_0, title='Pre LRP-0')
    #show_attributions(directory_names[i], attr_znt_eps, title='Pre LRP-EPS')
    #show_attributions(directory_names[i], attr_znt_zplus, title='Pre LRP-Z+')
        
## Inversion
def invert(
    start_latents,
    num_inference_steps,
    device=device,
):

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = checkpointed_unet(latent_model_input, t)["sample"]

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = scheduler.alphas_cumprod[current_t]
        alpha_t_next = scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)


class LatentNoise(torch.nn.Module):
    """
    The LatentNoise Module makes it easier to update the noise tensor with torch optimizers.
    """

    def __init__(self, noise: torch.Tensor):
        super().__init__()
        self.noise = torch.nn.Parameter(noise)

    def forward(self):
        return self.noise


def diffusion_pipe(noise_module: LatentNoise, num_inference_steps):
        z = noise_module()
        for i in range(start_step, num_inference_steps):
            t = scheduler.timesteps[i]
            z = scheduler.scale_model_input(z, t)
            with torch.no_grad():
                noise_pred = checkpointed_unet(z, t)["sample"]
            z = scheduler.step(noise_pred, t, z).prev_sample
            z0 = scheduler.step(noise_pred, t, z).pred_original_sample
        return z, z0

def plot_images(images, titles=None, figsize=(50, 5), save_path=None):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(n*5, 5))

    if n == 1:
        axes = [axes]

    #just for image sving
    os.makedirs(save_path, exist_ok=True)

    for i, img in enumerate(images):
        img.save(f"{save_path}/{i}.png")

    #for i, img in enumerate(images):
    #    axes[i].imshow(img)
    #    axes[i].axis('off')
    #    #if titles is not None:
    #    #    axes[i].set_title(titles[i])
#
    #if save_path:
    #    plt.savefig(save_path)
    #plt.show()
    plt.close(fig)

def plot_to_pil(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def tensor_to_pil_image(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    return image

# conditional sampling
num_inference_steps = 100
start_step = 20


for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    #plot_to_pil(batch['img'])
    with torch.no_grad():
        z = vqvae.encode(batch['img'].to(device))   # encode the image in the latent space
    z = z.latents
    #plot_to_pil(z)
    
    #cond = z.view(1,-1)
    #cond = normalize(cond)
    #cond = cond + 0.5 * math.sqrt(512) * classifier.fc1.weight[31].unsqueeze(0)
    #cond = denormalize(cond)
    #z = cond.view(1,3,64,64)
    #dec_z = vqvae.decode(z)[0]
    #plot_to_pil(dec_z)
    
    inverted_latents = invert(z, num_inference_steps)                  # do the ddim scheduler reversed to add noise to the latents
    z = inverted_latents[-(start_step + 1)].unsqueeze(0)                  # use these latents to start the sampling. better performance when using not the last latent sample
    #plot_to_pil(z)
    noise_module = LatentNoise(z.clone()).to(device)                    # convert latent noise to a parameter module for optimization
    noise_module.noise.requires_grad = True
    intermediate_results = [batch['img'].to(device)]   # list to store the results of the steering
    intermediate_preds = [round(classifier(batch['img'].to(device))[0][cls_id].item(), 5)]
    
    optimizer = torch.optim.Adam(
        noise_module.parameters(), lr=0.01, maximize=False # not minimize gradient ascent
    )
    learning_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    
    x = torch.zeros_like(z)
    current_loss = float('inf')
    preds_binary = False
    current_pred = 0.0
    i = 0
    # for the linear classifier it works perfect to break out of the loop if the prediction switches.
    #while (current_pred < 1.0) & (i < 20) :
    while (preds_binary == False) & (i < 20) :
            
            optimizer.zero_grad()
            x, x0 = diffusion_pipe(noise_module, num_inference_steps) # forward
            #plot_to_pil(x)
            decoded_x = vqvae.decode(x)[0]
            #plot_to_pil(decoded_x)
            current_pred = classifier(decoded_x)[0][cls_id]

            if i % 1 == 0:
                intermediate_results.append(decoded_x)
                intermediate_preds.append(round(current_pred.item(), 5))

            loss, preds_binary = classifier_loss(classifier, decoded_x, 2.0, i)
            l1_dist = minDist_loss(decoded_x, batch['img'].to(device))
            #implementing the ssim and msssim distance
            #ssim_dist = minDist_loss(decoded_x, batch['img'].to(device))

            loss += l1_dist * 10
            #loss += ssim_dist * 20
            
            if i % 2 == 0:
                print(i, "loss:", loss.item(), "lr:", learning_scheduler.get_lr(), "l1-dist", l1_dist.item())
                #print(i, "loss:", loss.item())
            loss.backward()
            optimizer.step()
            learning_scheduler.step()

            current_loss = loss.item()
            i += 1
    
    with torch.no_grad():
        image = vqvae.decode(x)[0]

    print(f"Diffusion Counterfactual generated with loss: {current_loss} | classifier_prediction: {current_pred} | l1_dist: {l1_dist} | in {i} optimization steps")
    
    #lrp after manipulation
    attr_znt_0 = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[step], attr_znt_0, title='Post LRP-0')
    #show_attributions(directory_names[step], attr_znt_eps, title='Post LRP-EPS')
    #show_attributions(directory_names[step], attr_znt_zplus, title='Post LRP-Z+')
    image.requires_grad = False
    
    images = [tensor_to_pil_image(tensor) for tensor in intermediate_results]
    gif_path = f"{directory_names[step]}/GIF.gif"
    imageio.mimsave(gif_path, images, format='GIF', duration=10.0, loop=0)  # duration is in seconds

    row_path = f"{directory_names[step]}/sequence_small"
    plot_images(images, intermediate_preds, save_path=row_path)

    # process image
    image_processed = image.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image_pil = PIL.Image.fromarray(np.array(image_processed[0] * 255).astype(np.uint8))
    ori_processed = batch['img'].cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    ori_image = PIL.Image.fromarray(np.array(ori_processed[0] * 255).astype(np.uint8))

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(image_pil)
    axs[0].axis('off')
    axs[0].set_title('Diffusion Counterfactual Image')
    axs[1].imshow(ori_image)
    axs[1].axis('off')
    axs[1].set_title('Original Image')
    #plt.show()
    fig.savefig(f'{directory_names[step]}/ori_vs_DCE.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

    image_pil.save(f"{directory_names[step]}/diffCounter_IMG.png")
    print('finish')

Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_27152
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_27291
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28459
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28510
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28561
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28641
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28749
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28754
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_linear/folder_IMG_28875
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_bad_line

The config attributes {'timestep_values': None, 'timesteps': 1000} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


tensor(-3.0893, device='cuda:0')
tensor(-2.2242, device='cuda:0')
tensor(-1.4349, device='cuda:0')
tensor(-3.0878, device='cuda:0')
tensor(-0.5016, device='cuda:0')
tensor(-0.9032, device='cuda:0')
tensor(-1.8061, device='cuda:0')
tensor(-1.5006, device='cuda:0')
tensor(-1.5091, device='cuda:0')
tensor(-3.0440, device='cuda:0')
tensor(-1.7235, device='cuda:0')
tensor(-2.6975, device='cuda:0')
tensor(-1.1172, device='cuda:0')
tensor(-0.9103, device='cuda:0')
tensor(-1.3892, device='cuda:0')
tensor(-0.6905, device='cuda:0')
tensor(-1.6558, device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False])


100%|██████████| 99/99 [00:02<00:00, 41.40it/s]


Classifier prediction: -3.2522175312042236
0 loss: 5.699406623840332 lr: [0.01] l1-dist 0.0447189137339592
Classifier prediction: -2.5887632369995117
2 loss: 5.0674028396606445 lr: [0.009931100837462445] l1-dist 0.047863952815532684
Classifier prediction: -2.0948095321655273
4 loss: 4.670788764953613 lr: [0.009774869058090914] l1-dist 0.057597946375608444
Classifier prediction: -1.5820859670639038
6 loss: 4.230411052703857 lr: [0.009543642776065642] l1-dist 0.06483250856399536
Classifier prediction: -1.0531268119812012
8 loss: 3.791684150695801 lr: [0.009241066670644704] l1-dist 0.07385572046041489
Classifier prediction: -0.42020437121391296
10 loss: 3.2786717414855957 lr: [0.008871910576983217] l1-dist 0.08584672212600708
Classifier prediction: 0.3648047149181366
12 loss: 2.6706604957580566 lr: [0.008441994226264132] l1-dist 0.10354652255773544
Diffusion Counterfactual generated with loss: 2.6706604957580566 | classifier_prediction: 0.3648047149181366 | l1_dist: 0.10354652255773544 | 

  6%|▌         | 1/17 [00:28<07:41, 28.86s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.92it/s]


Classifier prediction: -2.318789482116699
0 loss: 4.804568290710449 lr: [0.01] l1-dist 0.04857790097594261
Classifier prediction: -1.6162564754486084
2 loss: 4.228621006011963 lr: [0.009931100837462445] l1-dist 0.0612364336848259
Classifier prediction: -1.0625545978546143
4 loss: 3.7597155570983887 lr: [0.009774869058090914] l1-dist 0.06971611082553864
Classifier prediction: -0.4725781977176666
6 loss: 3.3060693740844727 lr: [0.009543642776065642] l1-dist 0.08334910869598389
Classifier prediction: 0.15730783343315125
8 loss: 2.8465020656585693 lr: [0.009241066670644704] l1-dist 0.10038098692893982
Diffusion Counterfactual generated with loss: 2.8465020656585693 | classifier_prediction: 0.15730783343315125 | l1_dist: 0.10038098692893982 | in 9 optimization steps


 12%|█▏        | 2/17 [00:49<05:59, 23.96s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.86it/s]


Classifier prediction: -1.474960207939148
0 loss: 4.118856906890869 lr: [0.01] l1-dist 0.06438963860273361
Classifier prediction: -0.9381029009819031
2 loss: 3.522012233734131 lr: [0.009931100837462445] l1-dist 0.058390937745571136
Classifier prediction: -0.3931589722633362
4 loss: 3.070895195007324 lr: [0.009774869058090914] l1-dist 0.06777362525463104
Classifier prediction: 0.05027990788221359
6 loss: 2.6999075412750244 lr: [0.009543642776065642] l1-dist 0.0750187337398529
Diffusion Counterfactual generated with loss: 2.6999075412750244 | classifier_prediction: 0.05027990788221359 | l1_dist: 0.0750187337398529 | in 7 optimization steps


 18%|█▊        | 3/17 [01:06<04:48, 20.62s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.85it/s]


Classifier prediction: -3.186439275741577
0 loss: 5.549551486968994 lr: [0.01] l1-dist 0.03631121292710304
Classifier prediction: -2.5537922382354736
2 loss: 4.967963218688965 lr: [0.009931100837462445] l1-dist 0.04141714423894882
Classifier prediction: -1.954103946685791
4 loss: 4.5201263427734375 lr: [0.009774869058090914] l1-dist 0.056602220982313156
Classifier prediction: -1.4594134092330933
6 loss: 4.130114555358887 lr: [0.009543642776065642] l1-dist 0.06707008183002472
Classifier prediction: -0.7316893935203552
8 loss: 3.4909865856170654 lr: [0.009241066670644704] l1-dist 0.07592970877885818
Classifier prediction: 0.043284378945827484
10 loss: 2.802475690841675 lr: [0.008871910576983217] l1-dist 0.08457601070404053
Diffusion Counterfactual generated with loss: 2.802475690841675 | classifier_prediction: 0.043284378945827484 | l1_dist: 0.08457601070404053 | in 11 optimization steps


 24%|██▎       | 4/17 [01:30<04:47, 22.15s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.49it/s]


Classifier prediction: -0.4957011044025421
0 loss: 2.8765904903411865 lr: [0.01] l1-dist 0.03808894008398056
Classifier prediction: 0.16523423790931702
2 loss: 2.517380952835083 lr: [0.009931100837462445] l1-dist 0.0682615116238594
Diffusion Counterfactual generated with loss: 2.517380952835083 | classifier_prediction: 0.16523423790931702 | l1_dist: 0.0682615116238594 | in 3 optimization steps


 29%|██▉       | 5/17 [01:39<03:27, 17.32s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.53it/s]


Classifier prediction: -0.8376410007476807
0 loss: 3.197890281677246 lr: [0.01] l1-dist 0.03602492809295654
Classifier prediction: 0.08710312098264694
2 loss: 2.442864418029785 lr: [0.009931100837462445] l1-dist 0.05299675464630127
Diffusion Counterfactual generated with loss: 2.442864418029785 | classifier_prediction: 0.08710312098264694 | l1_dist: 0.05299675464630127 | in 3 optimization steps


 35%|███▌      | 6/17 [01:48<02:39, 14.50s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 50.81it/s]


Classifier prediction: -1.9193310737609863
0 loss: 4.337862491607666 lr: [0.01] l1-dist 0.04185314103960991
Classifier prediction: -1.1707054376602173
2 loss: 3.65706205368042 lr: [0.009931100837462445] l1-dist 0.04863566905260086
Classifier prediction: -0.5746570229530334
4 loss: 3.4270753860473633 lr: [0.009774869058090914] l1-dist 0.08524185419082642
Classifier prediction: 0.06140168756246567
6 loss: 2.6542112827301025 lr: [0.009543642776065642] l1-dist 0.0715613067150116
Diffusion Counterfactual generated with loss: 2.6542112827301025 | classifier_prediction: 0.06140168756246567 | l1_dist: 0.0715613067150116 | in 7 optimization steps


 41%|████      | 7/17 [02:02<02:23, 14.35s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 49.71it/s]


Classifier prediction: -1.6750441789627075
0 loss: 4.079682350158691 lr: [0.01] l1-dist 0.04046384245157242
Classifier prediction: -1.4992420673370361
2 loss: 4.778306007385254 lr: [0.009931100837462445] l1-dist 0.12790636718273163
Diffusion Counterfactual generated with loss: 3.102604866027832 | classifier_prediction: 0.5691964626312256 | l1_dist: 0.1671801209449768 | in 4 optimization steps


 47%|████▋     | 8/17 [02:11<01:54, 12.72s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 49.58it/s]


Classifier prediction: -1.632059097290039
0 loss: 4.041840076446533 lr: [0.01] l1-dist 0.0409780852496624
Classifier prediction: -0.6066117882728577
2 loss: 3.2815661430358887 lr: [0.009931100837462445] l1-dist 0.0674954280257225
Classifier prediction: 0.11276825517416
4 loss: 2.6370694637298584 lr: [0.009774869058090914] l1-dist 0.07498377561569214
Diffusion Counterfactual generated with loss: 2.6370694637298584 | classifier_prediction: 0.11276825517416 | l1_dist: 0.07498377561569214 | in 5 optimization steps


 53%|█████▎    | 9/17 [02:22<01:37, 12.16s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 49.92it/s]


Classifier prediction: -3.263070583343506
0 loss: 5.676961898803711 lr: [0.01] l1-dist 0.04138915613293648
Classifier prediction: -2.706984519958496
2 loss: 5.155869483947754 lr: [0.009931100837462445] l1-dist 0.04488849639892578
Classifier prediction: -2.271000385284424
4 loss: 4.784661293029785 lr: [0.009774869058090914] l1-dist 0.05136607214808464
Classifier prediction: -1.8414192199707031
6 loss: 4.42793083190918 lr: [0.009543642776065642] l1-dist 0.05865117162466049
Classifier prediction: -1.4127620458602905
8 loss: 4.081504821777344 lr: [0.009241066670644704] l1-dist 0.06687428802251816
Classifier prediction: -1.0359368324279785
10 loss: 3.7867374420166016 lr: [0.008871910576983217] l1-dist 0.07508005946874619
Classifier prediction: 3.3919155597686768
12 loss: 2.9619665145874023 lr: [0.008441994226264132] l1-dist 0.15700508654117584
Diffusion Counterfactual generated with loss: 2.9619665145874023 | classifier_prediction: 3.3919155597686768 | l1_dist: 0.15700508654117584 | in 13 o

 59%|█████▉    | 10/17 [02:46<01:51, 15.94s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 49.90it/s]


Classifier prediction: -1.825016736984253
0 loss: 4.185042858123779 lr: [0.01] l1-dist 0.036002591252326965
Classifier prediction: -1.3805158138275146
2 loss: 3.9582600593566895 lr: [0.009931100837462445] l1-dist 0.057774417102336884
Classifier prediction: -1.0711637735366821
4 loss: 3.7886581420898438 lr: [0.009774869058090914] l1-dist 0.07174943387508392
Classifier prediction: -0.6118988394737244
6 loss: 3.451056480407715 lr: [0.009543642776065642] l1-dist 0.08391577005386353
Classifier prediction: -0.047563232481479645
8 loss: 2.9331984519958496 lr: [0.009241066670644704] l1-dist 0.08856351673603058
Diffusion Counterfactual generated with loss: 2.7423624992370605 | classifier_prediction: 0.17037460207939148 | l1_dist: 0.0912737250328064 | in 10 optimization steps


 65%|██████▍   | 11/17 [03:06<01:42, 17.00s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 42.87it/s]


Classifier prediction: -2.889946460723877
0 loss: 5.2585601806640625 lr: [0.01] l1-dist 0.036861352622509
Classifier prediction: -1.5258978605270386
2 loss: 4.340461730957031 lr: [0.009931100837462445] l1-dist 0.08145636320114136
Classifier prediction: -0.1502390205860138
4 loss: 3.2153098583221436 lr: [0.009774869058090914] l1-dist 0.10650709271430969
Diffusion Counterfactual generated with loss: 2.837104320526123 | classifier_prediction: 0.3048820197582245 | l1_dist: 0.11419864743947983 | in 6 optimization steps


 71%|███████   | 12/17 [03:19<01:18, 15.77s/it]

finish


100%|██████████| 99/99 [00:01<00:00, 49.87it/s]


Classifier prediction: -1.1418135166168213
0 loss: 3.5043153762817383 lr: [0.01] l1-dist 0.03625018522143364
Classifier prediction: -0.5309040546417236
2 loss: 3.130749225616455 lr: [0.009931100837462445] l1-dist 0.05998452752828598
Classifier prediction: -0.026737965643405914
4 loss: 2.7229793071746826 lr: [0.009774869058090914] l1-dist 0.06962413340806961
Diffusion Counterfactual generated with loss: 2.527459144592285 | classifier_prediction: 0.2074051797389984 | l1_dist: 0.07348643243312836 | in 6 optimization steps


 76%|███████▋  | 13/17 [03:31<00:59, 14.82s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 49.38it/s]


Classifier prediction: -0.643534779548645
0 loss: 3.346733570098877 lr: [0.01] l1-dist 0.07031987607479095
Diffusion Counterfactual generated with loss: 1.7437050342559814 | classifier_prediction: 1.39066481590271 | l1_dist: 0.1134369820356369 | in 2 optimization steps


 82%|████████▏ | 14/17 [03:37<00:36, 12.12s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 49.35it/s]


Classifier prediction: -1.447609782218933
0 loss: 3.9278740882873535 lr: [0.01] l1-dist 0.048026423901319504
Classifier prediction: -0.043774642050266266
2 loss: 2.6878087520599365 lr: [0.009931100837462445] l1-dist 0.06440342217683792
Diffusion Counterfactual generated with loss: 2.326993942260742 | classifier_prediction: 0.4459528625011444 | l1_dist: 0.07729469239711761 | in 4 optimization steps


 88%|████████▊ | 15/17 [03:47<00:22, 11.26s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 49.25it/s]


Classifier prediction: -0.7701396346092224
0 loss: 3.167670726776123 lr: [0.01] l1-dist 0.03975310176610947
Classifier prediction: -0.21826401352882385
2 loss: 2.6434991359710693 lr: [0.009931100837462445] l1-dist 0.04252350330352783
Classifier prediction: 0.1932341754436493
4 loss: 2.4113826751708984 lr: [0.009774869058090914] l1-dist 0.060461681336164474
Diffusion Counterfactual generated with loss: 2.4113826751708984 | classifier_prediction: 0.1932341754436493 | l1_dist: 0.060461681336164474 | in 5 optimization steps


 94%|█████████▍| 16/17 [03:58<00:11, 11.19s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 49.19it/s]


Classifier prediction: -1.8123350143432617
0 loss: 4.136782646179199 lr: [0.01] l1-dist 0.032444775104522705
Classifier prediction: -0.7623562216758728
2 loss: 3.3204517364501953 lr: [0.009931100837462445] l1-dist 0.05580954998731613
Classifier prediction: -0.04490021616220474
4 loss: 2.9405124187469482 lr: [0.009774869058090914] l1-dist 0.08956122398376465
Diffusion Counterfactual generated with loss: 2.7100725173950195 | classifier_prediction: 0.31074413657188416 | l1_dist: 0.10208168625831604 | in 6 optimization steps


100%|██████████| 17/17 [04:10<00:00, 14.76s/it]

finish



