<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2020_0703four_in_one_network2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 2020年度駒澤大学心理学特講IIIA 

浅川伸一

## One network, many uses

This notebook follows the tutorial: [One neural network, many uses: image captioning, image search, similar images and similar words using one model](https://towardsdatascience.com/one-neural-network-many-uses-build-image-search-image-captioning-similar-words-and-similar-1e22080ce73d) 

Made by [@paraschopra](https://twitter.com/paraschopra)

MIT License.

In [None]:
!pip install -U ipympl > /dev/null

In [None]:
!pip install -U -q PyDrive

from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

### Flickr8k Dataset
- Flickr8k dataset を使用します。本来は使用許可を得るために [この書類](https://forms.illinois.edu/sec/1713398) をダウンロードして使用許可を得る必要があります。
- このデータセットには 8000 枚の画像に対して 5 つの脚注ついています。

In [None]:
# 実習のための画像データの ID を入れて，データを入手
download = drive.CreateFile({'id': '1y2P-Z8ZlpEyNbUq2fuDAIcb4Sjp4tyUE'})
download.GetContentFile('caption_datasets.zip')
!unzip caption_datasets.zip

download = drive.CreateFile({'id': '1FyModcTYRiHaoXHU_wloTIz13S6wpIW4'})
download.GetContentFile('Flicker8k_Dataset.zip')
!unzip Flicker8k_Dataset.zip > /dev/null

In [None]:
#https://drive.google.com/open?id=1-bFD13-6GWgDSxgBWQE33vPBLQfKz7U3
download = drive.CreateFile({'id': '1rYWwq26aECq-Xmq3sOQO6oFEheeB69AH'})
download.GetContentFile('models.zip')
!unzip models.zip > /dev/null    

In [None]:
import matplotlib.pyplot as plt
import random
import json
#import ipympl 
%matplotlib inline
#%matplotlib widget

from scipy import ndimage
import numpy as np
from copy import deepcopy
from PIL import Image
import IPython.display
from math import floor

import torch
import torch.nn as nn                     # neural networks
import torch.nn.functional as F           # layers, activations and more
import torch.optim as optim  
import torchvision.transforms.functional as TF
import torchvision
from torchvision import datasets, models, transforms

In [None]:
# GPU の設定チェック
is_cuda = torch.cuda.is_available()
is_cuda

In [None]:
if(is_cuda):
    USE_GPU = True
else:
    USE_GPU = False

In [None]:
ENDWORD = '<END>'
STARTWORD = '<START>'
PADWORD = '<PAD>'
HEIGHT = 299
WIDTH = 299
INPUT_EMBEDDING = 300
HIDDEN_SIZE = 300
OUTPUT_EMBEDDING = 300

CAPTION_FILE = 'caption_datasets/dataset_flickr8k.json'
IMAGE_DIR = 'Flicker8k_Dataset/'

In [None]:
# [PyTorch modelzoo](https://pytorch.org/docs/stable/torchvision/models.html) から訓練済みのモデルを入手します
import string
inception = models.inception_v3(pretrained=True)

In [None]:
# inception model の詳細を表示
inception

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
num_ftrs = inception.fc.in_features
num_ftrs

In [None]:
set_parameter_requires_grad(inception, True)
num_ftrs = inception.fc.in_features
inception.fc = nn.Linear(num_ftrs,INPUT_EMBEDDING)

In [None]:
inception.load_state_dict(torch.load('models/inception_epochs_40.pth'))

In [None]:
if(USE_GPU):
    inception.cuda()

## Class for holding data

In [None]:
class Flickr8KImageCaptionDataset:    
    def __init__(self):        
        all_data = json.load(open('caption_datasets/dataset_flickr8k.json', 'r'))
        all_data=all_data['images']
        
        self.training_data = []
        self.test_data = []
        self.w2i = {ENDWORD: 0, STARTWORD: 1}
        self.word_frequency = {ENDWORD: 0, STARTWORD: 0}
        self.i2w = {0: ENDWORD, 1: STARTWORD}
        self.tokens = 2 #END is default
        self.batch_index = 0
        
        for data in all_data:
            if(data['split']=='train'):
                self.training_data.append(data)
            else:
                self.test_data.append(data)
                
            for sentence in data['sentences']:
                for token in sentence['tokens']:
                    if(token not in self.w2i.keys()):
                        self.w2i[token] = self.tokens
                        self.i2w[self.tokens] = token
                        self.tokens +=1
                        self.word_frequency[token] = 1
                    else:
                        self.word_frequency[token] += 1
                        
    def image_to_tensor(self,filename):
        image = Image.open(filename)
        image = TF.resize(img=image, size=(HEIGHT,WIDTH))
       
        image = TF.to_tensor(pic=image)
        image = TF.normalize(image, mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        
        return torch.unsqueeze(image,0)

    
    def return_train_batch(self): #size of 1 always        
        #np.random.shuffle(self.training_data)
        for index in range(len(self.training_data)):
            
        #index = np.random.randint(len(self.training_data))
            sentence_index = np.random.randint(len(self.training_data[index]['sentences']))
            output_sentence_tokens = deepcopy(self.training_data[index]['sentences'][sentence_index]['tokens'])
            output_sentence_tokens.append(ENDWORD) #corresponds to end word
            image = self.image_to_tensor('Flicker8k_Dataset/'+self.training_data[index]['filename'])
            yield image, list(map(lambda x: self.w2i[x], output_sentence_tokens)), output_sentence_tokens, index
    
    def convert_tensor_to_word(self, output_tensor):
        output = F.log_softmax(output_tensor.detach().squeeze(), dim=0).numpy()
        return self.i2w[np.argmax(output)]
    
    def convert_sentence_to_tokens(self, sentence):        
        tokens = sentence.split(" ")
        converted_tokens= list(map(lambda x: self.w2i[x], tokens))
        converted_tokens.append(self.w2i[ENDWORD])
        return converted_tokens
    
    def caption_image_greedy(self, net, image_filename, max_words=15): 
        #non beam search, no temperature implemented
        net.eval()
        inception.eval()
        image_tensor = self.image_to_tensor(image_filename)
        hidden=None
        embedding=None
        words = []    
        input_token = STARTWORD
        input_tensor = torch.tensor(self.w2i[input_token]).type(torch.LongTensor)                
        for i in range(max_words):            
            if(i==0):
                out, hidden=net(input_tensor, hidden=image_tensor, process_image=True)
            else:
                out, hidden=net(input_tensor, hidden)
                
            word = self.convert_tensor_to_word(out)
            input_token = self.w2i[word]
            input_tensor = torch.tensor(input_token).type(torch.LongTensor)            
            
            if(word==ENDWORD):
                break
            else:
                words.append(word)

        return ' '.join(words)    
    
    def forward_beam(self, net, hidden, process_image, 
                     partial_sentences, sentences, topn_words=5, max_sentences=10):        
        max_words = 50        
        hidden_index = {}        
        while(sentences<max_sentences):            
            #print("Sentences: ",sentences)            
            new_partial_sentences = []
            new_partial_sentences_logp = []
            new_partial_avg_logp= []
            
            if(len(partial_sentences[-1][0])>max_words):
                break                               
            for partial_sentence in partial_sentences:
                input_token = partial_sentence[0][-1]
                input_tensor = torch.tensor(self.w2i[input_token]).type(torch.FloatTensor)
                if(partial_sentence[0][-1]==STARTWORD):
                    out, hidden=net(input_tensor, hidden, process_image=True)
                else:
                    out, hidden=net(input_tensor, torch.tensor(hidden_index[input_token]))

                #take first topn words and add as children to root
                out = F.log_softmax(out.detach().squeeze(), dim=0).numpy()
                out_indexes = np.argsort(out)[::-1][:topn_words]
        
                for out_index in out_indexes:                    
                    if(self.i2w[out_index]==ENDWORD):
                        sentences=sentences+1                        
                    else:                    
                        total_logp = float(out[out_index]) + partial_sentence[1]
                        new_partial_sentences_logp.append(total_logp)
                        new_partial_sentences.append([np.concatenate((partial_sentence[0], 
                                                                      [self.i2w[out_index]])),
                                                      total_logp])                        
                        len_words = len(new_partial_sentences[-1][0])                                                                        
                        new_partial_avg_logp.append(total_logp/len_words)                        
                        #print(self.i2w[out_index])                        
                        hidden_index[self.i2w[out_index]] = deepcopy(hidden.detach().numpy())
                                                   
            #select topn partial sentences        
            top_indexes = np.argsort(new_partial_sentences_logp)[::-1][:topn_words]                                                                                        
            new_partial_sentences = np.array(new_partial_sentences)[top_indexes]        
            #print("New partial sentences (topn):", new_partial_sentences)        
            partial_sentences = new_partial_sentences                    
        return partial_sentences
    
    def caption_image_beam_search(self, net, image_filename, topn_words=10, max_sentences=10):        
        net.eval()
        inception.eval()
        image_tensor = self.image_to_tensor(image_filename)
        hidden=None
        embedding=None
        words = []           
        sentences = 0
        partial_sentences = [[[STARTWORD], 0.0]]
        #root_id = hash(input_token) #for start word
        #nodes = {}        
        #nodes[root_id] = Node(root_id, [STARTWORD, 0], None)
        partial_sentences = self.forward_beam(net, 
                                              image_tensor, 
                                              True, 
                                              partial_sentences, 
                                              sentences,  
                                              topn_words, 
                                              max_sentences)
    
        logp = []        
        joined_sentences = []    
        for partial_sentence in partial_sentences:                    
            joined_sentences.append([' '.join(partial_sentence[0][1:]),partial_sentence[1]])                                        
        return joined_sentences

    def print_beam_caption(self, net, train_filename,num_captions=0):
        beam_sentences = f.caption_image_beam_search(net,train_filename)
        if(num_captions==0):
            num_captions=len(beam_sentences)
        for sentence in beam_sentences[:num_captions]:
            print(sentence[0]+" [",sentence[1], "]")                    

## class for network

In [None]:
class IC_V6(nn.Module):    
    #V2: Fed image vector directly as hidden and fed words generated as iputs back to LSTM
    #V3: Added an embedding layer between words input and GRU/LSTM
    
    def __init__(self, token_dict_size):
        super(IC_V6, self).__init__()        
        #Input is an image of height 500, and width 500                
        self.embedding_size = INPUT_EMBEDDING
        self.hidden_state_size = HIDDEN_SIZE
        self.token_dict_size = token_dict_size
        self.output_size = OUTPUT_EMBEDDING        
        self.batchnorm = nn.BatchNorm1d(self.embedding_size)
        self.input_embedding = nn.Embedding(self.token_dict_size, self.embedding_size)        
        self.embedding_dropout = nn.Dropout(p=0.22)
        self.gru_layers = 3        
        self.gru = nn.GRU(input_size=self.embedding_size, hidden_size=self.hidden_state_size, 
                          num_layers=self.gru_layers, dropout=0.22)
        self.linear = nn.Linear(self.hidden_state_size, self.output_size)
        self.out = nn.Linear(self.output_size, token_dict_size)
        
    def forward(self, input_tokens, hidden, process_image=False, use_inception=True):                        
        if(USE_GPU):
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')
        
        if(process_image):
            if(use_inception):
                inp=self.embedding_dropout(inception(hidden))
            else:
                inp=hidden
            #inp=self.batchnorm(inp)
            hidden=torch.zeros((self.gru_layers,1, self.hidden_state_size))
        else:
            inp=self.embedding_dropout(self.input_embedding(
                input_tokens.view(1).type(torch.LongTensor).to(device)))
            #inp=self.batchnorm(inp)                    
        hidden = hidden.view(self.gru_layers,1,-1)
        inp = inp.view(1,1,-1)        
        out, hidden = self.gru(inp, hidden)
        out = self.out(self.linear(out))                       
        return out, hidden        

In [None]:
f = Flickr8KImageCaptionDataset()

In [None]:
net = IC_V6(f.tokens)

In [None]:
net.load_state_dict(torch.load('models/epochs_40_loss_2_841_v6.pth'))

In [None]:
if(USE_GPU):
    net.cuda()
    inception.cuda()

In [None]:
net.eval()

In [None]:
#traindataset
random_train_index =  np.random.randint(len(f.training_data))
train_filename = 'Flicker8k_Dataset/'+f.training_data[random_train_index]['filename']
print("Original caption: ",f.training_data[random_train_index]['sentences'][0]['raw'])
print("")
print("Greedy caption:", f.caption_image_greedy(net,train_filename))
print("")
print("Beam caption:")
f.print_beam_caption(net, train_filename)

IPython.display.Image(filename=train_filename) 

## Train the network

In [None]:
l = torch.nn.CrossEntropyLoss(reduction='none')

In [None]:
o = optim.Adam(net.parameters(), lr=0.0001)

In [None]:
epochs = 20
epochs = 5
inception.eval()
net.train()
loss_so_far = 0.0
total_samples = len(f.training_data)

for epoch in range(epochs):   
    for (image_tensor, tokens, _, index) in f.return_train_batch():    
        o.zero_grad()
        net.zero_grad()

        words = []
        loss=0.
        input_token = f.w2i[STARTWORD]
        input_tensor = torch.tensor(input_token)
        for token in tokens:
            if(input_token==f.w2i[STARTWORD]):
                out, hidden=net(input_tensor, image_tensor, process_image=True)
            else:
                out, hidden=net(input_tensor, hidden)

            class_label = torch.tensor(token).view(1)
            input_token = token
            input_tensor = torch.tensor(input_token)
            out = out.squeeze().view(1,-1)
            loss += l(out,class_label)
        loss = loss/len(tokens)
        loss.backward()
        o.step()
        loss_so_far += loss.detach().item()
        if(np.random.rand()<0.002): #5% of cases
            print("Epoch: ", epoch, 
                  ", index: ", index,
                  " loss: ", round(loss.detach().item(),3),
                  " | running avg loss: ", round(loss_so_far/((epoch*total_samples)+(index+1)),3))
            torch.save(net.state_dict(), 'models/running_save_v6.pth')
            torch.save(net.state_dict(), 'models/running_inception_save_v6.pth')
            net.eval()
           
            #test dataset
            #random_train_index = np.random.randint(len(f.training_data))
            random_train_index = index
            train_filename = IMAGE_DIR+f.training_data[random_train_index]['filename']
            print("Original caption: ")
            [print(x['raw'].lower()) for x in f.training_data[random_train_index]['sentences']]
            print("")
            print("Greedy caption:", f.caption_image_greedy(net,train_filename))
            print("")
            print("Beam caption:")
            f.print_beam_caption(net, train_filename, 3)
            #IPython.display.Image(filename=test_filename) 
            pil_im = Image.open(train_filename, 'r')
            plt.figure()
            plt.imshow(np.asarray(pil_im))
            plt.show()
            net.train()
    
    print("\n\n")
    print("==== EPOCH DONE. === ")
    print("\n\n")

## Helper functions for visualizations

In [None]:
def rand_cmap(nlabels, type='bright', first_color_black=True, last_color_black=False, verbose=True):
    """
    Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks
    :param nlabels: Number of labels (size of colormap)
    :param type: 'bright' for strong colors, 'soft' for pastel colors
    :param first_color_black: Option to use first color as black, True or False
    :param last_color_black: Option to use last color as black, True or False
    :param verbose: Prints the number of labels and shows the colormap. True or False
    :return: colormap for matplotlib
    """
    from matplotlib.colors import LinearSegmentedColormap
    import colorsys
    import numpy as np

    if type not in ('bright', 'soft'):
        print ('Please choose "bright" or "soft" for type')
        return

    if verbose:
        print('Number of labels: ' + str(nlabels))

    # Generate color map for bright colors, based on hsv
    if type == 'bright':
        randHSVcolors = [(np.random.uniform(low=0.0, high=1),
                          np.random.uniform(low=0.2, high=1),
                          np.random.uniform(low=0.9, high=1)) for i in range(nlabels)]

        # Convert HSV list to RGB
        randRGBcolors = []
        for HSVcolor in randHSVcolors:
            randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2]))

        if first_color_black:
            randRGBcolors[0] = [0, 0, 0]

        if last_color_black:
            randRGBcolors[-1] = [0, 0, 0]

        random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)

    # Generate soft pastel colors, by limiting the RGB spectrum
    if type == 'soft':
        low = 0.6
        high = 0.95
        randRGBcolors = [(np.random.uniform(low=low, high=high),
                          np.random.uniform(low=low, high=high),
                          np.random.uniform(low=low, high=high)) for i in xrange(nlabels)]

        if first_color_black:
            randRGBcolors[0] = [0, 0, 0]

        if last_color_black:
            randRGBcolors[-1] = [0, 0, 0]
        random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels)

    # Display colorbar
    if verbose:
        from matplotlib import colors, colorbar
        from matplotlib import pyplot as plt
        fig, ax = plt.subplots(1, 1, figsize=(15, 0.5))

        bounds = np.linspace(0, nlabels, nlabels + 1)
        norm = colors.BoundaryNorm(bounds, nlabels)

        cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None,
                                   boundaries=bounds, format='%1i', orientation=u'horizontal')

    return random_colormap

## Start visualizations (First of word embeddings)

In [None]:
frequency_threshold = 50 #the word should have appeared at least this many times for us to visualize

all_word_embeddings = []
all_words = []

for word in f.word_frequency.keys():
    if(f.word_frequency[word]>=frequency_threshold):
        all_word_embeddings.append(net.input_embedding(torch.tensor(f.w2i[word])).detach().numpy())
        all_words.append(word)

In [None]:
len(all_words)

In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, random_state=0)

In [None]:
X_2d = tsne.fit_transform(all_word_embeddings)

In [None]:
new_cmap = rand_cmap(10, type='bright', first_color_black=True, last_color_black=False, verbose=True)

In [None]:
fig,ax = plt.subplots(figsize=(7, 7))
    
sc = plt.scatter(X_2d[:,0], X_2d[:,1])
#plt.legend()
#plt.show()

annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->", color='red'))
annot.set_visible(False)

def update_annot(ind):
    pos = sc.get_offsets()[ind["ind"][0]]
    annot.xy = pos
    text = "{}".format(" ".join([all_words[n] for n in ind["ind"]]))
    annot.set_text(text)
    annot.get_bbox_patch().set_facecolor('white')
    annot.get_bbox_patch().set_alpha(0.9)


def hover(event):
    vis = annot.get_visible()
    if event.inaxes == ax:
        cont, ind = sc.contains(event)
        if cont:                        
            update_annot(ind)
            annot.set_visible(True)
            fig.canvas.draw_idle()
        else:            
            if vis:
                annot.set_visible(False)
                fig.canvas.draw_idle()
                
def onpick(event):
    ind = event.ind
    print(ind)
    label_pos_x = event.mouseevent.xdata
    label_pos_y = event.mouseevent.ydata
    annot.xy = (label_pos_x,label_pos_y)
    annot.set_text(y[ind])
    ax.figure.canvas.draw_idle()
        
fig.canvas.mpl_connect("motion_notify_event", hover)
#fig.canvas.mpl_connect('pick_event', onpick)
plt.show()

## find top 5 closest words due to similarity

In [None]:
from scipy import spatial
def return_cosine_sorted(target_word_embedding):    
    words = []
    cosines = []    
    for i in range(len(all_word_embeddings)):
        cosines.append(1 - spatial.distance.cosine(target_word_embedding, all_word_embeddings[i]))
        
    sorted_indexes = np.argsort(cosines)[::-1]    
    return np.vstack((np.array(all_words)[sorted_indexes], np.array(cosines)[sorted_indexes])).T

In [None]:
all_words[:10]
#print(all_words)

In [None]:
def return_similar_words(word, top_n=5):    
    return return_cosine_sorted(return_embedding(word))[1:top_n+1]    

In [None]:
def return_embedding(word):
    if(word in all_words):
        target_embedding_index = [i for i, s in enumerate(all_words) if word in s][0]
        return all_word_embeddings[target_embedding_index]
    else:
        return None

In [None]:
def return_analogy(source_word_1, source_word_2, target_word_1, top_n=5):    
    em_sw_1 = return_embedding(source_word_1)
    em_sw_2 = return_embedding(source_word_2)
    em_tw_1 = return_embedding(target_word_1)
    
    if((em_sw_1 is None) | (em_sw_2 is None) | (em_tw_1 is None)):
        return 0
    
    target_embedding = em_tw_1 + (em_sw_2 - em_sw_1)
    return return_cosine_sorted(target_embedding)[1:top_n+1]  

In [None]:
return_similar_words('boy')

In [None]:
return_analogy('green', 'grass', 'red')

## embedding of images (visualize)

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

def pol2cart(rho, phi):
    x = rho * np.cos(phi)
    y = rho * np.sin(phi)
    return(x, y)

In [None]:
import itertools
all_image_embeddings = []
all_image_filenames = []

for i in range(len(f.training_data)):
    all_image_embeddings.append(inception(f.image_to_tensor('Flicker8k_Dataset/'
                                                            +f.training_data[i]['filename'])).detach().numpy())
    all_image_filenames.append(f.training_data[i]['filename'])

In [None]:
all_image_embeddings_temp = all_image_embeddings[:]
all_image_filenames_temp = all_image_filenames[:]

In [None]:
from matplotlib.offsetbox import (TextArea, 
                                  DrawingArea, 
                                  OffsetImage,
                                  AnnotationBbox)

In [None]:
from sklearn.manifold import TSNE
tsne_images = TSNE(n_components=2, random_state=0)

In [None]:
X_2d = tsne.fit_transform(np.squeeze(all_image_embeddings_temp))

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))    
sc = plt.scatter(X_2d[:,0], X_2d[:,1])
annot = ax.annotate("", xy=(0,0), xytext=(20,20),textcoords="offset points",
                    bbox=dict(boxstyle="round", fc="w"),
                    arrowprops=dict(arrowstyle="->", color='red'))
annot.set_visible(False)

def update_annot(ind):
    pos = sc.get_offsets()[ind["ind"][0]]
    annot.xy = pos
    #text = "{}".format(" ".join([all_words[n] for n in ind["ind"]]))
    #annot.set_text(text)
    
    rho = 10 #how for to draw centers of new images
    total_radians = 2* np.pi
    num_images = len(ind["ind"])
    if(num_images > 4): #at max 4
        num_images=4
    radians_offset = total_radians/num_images    
    
    for i in range(num_images):                
        hovered_filename = 'Flicker8k_Dataset/'+all_image_filenames_temp[ind["ind"][i]]
        arr_img = Image.open(hovered_filename, 'r')
        imagebox = OffsetImage(arr_img, zoom=0.3)
        #imagebox.image.axes = ax    
        offset = pol2cart(rho, i*radians_offset)                   
        new_xy = (pos[0]+offset[0], pos[1]+offset[1])                
        ab = AnnotationBbox(imagebox, new_xy)
        ax.add_artist(ab)    
        annot.get_bbox_patch().set_facecolor('white')
        annot.get_bbox_patch().set_alpha(0.9)


def hover(event):    
    vis = annot.get_visible()
    if event.inaxes == ax:
        cont, ind = sc.contains(event)
        if cont:                        
            update_annot(ind)            
            annot.set_visible(True)
            fig.canvas.draw_idle()
        else:
            if vis:
                annot.set_visible(False)
                remove_all_images()
                fig.canvas.draw_idle()

def remove_all_images():
    for obj in ax.findobj(match = type(AnnotationBbox(1, 1))):    
        obj.remove()        

fig.canvas.mpl_connect("motion_notify_event", hover)
#fig.canvas.mpl_connect('pick_event', onpick)
plt.show()

## Similar images to a given image

In [None]:
def plot_image(filename):
    pil_im = Image.open(filename, 'r')
    plt.figure()
    plt.imshow(np.asarray(pil_im))
    plt.show()

In [None]:
def return_similar_images(image_filename, top_n=5):   
    return return_cosine_sorted_image(return_embedding_image(image_filename))[1:top_n+1]    

In [None]:
def return_cosine_sorted_image(target_image_embedding):            
    cosines = []    
    for i in range(len(all_image_embeddings)):
        cosines.append(1 - spatial.distance.cosine(target_image_embedding, all_image_embeddings[i]))        
    sorted_indexes = np.argsort(cosines)[::-1]    
    return np.vstack((np.array(all_image_filenames)[sorted_indexes], np.array(cosines)[sorted_indexes])).T

In [None]:
def return_embedding_image(image_filename):
    return inception(f.image_to_tensor(image_filename)).detach().numpy().squeeze()