In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
import torch
import random
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")


In [2]:
random.seed(0)

class TwoImageDataset(Dataset):
    """ Own dataset """

    def __init__(self, quickdraw_dir, realworld_dir, fruit_category, image_size = 255, class_size = 5000):
        """
        Args:
            sketch_dir (string): Directory to all the sketch images.
            realworld_dir (string): Directory to all the real world images.
            fruit_category: list to fruit catogory
            class_size: Num of images in each category
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.quickdraw_dir = quickdraw_dir
        self.realworld_dir = realworld_dir
        self.transform = transform
        self.fruit_category = fruit_category
        self.class_size = class_size
        self.quickdraw_data_dict = dict(np.load(quickdraw_dir))
        self.realworld_data_dict = dict(np.load(realworld_dir))
        self.image_size = image_size
        
        self.transform_img = transforms.Compose([transforms.ToPILImage(),
                                                transforms.Resize(image_size),
                                                transforms.ToTensor()])
        self.transform_label = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return self.class_size * len(self.fruit_category)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        class_index = int(idx // self.class_size)
        category =  self.fruit_category[class_index]
        category_idx_real = random.choice(range(self.class_size))
        category_idx_sketch = random.choice(range(self.class_size))
        
        label = np.zeros((len(self.fruit_category), 1))
        label[class_index] = 1
      
        quickdraw_ary =  self.quickdraw_data_dict[category][category_idx_sketch]
        realimage_ary =  self.realworld_data_dict[category][category_idx_real]
        
        sample = {'realworld': realimage_ary, 'sketch': quickdraw_ary, 'label': label}
        if self.transform:
            sample['realworld'] = self.transform_img(sample['realworld'])
            sample['sketch'] = self.transform_img(sample['sketch'])
            sample['label'] = self.transform_label(sample['label'])
        return sample
  

In [3]:
QURIES = ['banana', 'strawberry', 'pear', 'watermelon', 'grapes','pineapple', 'apple', ]  

train_quickdraw_dir = './data/compressed_quickdraw_train.npz'
test_quickdraw_dir = './data/compressed_quickdraw_test.npz'
train_realworld_dir = './data/compressed_realworld_train.npz'
test_realworld_dir = './data/compressed_realworld_test.npz'

train_both = TwoImageDataset(train_quickdraw_dir, train_realworld_dir, QURIES, image_size = 255, class_size = 4000)
test_both = TwoImageDataset(test_quickdraw_dir, test_realworld_dir, QURIES, image_size = 255, class_size = 1000)


In [4]:
def convbn(in_channels, out_channels, kernel_size, stride, padding, bias):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class CNN(nn.Module):
    def __init__(self, n_channels=3, n_classes=7, dropout=0.1):
        super(CNN, self).__init__()
        layer1   = convbn(n_channels, 64, kernel_size=3, stride=2, padding=1, bias=True)
        layer2   = convbn(64,  128, kernel_size=3, stride=2, padding=1, bias=True)
        layer3   = convbn(128, 192, kernel_size=3, stride=2, padding=1, bias=True)
        layer4   = convbn(192, 256, kernel_size=3, stride=2, padding=1, bias=True)
        layer1_2 = convbn(64,  64,  kernel_size=3, stride=1, padding=0, bias=True)
        layer2_2 = convbn(128, 128, kernel_size=3, stride=1, padding=0, bias=True)
        layer3_2 = convbn(192, 192, kernel_size=3, stride=1, padding=0, bias=True)
        layer4_2 = convbn(256, 256, kernel_size=3, stride=1, padding=0, bias=True)
        
        pool = nn.AdaptiveAvgPool2d((1,1))
        self.layers = nn.Sequential(layer1, layer1_2, layer2, layer2_2, layer3, layer3_2, layer4, layer4_2, pool)
        self.nn = nn.Linear(256, n_classes)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x):
        feats = self.layers(x).flatten(1)
        x = self.nn(self.dropout(feats))
        return x, feats
    
class CNN2(nn.Module):
    def __init__(self, n_channels=3, n_classes=7, dropout=0.1):
        super(CNN2, self).__init__()
        layer1   = convbn(n_channels, 64, kernel_size=3, stride=2, padding=1, bias=True)
        layer2   = convbn(64,  128, kernel_size=3, stride=2, padding=1, bias=True)
        layer3   = convbn(128, 256, kernel_size=3, stride=2, padding=1, bias=True)
        layer1_2 = convbn(64,  64,  kernel_size=3, stride=1, padding=0, bias=True)
        layer2_2 = convbn(128, 128, kernel_size=3, stride=1, padding=0, bias=True)
        layer3_2 = convbn(256, 256, kernel_size=3, stride=1, padding=0, bias=True)
        pool = nn.AdaptiveAvgPool2d((1,1))
        self.layers = nn.Sequential(layer1, layer1_2, layer2, layer2_2, layer3, layer3_2, pool)
        self.nn = nn.Linear(256, n_classes)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x):
        feats = self.layers(x).flatten(1)
        x = self.nn(self.dropout(feats))
        return x, feats

In [5]:
def contrastive_loss(features, labels, temperature):
    
    # compute feature
    feats_matrix = features.unsqueeze(2).expand(-1, -1, features.size(0)) # (batchsize, feats.len, batchsize)
    trans_feats_matrix = feats_matrix.transpose(0, 2)
    sim_matrix = F.cosine_similarity(feats_matrix, trans_feats_matrix, dim=1)
    
    # compute label (batchsize)
    batchsize = labels.shape[0]
    label_matrix = labels.unsqueeze(-1).expand((batchsize, batchsize))
    trans_label_matrix = label_matrix.transpose(0, 1)
    target_matrix = (label_matrix == trans_label_matrix).type(torch.float)
    
    KL = nn.KLDivLoss(reduction="batchmean", log_target=False)
    loss = KL(F.softmax(sim_matrix / temperature).log(), F.softmax(target_matrix / temperature))  
    return loss


class CNNCL(nn.Module):
    def __init__(self, n_classes=7, dropout=0.1, t=0.1):
        super(CNNCL, self).__init__()
        self.t = t
        self.quickdraw_model = CNN()
        self.realworld_model = CNN2()
    def forward(self, quickdraw_x, realworld_x):
        quickdraw_pred, quickdraw_feat = self.quickdraw_model(quickdraw_x)
        realworld_pred, realworld_feat = self.realworld_model(realworld_x)
        return quickdraw_pred, quickdraw_feat, realworld_pred, realworld_feat
    def loss(self, quickdraw_pred, quickdraw_feat, realworld_pred, realworld_feat, y):
        CE = nn.CrossEntropyLoss()
        quickdraw_CE = CE(quickdraw_pred, y)
        realworld_CE = CE(realworld_pred, y)
        label = y.argmax(1)
        CL_quickdraw = contrastive_loss(quickdraw_feat, label, self.t)
        CL_realworld = contrastive_loss(realworld_feat, label, self.t)
        CL_both = contrastive_loss(quickdraw_feat*realworld_feat, label, self.t)
        loss = quickdraw_CE + realworld_CE + CL_quickdraw + CL_realworld + CL_both
        return loss, quickdraw_CE, realworld_CE

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CNN_CL = CNNCL().to(device)
CNN_CL.load_state_dict(torch.load('./model/CNNCL.pth'))

<All keys matched successfully>

In [None]:
realworldloader = DataLoader(test_both, batch_size= 100, shuffle=True, pin_memory=True, num_workers=32)
CNN_CL.eval()
realworld_pred_lst = []
realworld_lst = []
realworld_y_lst = []
with torch.no_grad():
     for _, data in enumerate(realworldloader):
        realworld, sketch, y = data['realworld'].to(device), data['sketch'].to(device), data['label'].to(device)
        pred1, feat1, pred2, feat2 = CNN_CL(sketch, realworld)
        realworld_pred_lst.append(pred2)
        realworld_y_lst.append(y.squeeze())
        realworld_lst.append(realworld)
        if _ == 10:
            break
realworld_pred = torch.cat(realworld_pred_lst,0)
realworld = torch.cat(realworld_lst,0).cpu()

class_size = 1000
test_quickdraw = dict(np.load('./data/compressed_quickdraw_test.npz'))

def search_engine_demo(choose_sketch_index, model, k = 5):
    if choose_sketch_index >= 7000:
        print('Sorry, we only have 7000 sketch images in test!')
        print('Let us pick the 1st sketch image for you to test!')
        choose_sketch_index = 1
    class_index = int(choose_sketch_index // class_size)
    category =  QURIES[class_index]
    category_idx = int(choose_sketch_index % class_size)
    transform_img = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])   
    quickdraw_ary =  transform_img(test_quickdraw[category][category_idx]).unsqueeze(0).cuda()
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    pred1, feat1, pred2, feat2 = model(quickdraw_ary, quickdraw_ary)
    quickdraw_y = y.squeeze()
    quickdraw_pred = nn.functional.normalize(pred1)
    cos_similarity = cos(realworld_pred,quickdraw_pred)
    topk_index = torch.topk(cos_similarity, k).indices

    topk_index = topk_index.cpu()
    return test_quickdraw[category][category_idx], topk_index


In [None]:
def plot_search_engine(real_world, choosed_sketch, topk_index):
    print('\n')
    print('************** Welcome to Our Search Engine Demo **************')
    print('\n')
    print('--------------------- User Input Sketch ----------------------')
    fig, axs = plt.subplots(1,1,figsize=(3,3))
    axs.imshow(choosed_sketch)
    plt.show()

    print('------------------ Search Engine Outputs ---------------------')
    if len(top_k_index) == 1:
        fig_realworld, axs_realworld = plt.subplots(1,1,figsize=(5,5))
        axs_realworld.imshow(realworld[top_k_index[0]].permute(1,2,0))
    else: 
        fig_realworld, axs_realworld = plt.subplots(1,len(top_k_index),figsize=(20,80))
        for i in range(len(top_k_index)):
            k = top_k_index[i]
            axs_realworld[i].imshow(realworld[k].permute(1,2,0))
    plt.show()
    print('************************ THE END *****************************')

In [None]:
choose_sketch_index = 1
choosed_sketch, top_k_index = search_engine_demo(choose_sketch_index, CNN_CL, k = 5)
plot_search_engine(realworld, choosed_sketch, top_k_index )