In [3]:
import PIL.Image
import numpy as np
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import imageio
import os

from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
from torch.utils.data import DataLoader
from torch.utils.checkpoint import checkpoint
import torch.optim.lr_scheduler as lr_scheduler

from zennit.composites import LayerMapComposite
from zennit.rules import Epsilon, ZPlus, Pass, Norm

from data.dataset import ImageDataset, CelebHQAttrDataset
from init_classifier import LinearClassifier, VQVAEClassifier, ResNet50Classifier, ViTClassifier
from xai_lrp import xai_zennit, show_attributions


class CheckpointedUNetWrapper(torch.nn.Module):
    def __init__(self, model):
        super(CheckpointedUNetWrapper, self).__init__()
        self.model = model

    def checkpointed_forward(self, module, *inputs):
        def custom_forward(*inputs):
            return module(*inputs)
        return checkpoint(custom_forward, *inputs)

    def forward(self, sample, timestep):

        # 1. time
        timesteps = timestep
        if not torch.is_tensor(timesteps):
            timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
        elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
            timesteps = timesteps[None].to(sample.device)

        # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
        timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

        t_emb = self.model.time_proj(timesteps)
        #t_emb = t_emb.to(dtype=self.dtype)
        emb = self.model.time_embedding(t_emb)

        # 2. pre-process
        skip_sample = sample
        sample = self.model.conv_in(sample)

        # 3. down
        down_block_res_samples = (sample,)
        for downsample_block in self.model.down_blocks:
            if hasattr(downsample_block, "skip_conv"):
                sample, res_samples, skip_sample = self.checkpointed_forward(downsample_block, sample, emb, skip_sample)
            else:
                sample, res_samples = self.checkpointed_forward(downsample_block, sample, emb)

            down_block_res_samples += res_samples

        # 4. mid
        sample = self.checkpointed_forward(self.model.mid_block, sample, emb)

        # 5. up
        skip_sample = None
        for upsample_block in self.model.up_blocks:
            res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
            down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

            if hasattr(upsample_block, "skip_conv"):
                sample, skip_sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb, skip_sample)
            else:
                sample = self.checkpointed_forward(upsample_block, sample, res_samples, emb)

        # 6. post-process
        sample = self.model.conv_norm_out(sample)
        sample = self.model.conv_act(sample)
        sample = self.model.conv_out(sample)

        if skip_sample is not None:
            sample += skip_sample

        return {"sample": sample}

def classifier_loss(classifier, images, targets, idx):
    preds = classifier(images)
    if idx % 2 == 0:
        print(f"Classifier prediction: {preds[0][cls_id]}")
    targets = torch.tensor(targets).to(device)
    #error = torch.nn.functional.binary_cross_entropy_with_logits(preds[0][31], targets)
    error = torch.abs(preds[0][cls_id] - targets).mean()
    preds_binary = torch.sigmoid(preds[0][cls_id]) > 0.5

    return error, preds_binary

def minDist_loss(images, original_images):
    error = torch.abs(images - original_images).mean()
    return error


# data loading with ground truth no smiling
data = ImageDataset('/home/dai/GPU-Student-2/Cederic/DataSciPro/data/misclsData_gt1', image_size=256, exts=['jpg', 'JPG', 'png'], do_augment=False, sort_names=True)
dataloader = DataLoader(data, batch_size=1, shuffle=False)

# create output folders
directory_names = []
for i, _ in enumerate(dataloader):
    img_index = dataloader.dataset.paths[i].name.split('_')[0]
    directory_name = os.path.join("/home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear", f'folder_IMG_{img_index}')
    directory_names.append(directory_name)
    os.makedirs(directory_name, exist_ok=True)
    print(f'Created directory: {directory_name}')

#
device = "cuda" if torch.cuda.is_available() else "cpu"
cls_type = 'linear'
cls_id =  CelebHQAttrDataset.cls_to_id['Smiling']

# load all models
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")

unet.to(device)
vqvae.to(device)

checkpointed_unet = CheckpointedUNetWrapper(unet)

# load all models
if cls_type == 'linear':    
    classifier = LinearClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b128linear2024-06-02 13:08:28.ckpt",
                                            input_dim = data[0]['img'].shape,
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))
elif cls_type == 'vqvae':
    classifier = VQVAEClassifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b32vqvae2024-06-01 08:48:59.ckpt",
                                       num_classes = len(CelebHQAttrDataset.id_to_cls))

elif cls_type == 'res50':
    classifier = ResNet50Classifier.load_from_checkpoint("/home/dai/GPU-Student-2/Cederic/DataSciPro/cls_checkpoints/ffhq256.b64res502024-06-02 17:06:41.ckpt",
                                            num_classes = len(CelebHQAttrDataset.id_to_cls))


classifier.to(device)
classifier.eval()
# check functionality of classifier
all_outputs = []
with torch.no_grad():
    for batch in dataloader:
        inputs = batch['img'].to(classifier.device)
        outputs = classifier(inputs)
        print(outputs[0][cls_id])

        preds_binary = torch.sigmoid(outputs[:, cls_id].cpu()) > 0.5
        all_outputs.append(preds_binary) 
all_outputs = torch.cat(all_outputs, dim=0)
print(all_outputs)


###### explainable ai lrp
# lrp rules
layer_map_lrp_0 = [
    (torch.nn.ReLU, Pass()),  # ignore activations
    (torch.nn.Linear, Epsilon(epsilon=0)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_zplus = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, ZPlus()),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

layer_map_lrp_eps = [
    (torch.nn.ReLU, Pass()),
    (torch.nn.Linear, Epsilon(epsilon=1)),  # this is the dense Linear, not any Linear
    (torch.nn.Conv2d, ZPlus()),
    (torch.nn.BatchNorm2d, Pass()),
    (torch.nn.AdaptiveAvgPool2d, Norm()),
]

#before manipulation
for i, batch in enumerate(dataloader):
    inputs = batch['img'].to(classifier.device)
    attr_znt_0 = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, inputs, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[i], attr_znt_0, title='Pre Zennit LRP-0')
    show_attributions(directory_names[i], attr_znt_eps, title='Pre Zennit LRP-EPS')
    show_attributions(directory_names[i], attr_znt_zplus, title='Pre Zennit LRP-Z+')
        
## Inversion
def invert(
    start_latents,
    num_inference_steps,
    device=device,
):

    # Latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps - 1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1:
            continue

        t = timesteps[i]

        # Expand the latents if we are doing classifier free guidance
        latent_model_input = latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # Predict the noise residual
        noise_pred = checkpointed_unet(latent_model_input, t)["sample"]

        current_t = max(0, t.item() - (1000 // num_inference_steps))  # t
        next_t = t  # min(999, t.item() + (1000//num_inference_steps)) # t+1
        alpha_t = scheduler.alphas_cumprod[current_t]
        alpha_t_next = scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1 - alpha_t).sqrt() * noise_pred) * (alpha_t_next.sqrt() / alpha_t.sqrt()) + (
            1 - alpha_t_next
        ).sqrt() * noise_pred

        # Store
        intermediate_latents.append(latents)

    return torch.cat(intermediate_latents)


class LatentNoise(torch.nn.Module):
    """
    The LatentNoise Module makes it easier to update the noise tensor with torch optimizers.
    """

    def __init__(self, noise: torch.Tensor):
        super().__init__()
        self.noise = torch.nn.Parameter(noise)

    def forward(self):
        return self.noise


def diffusion_pipe(noise_module: LatentNoise, num_inference_steps):
        z = noise_module()
        for i in range(start_step, num_inference_steps):
            t = scheduler.timesteps[i]
            z = scheduler.scale_model_input(z, t)
            with torch.no_grad():
                noise_pred = checkpointed_unet(z, t)["sample"]
            z = scheduler.step(noise_pred, t, z).prev_sample
            z0 = scheduler.step(noise_pred, t, z).pred_original_sample
        return z, z0

def plot_images(images, titles=None, figsize=(50, 5), save_path=None):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize=(n*5, 5))

    if n == 1:
        axes = [axes]

    #for i, img in enumerate(images):
    #    img.save(f"{save_path}/{i}.png")

    for i, img in enumerate(images):
        axes[i].imshow(img)
        axes[i].axis('off')
        #if titles is not None:
        #    axes[i].set_title(titles[i])

    if save_path:
        plt.savefig(save_path)
    #plt.show()
    plt.close(fig)

def plot_to_pil(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def tensor_to_pil_image(tensor):
    image = tensor.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image = PIL.Image.fromarray(np.array(image[0].detach().numpy() * 255).astype(np.uint8))
    return image

# conditional sampling
num_inference_steps = 100
start_step = 20


for step, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    #plot_to_pil(batch['img'])
    with torch.no_grad():
        z = vqvae.encode(batch['img'].to(device))   # encode the image in the latent space
    z = z.latents
    #plot_to_pil(z)
    
    #cond = z.view(1,-1)
    #cond = normalize(cond)
    #cond = cond + 0.5 * math.sqrt(512) * classifier.fc1.weight[31].unsqueeze(0)
    #cond = denormalize(cond)
    #z = cond.view(1,3,64,64)
    #dec_z = vqvae.decode(z)[0]
    #plot_to_pil(dec_z)
    
    inverted_latents = invert(z, num_inference_steps)                  # do the ddim scheduler reversed to add noise to the latents
    z = inverted_latents[-(start_step + 1)].unsqueeze(0)                  # use these latents to start the sampling. better performance when using not the last latent sample
    #plot_to_pil(z)
    noise_module = LatentNoise(z.clone()).to(device)                    # convert latent noise to a parameter module for optimization
    noise_module.noise.requires_grad = True
    intermediate_results = [batch['img'].to(device)]   # list to store the results of the steering
    intermediate_preds = [round(classifier(batch['img'].to(device))[0][cls_id].item(), 5)]
    
    optimizer = torch.optim.Adam(
        noise_module.parameters(), lr=0.01, maximize=False # not minimize gradient ascent
    )
    learning_scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)
    
    x = torch.zeros_like(z)
    current_loss = float('inf')
    preds_binary = False
    current_pred = 0.0
    i = 0
    # for the linear classifier it works perfect to break out of the loop if the prediction switches.
    while (current_pred < 1.0) & (i < 20) :
    #while (preds_binary == False) & (i < 20) :
            
            optimizer.zero_grad()
            x, x0 = diffusion_pipe(noise_module, num_inference_steps) # forward
            #plot_to_pil(x)
            decoded_x = vqvae.decode(x)[0]
            #plot_to_pil(decoded_x)
            current_pred = classifier(decoded_x)[0][cls_id]

            if i % 1 == 0:
                intermediate_results.append(decoded_x)
                intermediate_preds.append(round(current_pred.item(), 5))

            loss, preds_binary = classifier_loss(classifier, decoded_x, 1.0, i)
            l1_dist = minDist_loss(decoded_x, batch['img'].to(device))
            loss += l1_dist * 10
            
            if i % 2 == 0:
                print(i, "loss:", loss.item(), "lr:", learning_scheduler.get_lr(), "l1-dist", l1_dist.item())
                #print(i, "loss:", loss.item())
            loss.backward()
            optimizer.step()
            learning_scheduler.step()

            current_loss = loss.item()
            i += 1
    
    with torch.no_grad():
        image = vqvae.decode(x)[0]

    print(f"Diffusion Counterfactual generated with loss: {current_loss} | classifier_prediction: {current_pred} | l1_dist: {l1_dist} | in {i} optimization steps")
    
    #lrp after manipulation
    attr_znt_0 = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_0), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_eps = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_eps), device=device, target=torch.tensor(cls_id).to(device))[0]]
    attr_znt_zplus = [xai_zennit(classifier, image, RuleComposite=LayerMapComposite(layer_map_lrp_zplus), device=device, target=torch.tensor(cls_id).to(device))[0]]
    show_attributions(directory_names[step], attr_znt_0, title='Post Zennit LRP-0')
    show_attributions(directory_names[step], attr_znt_eps, title='Post Zennit LRP-EPS')
    show_attributions(directory_names[step], attr_znt_zplus, title='Post Zennit LRP-Z+')
    image.requires_grad = False
    
    images = [tensor_to_pil_image(tensor) for tensor in intermediate_results]
    gif_path = f"{directory_names[step]}/GIF.gif"
    imageio.mimsave(gif_path, images, format='GIF', duration=2.0, loop=0)  # duration is in seconds

    row_path = f"{directory_names[step]}/sequence.png"
    plot_images(images, intermediate_preds, save_path=row_path)

    # process image
    image_processed = image.cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    image_pil = PIL.Image.fromarray(np.array(image_processed[0] * 255).astype(np.uint8))
    ori_processed = batch['img'].cpu().permute(0, 2, 3, 1).clip(-1,1) * 0.5 + 0.5
    ori_image = PIL.Image.fromarray(np.array(ori_processed[0] * 255).astype(np.uint8))

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(image_pil)
    axs[0].axis('off')
    axs[0].set_title('Diffusion Counterfactual Image')
    axs[1].imshow(ori_image)
    axs[1].axis('off')
    axs[1].set_title('Original Image')
    #plt.show()
    fig.savefig(f'{directory_names[step]}/ori_vs_DCE.png', dpi=300, bbox_inches='tight')
    plt.close(fig)

    image_pil.save(f"{directory_names[step]}/diffCounter_IMG.png")
    print('finish')

Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_27300
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_27398
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_27591
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_27611
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_27931
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_28113
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_28125
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smiling_Linear/sequence/folder_IMG_28285
Created directory: /home/dai/GPU-Student-2/Cederic/DataSciPro/data_output_Smilin

  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
The config attributes {'timestep_values': None, 'timesteps': 1000} were passed to DDIMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


tensor(-1.2704, device='cuda:0')
tensor(-1.6096, device='cuda:0')
tensor(-2.0562, device='cuda:0')
tensor(-1.1663, device='cuda:0')
tensor(-1.7197, device='cuda:0')
tensor(-0.7761, device='cuda:0')
tensor(-1.8571, device='cuda:0')
tensor(-0.5018, device='cuda:0')
tensor(-1.8218, device='cuda:0')
tensor(-2.5475, device='cuda:0')
tensor(-1.2291, device='cuda:0')
tensor(-1.5236, device='cuda:0')
tensor(-2.1163, device='cuda:0')
tensor(-1.3400, device='cuda:0')
tensor(-1.9986, device='cuda:0')
tensor(-1.7572, device='cuda:0')
tensor(-1.3951, device='cuda:0')
tensor(-1.2508, device='cuda:0')
tensor(-1.5588, device='cuda:0')
tensor(-1.4129, device='cuda:0')
tensor(-0.3220, device='cuda:0')
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False])


100%|██████████| 99/99 [00:02<00:00, 40.09it/s]


Classifier prediction: -1.37044358253479
0 loss: 2.778909921646118 lr: [0.01] l1-dist 0.040846627205610275
Classifier prediction: 0.2841572165489197
2 loss: 1.4180927276611328 lr: [0.009931100837462445] l1-dist 0.07022498548030853
Classifier prediction: 1.2118208408355713
4 loss: 1.0750572681427002 lr: [0.009774869058090914] l1-dist 0.08632364124059677
Diffusion Counterfactual generated with loss: 1.0750572681427002 | classifier_prediction: 1.2118208408355713 | l1_dist: 0.08632364124059677 | in 5 optimization steps


  5%|▍         | 1/21 [00:55<18:21, 55.09s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.86it/s]


Classifier prediction: -1.7357804775238037
0 loss: 3.1483583450317383 lr: [0.01] l1-dist 0.04125778749585152
Classifier prediction: 0.14300379157066345
2 loss: 1.7048193216323853 lr: [0.009931100837462445] l1-dist 0.08478231728076935
Classifier prediction: 1.5644819736480713
4 loss: 1.6893174648284912 lr: [0.009774869058090914] l1-dist 0.11248354613780975
Diffusion Counterfactual generated with loss: 1.6893174648284912 | classifier_prediction: 1.5644819736480713 | l1_dist: 0.11248354613780975 | in 5 optimization steps


 10%|▉         | 2/21 [01:07<09:33, 30.19s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.05it/s]


Classifier prediction: -2.0542335510253906
0 loss: 3.383463144302368 lr: [0.01] l1-dist 0.03292296826839447
Classifier prediction: -1.362311840057373
2 loss: 3.008636474609375 lr: [0.009931100837462445] l1-dist 0.06463246047496796
Classifier prediction: 1.2724076509475708
4 loss: 1.078696846961975 lr: [0.009774869058090914] l1-dist 0.08062891662120819
Diffusion Counterfactual generated with loss: 1.078696846961975 | classifier_prediction: 1.2724076509475708 | l1_dist: 0.08062891662120819 | in 5 optimization steps


 14%|█▍        | 3/21 [01:20<06:39, 22.22s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.06it/s]


Classifier prediction: -1.245290756225586
0 loss: 2.685704469680786 lr: [0.01] l1-dist 0.04404136538505554
Classifier prediction: -0.2049323320388794
2 loss: 1.78743577003479 lr: [0.009931100837462445] l1-dist 0.05825034901499748
Classifier prediction: 0.19136324524879456
4 loss: 1.507222294807434 lr: [0.009774869058090914] l1-dist 0.06985855102539062
Classifier prediction: 0.5631318092346191
6 loss: 1.2534486055374146 lr: [0.009543642776065642] l1-dist 0.08165804296731949
Classifier prediction: 1.0101597309112549
8 loss: 0.9846431612968445 lr: [0.009241066670644704] l1-dist 0.09744834154844284
Diffusion Counterfactual generated with loss: 0.9846431612968445 | classifier_prediction: 1.0101597309112549 | l1_dist: 0.09744834154844284 | in 9 optimization steps


 19%|█▉        | 4/21 [01:41<06:08, 21.65s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.35it/s]


Classifier prediction: -1.8131568431854248
0 loss: 3.146090269088745 lr: [0.01] l1-dist 0.03329335153102875
Classifier prediction: 0.5467492938041687
2 loss: 0.9504333138465881 lr: [0.009931100837462445] l1-dist 0.049718260765075684
Diffusion Counterfactual generated with loss: 0.8481041789054871 | classifier_prediction: 1.237642526626587 | l1_dist: 0.061046164482831955 | in 4 optimization steps


 24%|██▍       | 5/21 [01:52<04:45, 17.82s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.20it/s]


Classifier prediction: -0.7897571325302124
0 loss: 2.173189640045166 lr: [0.01] l1-dist 0.038343243300914764
Classifier prediction: 1.167104721069336
2 loss: 0.9912639856338501 lr: [0.009931100837462445] l1-dist 0.08241592347621918
Diffusion Counterfactual generated with loss: 0.9912639856338501 | classifier_prediction: 1.167104721069336 | l1_dist: 0.08241592347621918 | in 3 optimization steps


 29%|██▊       | 6/21 [02:01<03:40, 14.72s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 41.34it/s]


Classifier prediction: -1.9079196453094482
0 loss: 3.2196199893951416 lr: [0.01] l1-dist 0.03117002546787262
Classifier prediction: -1.2846229076385498
2 loss: 2.700162649154663 lr: [0.009931100837462445] l1-dist 0.041553981602191925
Classifier prediction: -0.7463563680648804
4 loss: 2.270904541015625 lr: [0.009774869058090914] l1-dist 0.05245480686426163
Classifier prediction: 0.8227582573890686
6 loss: 0.8845946192741394 lr: [0.009543642776065642] l1-dist 0.07073529064655304
Diffusion Counterfactual generated with loss: 0.9941493272781372 | classifier_prediction: 1.2195571660995483 | l1_dist: 0.07745921611785889 | in 8 optimization steps


 33%|███▎      | 7/21 [02:20<03:46, 16.18s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.42it/s]


Classifier prediction: -0.48073992133140564
0 loss: 1.940827488899231 lr: [0.01] l1-dist 0.046008750796318054
Classifier prediction: 1.7068097591400146
2 loss: 1.3464093208312988 lr: [0.009931100837462445] l1-dist 0.06395995616912842
Diffusion Counterfactual generated with loss: 1.3464093208312988 | classifier_prediction: 1.7068097591400146 | l1_dist: 0.06395995616912842 | in 3 optimization steps


 38%|███▊      | 8/21 [02:29<02:59, 13.84s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.22it/s]


Classifier prediction: -1.9188576936721802
0 loss: 3.249070167541504 lr: [0.01] l1-dist 0.033021267503499985
Classifier prediction: -1.4899877309799194
2 loss: 2.883683204650879 lr: [0.009931100837462445] l1-dist 0.03936952352523804
Classifier prediction: -1.0859318971633911
4 loss: 2.5724844932556152 lr: [0.009774869058090914] l1-dist 0.04865526407957077
Classifier prediction: -0.703542947769165
6 loss: 2.207361936569214 lr: [0.009543642776065642] l1-dist 0.05038190633058548
Classifier prediction: -0.26784756779670715
8 loss: 1.7843916416168213 lr: [0.009241066670644704] l1-dist 0.051654405891895294
Classifier prediction: 0.20653674006462097
10 loss: 1.3570643663406372 lr: [0.008871910576983217] l1-dist 0.05636011064052582
Classifier prediction: 0.6826080679893494
12 loss: 0.9454736113548279 lr: [0.008441994226264132] l1-dist 0.06280817091464996
Classifier prediction: 1.227134346961975
14 loss: 0.9569183588027954 lr: [0.007958095421518055] l1-dist 0.07297839969396591
Diffusion Counter

 43%|████▎     | 9/21 [03:02<03:59, 19.98s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.74it/s]


Classifier prediction: -2.676177501678467
0 loss: 4.092867374420166 lr: [0.01] l1-dist 0.041669003665447235
Classifier prediction: 1.6016203165054321
2 loss: 1.6480499505996704 lr: [0.009931100837462445] l1-dist 0.10464295744895935
Diffusion Counterfactual generated with loss: 1.6480499505996704 | classifier_prediction: 1.6016203165054321 | l1_dist: 0.10464295744895935 | in 3 optimization steps


 48%|████▊     | 10/21 [03:11<03:01, 16.54s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 35.21it/s]


Classifier prediction: -1.2535080909729004
0 loss: 2.513293504714966 lr: [0.01] l1-dist 0.025978539139032364
Classifier prediction: -0.19536444544792175
2 loss: 1.964280128479004 lr: [0.009931100837462445] l1-dist 0.07689157128334045
Classifier prediction: 0.3826928734779358
4 loss: 1.5920321941375732 lr: [0.009774869058090914] l1-dist 0.09747251123189926
Classifier prediction: 1.0411041975021362
6 loss: 1.0880900621414185 lr: [0.009543642776065642] l1-dist 0.10469858348369598
Diffusion Counterfactual generated with loss: 1.0880900621414185 | classifier_prediction: 1.0411041975021362 | l1_dist: 0.10469858348369598 | in 7 optimization steps


 52%|█████▏    | 11/21 [03:28<02:48, 16.80s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.66it/s]


Classifier prediction: -1.5595449209213257
0 loss: 3.143853187561035 lr: [0.01] l1-dist 0.05843082070350647
Classifier prediction: 0.9460344314575195
2 loss: 0.8956513404846191 lr: [0.009931100837462445] l1-dist 0.08416857570409775
Diffusion Counterfactual generated with loss: 1.178856372833252 | classifier_prediction: 1.2044706344604492 | l1_dist: 0.09743857383728027 | in 4 optimization steps


 57%|█████▋    | 12/21 [03:39<02:14, 15.00s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.32it/s]


Classifier prediction: -2.242992401123047
0 loss: 3.6384997367858887 lr: [0.01] l1-dist 0.03955074027180672
Classifier prediction: 0.31294694542884827
2 loss: 1.325820803642273 lr: [0.009931100837462445] l1-dist 0.06387677043676376
Classifier prediction: 0.36876270174980164
4 loss: 1.3447773456573486 lr: [0.009774869058090914] l1-dist 0.07135400176048279
Classifier prediction: 1.6399176120758057
6 loss: 2.0076284408569336 lr: [0.009543642776065642] l1-dist 0.1367710828781128
Diffusion Counterfactual generated with loss: 2.0076284408569336 | classifier_prediction: 1.6399176120758057 | l1_dist: 0.1367710828781128 | in 7 optimization steps


 62%|██████▏   | 13/21 [03:56<02:04, 15.61s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.82it/s]


Classifier prediction: -1.5507910251617432
0 loss: 3.0091726779937744 lr: [0.01] l1-dist 0.04583815857768059
Diffusion Counterfactual generated with loss: 1.088670253753662 | classifier_prediction: 1.3349753618240356 | l1_dist: 0.07536948472261429 | in 2 optimization steps


 67%|██████▋   | 14/21 [04:03<01:30, 12.94s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.59it/s]


Classifier prediction: -2.0247175693511963
0 loss: 3.405848503112793 lr: [0.01] l1-dist 0.03811308741569519
Classifier prediction: -1.374620795249939
2 loss: 2.978684663772583 lr: [0.009931100837462445] l1-dist 0.06040636822581291
Classifier prediction: -0.5683680176734924
4 loss: 2.334089994430542 lr: [0.009774869058090914] l1-dist 0.07657220959663391
Classifier prediction: 1.0921351909637451
6 loss: 1.0457799434661865 lr: [0.009543642776065642] l1-dist 0.09536447376012802
Diffusion Counterfactual generated with loss: 1.0457799434661865 | classifier_prediction: 1.0921351909637451 | l1_dist: 0.09536447376012802 | in 7 optimization steps


 71%|███████▏  | 15/21 [04:20<01:24, 14.14s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.59it/s]


Classifier prediction: -2.0431790351867676
0 loss: 3.5878217220306396 lr: [0.01] l1-dist 0.054464273154735565
Classifier prediction: -0.8416629433631897
2 loss: 2.567584991455078 lr: [0.009931100837462445] l1-dist 0.07259220629930496
Classifier prediction: 0.04257919639348984
4 loss: 1.9767436981201172 lr: [0.009774869058090914] l1-dist 0.10193228721618652
Classifier prediction: 1.2578797340393066
6 loss: 1.5982271432876587 lr: [0.009543642776065642] l1-dist 0.13403473794460297
Diffusion Counterfactual generated with loss: 1.5982271432876587 | classifier_prediction: 1.2578797340393066 | l1_dist: 0.13403473794460297 | in 7 optimization steps


 76%|███████▌  | 16/21 [04:37<01:15, 15.01s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.47it/s]


Classifier prediction: -1.441793441772461
0 loss: 2.925058364868164 lr: [0.01] l1-dist 0.048326484858989716
Classifier prediction: -0.20588275790214539
2 loss: 1.9482216835021973 lr: [0.009931100837462445] l1-dist 0.07423388212919235
Classifier prediction: 1.1445190906524658
4 loss: 1.0179715156555176 lr: [0.009774869058090914] l1-dist 0.08734524250030518
Diffusion Counterfactual generated with loss: 1.0179715156555176 | classifier_prediction: 1.1445190906524658 | l1_dist: 0.08734524250030518 | in 5 optimization steps


 81%|████████  | 17/21 [04:50<00:57, 14.49s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.75it/s]


Classifier prediction: -1.2903563976287842
0 loss: 2.64945912361145 lr: [0.01] l1-dist 0.035910263657569885
Classifier prediction: -0.279375284910202
2 loss: 1.8135093450546265 lr: [0.009931100837462445] l1-dist 0.053413406014442444
Classifier prediction: 0.44443801045417786
4 loss: 1.2426519393920898 lr: [0.009774869058090914] l1-dist 0.06870898604393005
Classifier prediction: 1.358891487121582
6 loss: 1.2175981998443604 lr: [0.009543642776065642] l1-dist 0.0858706682920456
Diffusion Counterfactual generated with loss: 1.2175981998443604 | classifier_prediction: 1.358891487121582 | l1_dist: 0.0858706682920456 | in 7 optimization steps


 86%|████████▌ | 18/21 [05:07<00:45, 15.24s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.89it/s]


Classifier prediction: -1.6596801280975342
0 loss: 3.0025620460510254 lr: [0.01] l1-dist 0.034288205206394196
Classifier prediction: -1.2291425466537476
2 loss: 2.5688636302948 lr: [0.009931100837462445] l1-dist 0.03397208824753761
Classifier prediction: -0.3057168424129486
4 loss: 1.823319435119629 lr: [0.009774869058090914] l1-dist 0.05176025629043579
Diffusion Counterfactual generated with loss: 0.6584835052490234 | classifier_prediction: 1.052412509918213 | l1_dist: 0.060607098042964935 | in 6 optimization steps


 90%|█████████ | 19/21 [05:22<00:30, 15.15s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.62it/s]


Classifier prediction: -1.351494312286377
0 loss: 2.885748863220215 lr: [0.01] l1-dist 0.05342545360326767
Classifier prediction: -0.20971660315990448
2 loss: 1.9263710975646973 lr: [0.009931100837462445] l1-dist 0.071665458381176
Diffusion Counterfactual generated with loss: 1.061234712600708 | classifier_prediction: 1.0635846853256226 | l1_dist: 0.09976499527692795 | in 4 optimization steps


 95%|█████████▌| 20/21 [05:33<00:13, 13.86s/it]

finish


100%|██████████| 99/99 [00:02<00:00, 40.27it/s]


Classifier prediction: -0.3589981496334076
0 loss: 1.7364102602005005 lr: [0.01] l1-dist 0.03774121031165123
Classifier prediction: 2.9557957649230957
2 loss: 2.707279682159424 lr: [0.009931100837462445] l1-dist 0.07514839619398117
Diffusion Counterfactual generated with loss: 2.707279682159424 | classifier_prediction: 2.9557957649230957 | l1_dist: 0.07514839619398117 | in 3 optimization steps


100%|██████████| 21/21 [05:42<00:00, 16.30s/it]

finish



