In [None]:
import os
import torch
import torch.nn as nn
import random
import argparse
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os.path as osp
import glob

from PIL import Image
from tqdm import tqdm
from torch.utils import data
from torchvision import transforms

import sys
sys.path.insert(0,'/home/argusm/lang/')
from LDVCE.data.imagenet_classnames import name_map

# create dataset to read the counterfactual results images
class CFDataset():
    def __init__(self, path):

        self.images = []
        self.path = path
        for bucket_folder in glob.glob(self.path + "/bucket*"):
            self.images += [(original, counterfactual) for original, counterfactual in zip(sorted(glob.glob(bucket_folder + "/original/*.png")), sorted(glob.glob(bucket_folder + "/counterfactual/*.png")))]
        self.images = sorted(self.images)
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        original_path, counterfactual_path = self.images[idx]
        return original_path, counterfactual_path
        
def load_img(path):
    img = Image.open(os.path.join(path))
    img = np.array(img, dtype=np.uint8)
    return img



In [None]:
path = "/misc/lmbraid21/faridk/LDCE_w382_cc23"
dataset = CFDataset(path)
batch_size = 2
loader = data.DataLoader(dataset, batch_size=batch_size,
                         shuffle=False,
                         num_workers=4, pin_memory=True)

In [None]:
import skimage.filters

%matplotlib inline
dataset = CFDataset(path)
def image_l1(a, b):
    a = np.array(a)
    b = np.array(b)
    return np.linalg.norm(b-a, axis=2, ord=2)
    
cl, cf = dataset[9000+10]
cl = load_img(cl)
cf = load_img(cf)
#cf = skimage.filters.gaussian(cf, sigma=10, truncate=1/5)

print(cf.shape)
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
[x.set_axis_off() for x in axes]
axes[0].imshow(cl)
axes[1].imshow(cf)
diff = image_l1(cl, cf)
print(np.mean(diff))
dih = axes[2].imshow(diff)
fig.colorbar(dih, orientation='vertical')
plt.show()
plt.close()

In [None]:
import skimage.filters

%matplotlib inline
dataset = CFDataset(path)
def image_l1(a, b):
    a = np.array(a)
    b = np.array(b)
    return np.linalg.norm(b-a, axis=2, ord=1)

cl, cf = dataset[100]
cl = load_img(cl) / 255.
cf = load_img(cf) / 255.
print("cl range", cl.min(), cl.max())
#or sigma in (0,):
#   cf_mod = skimage.filters.gaussian(cf, sigma=sigma, truncate=1/5)
    #print("cf_mod range", cf_mod.min(), cf_mod.max())
cf_mod = cf
diff = image_l1(cl, cf_mod)
print(np.round(np.mean(diff),4))

fig, axes = plt.subplots(1, 3, figsize=(18, 6))
[x.set_axis_off() for x in axes]
axes[0].imshow(cl)
axes[1].imshow(cf_mod)
dih = axes[2].imshow(diff)
fig.colorbar(dih, orientation='vertical')
plt.show()
plt.close()

In [None]:
from pathlib import Path
from skimage.transform import resize
from skimage.morphology import disk
from skimage.filters import gaussian

for i in np.array(range(10))+100:
    image_path1 = dataset[i][0]
    save_path = Path("/misc/lmbraid21/faridk/seg_dino") / (Path(image_path1).stem + '.npz')
    print(image_path1)
    fig, axes = plt.subplots(1, 2, figsize=(18, 18))
    axes[0].imshow(load_img(image_path1))
    fg_mask_small = np.load(save_path)["fg_mask"]
    fg_mask_small = skimage.morphology.binary_dilation(fg_mask_small,np.ones((3,3)))  # dilate mask
    fg_mask_small = gaussian(fg_mask_small, sigma=1, preserve_range=True)  # smooth mask
    #fg_mask = np.array(Image.fromarray(fg_mask_small).resize((244,244),Image.BICUBIC))>.5
    fg_mask = np.array(Image.fromarray(fg_mask_small.astype(float)).resize((244,244),Image.BICUBIC))  #upsample
    fg_mask = fg_mask.clip(0,1.0)
    axes[1].imshow(fg_mask)
    plt.plot()
