In [57]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

from text_processing import tokenize_text, untokenize, pad_text, Toks

In [55]:
gpu_id = 0
BATCH_SIZE = 3
model_path = "./models/"
test_model_fname = "img_to_txt_state.tar"
test_folder = "./test_img/"

device = torch.device('cuda:{}'.format(gpu_id)) \
    if torch.cuda.is_available() else torch.device('cpu')

In [6]:
class NopModule(torch.nn.Module):
    def __init__(self):
        super(NopModule, self).__init__()
    
    def forward(self, input):
        return input

def get_cnn():
    inception = models.inception_v3(pretrained=True)  # pretrain on GoogleNet
    inception.fc = NopModule()
    inception = inception.to(device)
    inception.eval()

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
    trans = transforms.Compose([
                transforms.Resize(299),
                transforms.CenterCrop(299),
                transforms.ToTensor(),
                normalize,
            ])
    return inception, trans

Paper: The image feature is extracted from the second last layer of the Inception-v3 [51] CNN pre-trained on ImageNet [45].

In [17]:
class ImgEmb(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(ImgEmb, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        
        self.mlp = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
    
    def forward(self, input):
        res = self.relu(self.mlp(input))
        return res

In [18]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.emb_drop = nn.Dropout(0.5)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.gru_drop = nn.Dropout(0.5)
        self.mlp = nn.Linear(hidden_size, output_size)
        self.logsoftmax = nn.LogSoftmax(dim=2)
    
    def forward(self, input_, hidden_in):
        emb = self.embedding(input_)
        out, hidden = self.gru(self.emb_drop(emb), hidden_in)
        out = self.mlp(self.gru_drop(out))
        out = self.logsoftmax(out)
        return out, hidden

In [19]:
def build_model(dec_vocab_size, img_feat_size = 2048, 
        hid_size=512, loaded_state = None):
    enc = ImgEmb(img_feat_size, hid_size)
    dec = Decoder(dec_vocab_size, hid_size, dec_vocab_size)
    if loaded_state is not None:
        enc.load_state_dict(loaded_state['enc'])
        dec.load_state_dict(loaded_state['dec'])
    enc = enc.to(device)
    dec = dec.to(device)
    return enc, dec

In [24]:
def setup_test():
    cnn, trans = get_cnn()
    if torch.cuda.is_available():
        loaded_state = torch.load(model_path + test_model_fname, 
                                  map_location=device)
    else: 
        loaded_state = torch.load(model_path + test_model_fname, 
                                  map_location='cpu')
    
    dec_vocab_size = len(loaded_state['dec_idx_to_word'])  # 10,004
    enc,dec = build_model(dec_vocab_size, loaded_state=loaded_state)
    
    return {'cnn': cnn, 'trans': trans, 'enc':enc, 'dec':dec, 
            'loaded_state':loaded_state}

In [33]:
def has_image_ext(path):
    IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
    ext = os.path.splitext(path)[1]
    if ext.lower() in IMG_EXTENSIONS:
        return True
    return False

def list_image_folder(root):
    images = []
    dir = os.path.expanduser(root)
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if os.path.isdir(d):
            continue
        if has_image_ext(d):
            images.append(d)
    return images

def safe_pil_loader(path, from_memory=False):
    try:
        if from_memory:
            img = Image.open(path)
            res = img.convert('RGB')
        else:
            with open(path, 'rb') as f:
                img = Image.open(f)
                res = img.convert('RGB')
    except:
        res = Image.new('RGB', (299, 299), color=0)
    return res

class ImageTestFolder(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        self.root = root
        self.loader = safe_pil_loader
        self.transform = transform

        self.samples = list_image_folder(root)

    def __getitem__(self, index):
        path = self.samples[index]
        sample = self.loader(path)
        sample = self.transform(sample)
        return sample, path

    def __len__(self):
        return len(self.samples)
    
def get_image_reader(dirpath, transform, batch_size, workers=4):
    image_reader = torch.utils.data.DataLoader(
            ImageTestFolder(dirpath, transform),
            batch_size=batch_size, shuffle=False,
            num_workers=workers, pin_memory=True)
    return image_reader

In [47]:
def generate(enc, dec, feats, L=20):
    enc.eval()
    dec.eval()
    with torch.no_grad():
        hid_enc = enc(feats).unsqueeze(0)

        # run the decoder step by step
        dec_tensor = torch.zeros(feats.shape[0], L+1, dtype=torch.long)
        dec_tensor = dec_tensor.to(device)
        last_enc = hid_enc
        for i in range(L):
            out_dec, hid_dec = dec.forward(dec_tensor[:,i].unsqueeze(1), last_enc)
            chosen = torch.argmax(out_dec[:,0],dim=1)
            dec_tensor[:, i+1] = chosen
            last_enc = hid_dec
    
    return dec_tensor.data.cpu().numpy()

In [50]:
# main
setup_data = setup_test()

enc = setup_data['enc']
dec = setup_data['dec']
cnn = setup_data['cnn']
trans = setup_data['trans']
loaded_state = setup_data['loaded_state']

In [58]:
img_reader = get_image_reader(test_folder, trans, BATCH_SIZE)
using_images = True

all_text = []
for input_, text_data in img_reader:
    """
    input: [batch, 3, 299, 299]
    text_data: list of img path(include file), 
    """
    input_ = input_.to(device)
    with torch.no_grad():
        batch_feats_tensor = cnn(input_)  # [batch, 2048]
    print("batch_feats_tensor: ", batch_feats_tensor.shape)
    dec_tensor = generate(enc, dec, batch_feats_tensor)  # [batch, 21]
    print("dec_tensor: ",dec_tensor.shape)
    untok = []  # save the word
    for i in range(dec_tensor.shape[0]):
        untok.append(untokenize(dec_tensor[i], 
                                loaded_state['dec_idx_to_word'], 
                                to_text=False))
    print("untok:", untok)
    # then, send untok to seq2seq
    break

batch_feats_tensor:  torch.Size([3, 2048])
dec_tensor:  (3, 21)
untok: [['manNOUNNOUNNOUN', 'FRAMENETPosture', 'tennisNOUNNOUNNOUN', 'courtNOUNNOUNNOUN', 'FRAMENETContaining', 'racquetNOUNNOUNNOUN'], ['manNOUNNOUNNOUN', 'FRAMENETPosture', 'fieldNOUNNOUNNOUN'], ['bearNOUNNOUNNOUN', 'FRAMENETPlacing', 'topNOUNNOUNNOUN']]
