In [3]:
import os

os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

import torch
from torch import nn
from matplotlib import pyplot as plt
import numpy as np
import torchvision
from torchvision import transforms
from torchvision import datasets
from ast import literal_eval
import yaml

from torch.autograd import Variable

from models import cb_gan
from utils.datasets import ColoredMNIST, Libero
from utils.utils import get_dataset

dev = "cuda:0"
device = torch.device(dev)


In [2]:
# Load Config
config_file = 'config/cb_gan/libero_90.yaml'
with open(config_file, 'r') as stream:
		config = yaml.safe_load(stream)

if (torch.cuda.is_available()  and  config["train_config"]["use_cuda"] ) :
    use_cuda=True
else:
    use_cuda=False
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

config['dataset']['batch_size'] = 5
config['dataset']['test_batch_size'] = 5

In [4]:
# Load Test Data
dataloader , test_loader = get_dataset(config)

In [5]:
model = cb_gan.cbGAN(config)
model.load_state_dict(torch.load("models/cb_gan_libero_90_15.pt"))
model.eval()
model = model.to(device)

In [12]:
# Get just one sample of data
imgs, concepts = next(iter(test_loader))
imgs = Variable(imgs.type(FloatTensor))
imgs = imgs.to(device)
# concepts = [c.to(device) for c in concepts]
print(concepts.size())
concept_list = [int(item) for item in concepts[0]]
print(concept_list)

torch.Size([5, 22])
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]


In [13]:
# Process sample and retrieve probs and ground truth concepts
_, _, latent = model.enc(imgs)
print(latent.size())
fake_data,logits,_,_= model.dec(latent, return_all=True)
gen_imgs = model.dec(latent)

torch.Size([5, 64])


In [14]:
pos = torch.tensor([float(l) for i, l in enumerate(logits[0]) if i % 2 == 0])
neg = torch.tensor([float(l) for i, l in enumerate(logits[0]) if i % 2 == 1])
mod_logits = []
for p, n in zip(pos, neg):
    sample = torch.cat([p.reshape(-1,1), n.reshape(-1,1)], axis=1).to(device)
    sample = sample.unsqueeze(1).repeat(5, 1, 1)
    mod_logits.append(sample.squeeze())
    # if mod_logits == None:
    #     mod_logits = sample
    # else:
    #     mod_logits = torch.cat((mod_logits, sample), 1)
# mod_logits = torch.cat([pos.reshape(-1,1), neg.reshape(-1,1)], axis=1)
mod_logits[0].size()

torch.Size([5, 2])

In [15]:
mod_logits[1][0][0] = 0
mod_logits[1][0][1] = 1
len(mod_logits)

22

In [16]:
# Call code again with interventions on probs and observe how the output image is affected
# fake_data_mod,logits_mod,_,_= model.dec(latent,probs=mod_logits, return_all=True)
gen_imgs_latent = model.dec(latent,probs=mod_logits)
from torchvision.utils import save_image

# real_img = torch.swapaxes(imgs[0], 0, 2).cpu().detach().numpy()
# fake_img = torch.swapaxes(fake_data[0], 0, 2).cpu().detach().numpy()
# real_img = np.rot90(real_img)
# fake_img = np.rot90(fake_img)
# fake_img_mod = torch.swapaxes(fake_data_mod[0], 0, 2).cpu().detach().numpy()
# fake_img_mod = np.rot90(fake_img_mod)

# plt.imshow(real_img)
# plt.show()
# plt.imshow(fake_img)
# plt.show()
# plt.imshow(fake_img_mod)
# plt.show()

save_image(gen_imgs.data, "intervention.png", nrow=5, normalize=True)