In [1]:
import os

os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

import torch
from matplotlib import pyplot as plt
import numpy as np
import yaml
from torch.autograd import Variable

from models import cb_gan
from utils.utils import get_dataset

dev = "cuda:0"
device = torch.device(dev)


In [2]:
# Load Config
config_file = 'config/cb_gan/libero_90.yaml'
with open(config_file, 'r') as stream:
		config = yaml.safe_load(stream)

if (torch.cuda.is_available()  and  config["train_config"]["use_cuda"] ) :
    use_cuda=True
else:
    use_cuda=False
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor


In [3]:
# Load data and model
dataloader , test_loader = get_dataset(config)

In [4]:
model = cb_gan.cbGAN(config)
model.load_state_dict(torch.load("trained_models/cb_gan_libero_90_20.pt"))
model = model.to(device)
model.train()
a=1

In [5]:
# Get just one sample of data
imgs, concepts = next(iter(dataloader))
imgs = Variable(imgs.type(FloatTensor))
concept_list = [int(item) for item in concepts[0]]
print(concept_list)

# Process sample and retrieve probs and ground truth concepts
_, _, latent = model.enc(imgs)
fake_data,logits,_,_= model.dec(latent, return_all=True)
gen_imgs = model.dec(latent)

[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]


  prob_gumbel = F.softmax(logits)


In [12]:
concept_dict = {
    'white_yellow_mug': 0,
    'butter': 1,
    'wine_bottle': 2,
    'yellow_book': 3,
    'ketchup': 4,
    'tomato_sauce': 5,
    'orange_juice': 6,
    'porcelain_mug': 7,
    'chefmate_8_frypan': 8,
    'cream_cheese': 9,
    'plate': 10,
    'chocolate_pudding': 11,
    'red_coffee_mug': 12,
    'moka_pot': 13,
    'basket': 14,
    'milk': 15,
    'white_bowl': 16,
    'wooden_tray': 17,
    'akita_black_bowl': 18,
    'alphabet_soup': 19,
    'black_book': 20,
    'new_salad_dressing': 21,
}

def intervention(logits, latent, device, concept, positive=True):
    logits_list = []
    for sample in logits:
        pos = torch.tensor([float(l) for i, l in enumerate(sample) if i % 2 == 0])
        neg = torch.tensor([float(l) for i, l in enumerate(sample) if i % 2 == 1])
        sample_logits = torch.cat([pos.reshape(-1,1), neg.reshape(-1,1)], axis=1).unsqueeze(dim=0).to(device)
        logits_list.append(sample_logits)

    mod_logits = torch.vstack(logits_list)
    mod_logits = list(torch.swapaxes(mod_logits, 0, 1))

    if positive:
        mod_logits[concept_dict[concept]][0][0] = 1
        mod_logits[concept_dict[concept]][0][1] = 0
    else:
        mod_logits[concept_dict[concept]][0][0] = 0
        mod_logits[concept_dict[concept]][0][1] = 1

    fake_data_mod,_,_,_= model.dec(latent, probs=mod_logits, return_all=True)


    fake_img_mod = torch.swapaxes(fake_data_mod[0], 0, 2).cpu().detach().numpy()
    fake_img_mod = np.rot90(fake_img_mod)
    plt.imshow(fake_img_mod)
    if positive:
        plt.savefig("interventions/" + concept + "_positive.png")
    else:
        plt.savefig("interventions/" + concept + "_negative.png")
    plt.close()


In [13]:
os.makedirs("interventions/", exist_ok=True)
real_img = torch.swapaxes(imgs[0], 0, 2).cpu().detach().numpy()
fake_img = torch.swapaxes(fake_data[0], 0, 2).cpu().detach().numpy()
real_img = np.rot90(real_img)
fake_img = np.rot90(fake_img)
plt.imshow(real_img)
plt.savefig("interventions/real_image.png")
plt.close()
plt.imshow(fake_img)
plt.savefig("interventions/generated_image.png")
plt.close()

concept_names = [
    'white_yellow_mug',
    'butter',
    'wine_bottle',
    'yellow_book',
    'ketchup',
    'tomato_sauce',
    'orange_juice',
    'porcelain_mug',
    'chefmate_8_frypan',
    'cream_cheese',
    'plate',
    'chocolate_pudding',
    'red_coffee_mug',
    'moka_pot',
    'basket',
    'milk',
    'white_bowl',
    'wooden_tray',
    'akita_black_bowl',
    'alphabet_soup',
    'black_book',
    'new_salad_dressing'
]

indices = [ind for ind, ele in enumerate(concept_list) if ele == 1]
concepts = [concept_names[i] for i in indices]
print(concepts)

for concept_idx in concept_names:
    intervention(logits, latent, device, concept_idx, positive=True)
    intervention(logits, latent, device, concept_idx, positive=False)

['moka_pot']
white_yellow_mug POSITIVE
white_yellow_mug NEGATIVE
butter POSITIVE
butter NEGATIVE
wine_bottle POSITIVE
wine_bottle NEGATIVE
yellow_book POSITIVE
yellow_book NEGATIVE
ketchup POSITIVE
ketchup NEGATIVE
tomato_sauce POSITIVE
tomato_sauce NEGATIVE
orange_juice POSITIVE
orange_juice NEGATIVE
porcelain_mug POSITIVE
porcelain_mug NEGATIVE
chefmate_8_frypan POSITIVE
chefmate_8_frypan NEGATIVE
cream_cheese POSITIVE
cream_cheese NEGATIVE
plate POSITIVE
plate NEGATIVE
chocolate_pudding POSITIVE
chocolate_pudding NEGATIVE
red_coffee_mug POSITIVE
red_coffee_mug NEGATIVE
moka_pot POSITIVE
moka_pot NEGATIVE
basket POSITIVE
basket NEGATIVE
milk POSITIVE
milk NEGATIVE
white_bowl POSITIVE
white_bowl NEGATIVE
wooden_tray POSITIVE
wooden_tray NEGATIVE
akita_black_bowl POSITIVE
akita_black_bowl NEGATIVE
alphabet_soup POSITIVE
alphabet_soup NEGATIVE
black_book POSITIVE
black_book NEGATIVE
new_salad_dressing POSITIVE
new_salad_dressing NEGATIVE
