In [1]:
%matplotlib inline
from densenet import densenet121_attn
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from scipy.misc import imresize
import glob
from visualize import make_dot

Using TensorFlow backend.


## Visualize masks

In [2]:
def plot_data_img(img):
    # Create figure and axes
    fig,ax = plt.subplots(1)

    # Display the image
    ax.imshow(img)
    
    plt.show()
    
def plot_mask(img, m1, m2, m3, title=None, filename=None):
    """
    Plots the given torch tensor as grayscale mask
    """
    _ = plt.figure(figsize = (10,10))
    _ = plt.subplot(2,2,1)
    _ = plt.axis('off')
    if title is not None:
        _ = plt.suptitle(title, verticalalignment='bottom', y=0.9)
    _ = plt.imshow(img)
    _ = plt.subplot(2,2,2)
    _ = plt.axis('off')
    _ = plt.imshow(img)
    m = imresize(m1.data.numpy()[0], (299,299))
    _ = plt.imshow(m, 'jet', interpolation='none', alpha=0.5)
    _ = plt.subplot(2,2,3)
    _ = plt.axis('off')
    _ = plt.imshow(img)
    m = imresize(m2.data.numpy()[0], (299,299))
    _ = plt.imshow(m, 'jet', interpolation='none', alpha=0.5)
    _ = plt.subplot(2,2,4)
    _ = plt.axis('off')
    _ = plt.imshow(img)
    m = imresize(m3.data.numpy()[0], (299,299))
    _ = plt.imshow(m, 'jet', interpolation='none', alpha=0.5)
    _ = plt.show()
    if filename is not None:
        _ = plt.savefig(filename+'_mask.png')
    
def get_img_array(path, target_dim=(299,299)):
    """
    Given path of image, returns it's numpy array
    """
    return img_to_array(load_img(path, target_size=target_dim))/255.0

In [None]:
weights = '../saved/den_121_attn_cub_crop_13_0.80375.pth'
d = densenet121_attn(weights=weights, mask_only=True)

In [None]:
# predict mask for random validation images
# and visualize it
data_dir = '/home/birdsnap/CUB_200_2011/cropped_test/validation'
base_save_path = '../vis/1_simpleFCmaskPred/'
folders = [file for file in glob.glob(data_dir+'*/*')]
for folder in folders[:10]:
    files = [file for file in glob.glob(folder + '/*')]
    file_sample = np.random.choice(files, 3, replace=False)
    sample = [get_img_array(i) for i in file_sample]
    img_v = [torch.autograd.Variable(torch.Tensor(i)).view(1,299, 299, 3) for i in sample]
    img_v = torch.cat(img_v, 0)
    img_v = img_v.permute(0, 3, 2, 1)
    m1, m2, m3 = d(img_v)
    for i in range(3):
        filename = folder.split('/')[-1] + '_' + file_sample[i].split('/')[-1]
        full_path = base_save_path + filename
        plot_mask(sample[i], m1[i], m2[i], m3[i], title=filename, filename=full_path)

In [None]:
# plot sum of all masks
m4 = m1 + m2 + m3
i = 0
plot_mask(sample[i], m1[i], m2[i], m4[i], title="Sum")

In [None]:
# visualize computation graph
make_dot(m1)

## Upsampling Component

In [None]:
class Upsampler(torch.nn.Module):
    def __init__(self, target_dim=(299,299), mode='bilinear'):

        super(Upsampler, self).__init__()
        self.h = target_dim[0]
        self.w = target_dim[1]

        self.upsampler = torch.nn.Upsample(size=target_dim, mode=mode)
        
    def img_crop(self, x, tl_x, tl_y, br_x, br_y, target_size=(299,299)):
        """
        Takes tensor of dimension x: (3, 299, 299) and
        f: (s, 4) containing tl_x, tl_y, br_x, br_y in that
        order. Returns upsampled crops
        """
        # note that the following step is not 
        # a part of the network, taking values
        # out of the tensor here
        tl_x, tl_y, br_x, br_y = int(tl_x.data[0]), int(tl_y.data[0]), int(br_x.data[0]), int(br_y.data[0])
        #cropped = img_set_zero(x, tl_x, tl_y, br_x, br_y)
        #cropped = cropped[:,tl_x:br_x,tl_y:br_y].contiguous()
        cropped = x[:,tl_x:br_x,tl_y:br_y].contiguous()
        cropped = cropped.view(1, 3, cropped.size(1), cropped.size(2))
        bi = torch.nn.Upsample(size=(299,299), mode='bilinear')
        upped = bi(cropped).view(3, 299, 299)
        return upped

    def img_crops(self, x, f):
        """
        x: (3, 299, 299)
        f: (g, 4) tl_x, tl_y, br_x, br_y
        returns cropped and upsampled same as x.size
        """
        out = []
        for f_i in torch.unbind(f):
            out.append(self.img_crop(x, f_i[0], f_i[1], f_i[2], f_i[3]))
        out = torch.stack(out, 0)
        return out

    def imgs_crops(self, x, f):
        """
        x: (s, 3, 299, 299)
        f: (s, g, 4) tl_x, tl_y, br_x, br_y
        returns cropped and upsampled same as x.size
        """
        out = []
        for i,x_i in enumerate(torch.unbind(x)):
            out.append(self.img_crops(x_i, f[i]))
        out = torch.stack(out, 0)
        return out

    def forward(self, x, f):
        return self.imgs_crops(img_v, f)


In [None]:
# load two images into pyt vars
img_path = '/home/birdsnap/CUB_200_2011/cropped_test/validation/200.Common_Yellowthroat/Common_Yellowthroat_0092_190573.jpg'
img = get_img_array(img_path)
print img.shape
plt.imshow(img)
plt.show()
img_v1 = torch.autograd.Variable(torch.Tensor(img).view(1,299,299,3)).permute(0, 3, 1, 2)

img_path = '/home/birdsnap/CUB_200_2011/cropped_test/validation/200.Common_Yellowthroat/Common_Yellowthroat_0075_190900.jpg'
img = get_img_array(img_path)
plot_data_img(img)
img_v2 = torch.autograd.Variable(torch.Tensor(img).view(1,299,299,3)).permute(0, 3, 1, 2)
img_v = torch.cat([img_v1, img_v2], 0)
print img_v.size()


# 2 masks per image, 2 images: (s=2, g=2, 4)
# 4 points in order: tl_x, tl_y, br_x, br_y

f = torch.autograd.Variable(torch.Tensor([[[0, 20, 150, 150], [200, 245, 298, 298]], 
                                          [[50, 35, 275, 200], [200, 150, 280, 280]]]))
print f.size()


In [None]:
up = Upsampler()
cropped = up(img_v, f)
print cropped.size()

In [None]:
for i in range(2):
    for j in range(2):
        plt.imshow(cropped[i][j].permute(1, 2, 0).data.numpy())
        plt.show()

In [None]:
torch.cuda.manual_seed_all(1)
j = torch.autograd.Variable(torch.randn(3, 2, 4)*10)
print j

In [None]:
i = j.clone()
i[:, :, 2] = i[:, :, 2]*1000

print i
i[:, :, 2] = torch.clamp(i[:, :, 2], max=5)
print i

In [None]:
i = i.long()
print i

In [None]:
import time
a = time.time()

In [None]:
b = time.time() -a

In [None]:
print b

In [11]:
import math
def save_plot(imgs, name):
    """ 
    Takes a (299, 299, 3)
    Glimpses: (3, 3, 299, 299)
    """
    plt.figure(figsize=(10,10))
    n = len(imgs)
    x = math.ceil(n/2)
    y = 2

    for i, img in enumerate(imgs):
        plt.subplot(x,y,i+1)
        plt.axis('off')
        plt.imshow(img)
    plt.savefig(name)

def save_glimpses(x, glimpses, epoch, path, exp_name):
    """ 
    x: (s, 3, 299, 299)
    glimpses: (s, g, 3, 299, 299)
    """
    glimpses = glimpses.permute(0, 1, 3, 4, 2)
    glimpses = glimpses.data.numpy()
    x = x.permute(0, 2, 3, 1).data.numpy()
    for s in range(x.shape[0]):
        name = "{}/{}_{}.png".format(path, exp_name, epoch)
        img = x[s]
        save_plot([img, glimpses[s][0], glimpses[s][1], glimpses[s][2]], name)


In [12]:
path = 'visual_attn/vis/2_RACNN'
x = torch.autograd.Variable(torch.randn(2,3,299,299))
i = torch.autograd.Variable(torch.randn(2,2,3,299,299))
save_glimpses(x, i, 3, path, 'test')

IndexError: index 2 is out of bounds for axis 0 with size 2