In [None]:
cd /root/userspace/public/JSRT/sakka/medical_image_attention/src/

In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.cm as cm
from PIL import Image
import gc
import os
import nltk
nltk.download('punkt')
from tqdm import tqdm
import pickle
import skimage.transform
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.nn.utils.rnn import  pack_padded_sequence

from model import Decoder, EncoderResNet
from kenshin_util import  to_variable, tensor2numpy, Vocabulary

In [None]:
def rescale_feature(img):
    img += abs(img.min())
    img /= img.max()
    return img

In [None]:
def set_transform():
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])
    
    return transform

In [None]:
root_img_dir= "/root/userspace/public/JSRT/sakka/medical_image_attention/image/jpg/"
#label_df = pd.read_csv("/root/userspace/public/JSRT/sakka/medical_image_attention/data/label/upsampling/image_freq_thresh_5_test.csv")
label_df = pd.read_csv("/root/userspace/public/JSRT/sakka/medical_image_attention/refactoring/label/image_freq_thresh_5_test.csv")
for i,path in enumerate(label_df["path"]):
    label_df.iloc[i]["path"] = root_img_dir + path
print(len(label_df))
label_df.head()

# vocab, encoder, decoderのpath指定

In [None]:
#vocab_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/data/vocab/vocab_freq_thresh_5.pkl"
vocab_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/refactoring/vocab/vocab_freq_thresh_5.pkl"
vocab = pickle.load(open(vocab_path, "rb"))

# Encoder
encoder_model = EncoderResNet()
#encoder_model_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/data/model/end_to_end/upsampling/encoder_freq5_each_word_100_balance.pth"
encoder_model_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/refactoring/model/encoder_freq5.pth"
encoder_model = nn.DataParallel(encoder_model)
encoder_model.load_state_dict(torch.load(encoder_model_path))


# model setting
vis_dim = 2048
vis_num = 196
#embed_dim = len(vocab)
embed_dim = 125
hidden_dim = 256
#vocab_size = len(vocab)
vocab_size = 125
num_layers = 1
dropout_ratio = 1.0

decoder_model = Decoder(vis_dim=vis_dim,
                vis_num=vis_num, 
                embed_dim=embed_dim,
                hidden_dim=hidden_dim, 
                vocab_size=vocab_size, 
                num_layers=num_layers,
                dropout_ratio=dropout_ratio)

decoder_model = nn.DataParallel(decoder_model)

#decoder_model_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/data/model/end_to_end/upsampling/decoder_freq5_each_word_100_balance.pth"
decoder_model_path = "/root/userspace/public/JSRT/sakka/medical_image_attention/refactoring/model/decoder_freq5.pth"
decoder_model.load_state_dict(torch.load(decoder_model_path))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE: {}".format(device))

encoder_model = encoder_model.to(device)
encoder_model.eval()
decoder_model = decoder_model.to(device)
decoder_model.eval()

In [None]:
# split character leve
opinion_lst = list(label_df["caption"])    
name_lst = []
for opinion in opinion_lst:
    name = ""
    for char in opinion:
        name += "{} ".format(char)
    name_lst.append(name)
 
# initialize
name_lst = name_lst
img_path_lst = list(label_df["path"])
names_lst = []
captions_lst = []
alphas_lst = []
transform = set_transform()
rm_path_cnt = 0

with torch.no_grad():
    for img_path, name in tqdm(zip(img_path_lst, name_lst)):
        if os.path.exists(img_path):
            img = Image.open(img_path).convert("RGB")
            img = transform(img)
            img = to_variable(img)
            fea = encoder_model(img.unsqueeze(0))
            fea = fea.view(fea.size(0), 2048, 196).transpose(1,2)

            ids, weights = decoder_model.module.sample(fea)
            names_lst.append(name)
            captions_lst.append(ids)
            alphas_lst.append(weights)
        else:
            rm_path_cnt += 1
            
print("Not exist path: {0}".format(rm_path_cnt))

In [None]:
print("name: {}".format(names_lst[0]))
print("caption: {}".format(captions_lst[0]))
print("alphas: {}".format(np.array(alphas_lst[0][1]).shape))

In [None]:
def decode_captions(captions, idx_to_word):
    N, D = captions.shape
    decoded = []
    for idx in range(N):
        words = []
        for wid in range(D):
            word = idx_to_word[captions[idx, wid]]
            if word == '<end>':
                break
            words.append(word)
        decoded.append(words)
    return decoded

In [None]:
def attention_visualization(img_path, caption, alphas):
    image = Image.open(img_path).convert("RGB")
    image = image.resize((224, 224))
    plt.figure(figsize=(12, 12))
    plt.subplot(4,5,1)
    plt.text(0, 1, "<start>" , color='black', backgroundcolor='white', fontsize=8)
    plt.imshow(image)
    plt.axis('off')
    
    words = caption
    total_alp = np.zeros((14, 14))
    for t in range(len(words)):
        if words[t] == "<end>":
            break
        if t > 14:
            break
        plt.subplot(4, 5, t+2)
        plt.text(0, 1, '%s'%(words[t]) , color='black', backgroundcolor='white', fontsize=14)
        plt.imshow(image)

        alp_curr = alphas[t, :].view(14, 14)
        total_alp += alp_curr
        alp_img = skimage.transform.pyramid_expand(alp_curr.numpy(), upscale=16, sigma=20)
        plt.imshow(alp_img, alpha=0.5)
        plt.axis('off')
    
    plt.figure()
    plt.imshow(image)
    total_alp = skimage.transform.pyramid_expand(total_alp, upscale=16, sigma=20)
    plt.imshow(total_alp, alpha=0.5)
    plt.show()

In [None]:
def get_result(alphas_lst, captions_lst, img_path_lst, vocab, idx):
    alps = torch.cat(alphas_lst[idx][1:], 0)
    alps += abs(alps.min())
    alps /= alps.max()
    cap = decode_captions(captions_lst[idx].data.cpu().numpy().reshape(1, -1), vocab.idx2word)[0]
    print("Prediction : {}".format(" ".join(cap)))
    attention_visualization(img_path_lst[idx], cap, alps.data.cpu())

In [None]:
plt.rcParams["font.family"] = "IPAexGothic"
test_id = 21
print("ID: {}".format(test_id))
print("Answer     : {}".format(names_lst[test_id]))
get_result(alphas_lst, captions_lst, img_path_lst, vocab, test_id)