In [13]:
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from torch.autograd.variable import Variable
from typing import List
from PIL import Image
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader

In [2]:
class Encoder(nn.Module):
    def __init__(self, in_channels=3, dropout=0.2):
        """
        Encoder for Image2NodeNet.
        :param dropout: dropout
        """
        super(Encoder, self).__init__()
        self.p = dropout
        self.conv1 = nn.Conv2d(in_channels, 8, 3, padding=(1, 1))
        self.conv2 = nn.Conv2d(8, 16, 3, padding=(1, 1))
        self.conv3 = nn.Conv2d(16, 32, 3, padding=(1, 1))
        self.conv4 = nn.Conv2d(32, 64, 3, padding=(1,1))
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = F.max_pool2d(self.drop(F.relu(self.conv1(x))), (2, 2))
        x = F.max_pool2d(self.drop(F.relu(self.conv2(x))), (2, 2))
        x = F.max_pool2d(self.drop(F.relu(self.conv3(x))), (2, 2))
        x = F.max_pool2d(F.relu(self.conv4(x)), (2, 2))
        return x

In [3]:
class ResnetEncoder(nn.Module):
    def __init__(self, arch_name, in_channels=3, pretrained=False):
        super(ResnetEncoder, self).__init__()
        self.arch_name = arch_name
        self.pretrained = pretrained
        self.in_channels = in_channels
        self.build_model()

    def build_model(self):
        print(f'Building {self.arch_name} model (pretrained={self.pretrained})!!')
                    
        if self.arch_name == 'resnet50':
            base_model = models.resnet50(pretrained=self.pretrained)
            self.features = nn.Sequential(*list(base_model.children())[:-1])

        elif self.arch_name == 'resnet18':
            base_model = models.resnet18(pretrained=self.pretrained)
            self.features = nn.Sequential(*list(base_model.children())[:-1]) 
                
        else:
            raise('This architecture is not supported!!')
        
        if self.in_channels != 3:
            self.features[0] =  nn.Conv2d(self.in_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        if self.pretrained:
            i = 4 if self.in_channel != 3 else 0
            # Freeze all weights before CB4
            for param in self.features[i:6].parameters():
                param.requires_grad = False
                                        

    def forward(self, x):
        return self.features(x)

    

In [4]:
im = Image.open('data/tables/Table_1/Table_1_Image_10.png').convert('RGB')
im_t = transforms.functional.to_tensor(transforms.functional.center_crop(im, 256)).unsqueeze(0)
im_t.shape

torch.Size([1, 3, 256, 256])

In [5]:
enc = ResnetEncoder(arch_name='resnet18')

Building resnet18 model (pretrained=False)!!


In [6]:
encoded = enc(im_t)
encoded.shape

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


torch.Size([1, 512, 1, 1])

In [7]:
class Image2NodeNet(nn.Module):
    def __init__(self,
                 hd_sz,
                 input_size,
                 inp_op_sz,
                 encoder,
                 num_layers=1,
                 time_steps=3,
                 dropout=0.5):
        """
        Defines RNN structure that takes features encoded by CNN and produces program
        instructions at every time step.
        :inp_op_sz: total number of unique operations
        :param dropout: dropout
        :param hd_sz: rnn hidden size
        :param input_size: input_size (CNN feature size) to rnn
        :param encoder: Feature extractor network object
        :param num_layers: Number of layers to rnn
        :param time_steps: max length of program
        """
        super(Image2NodeNet, self).__init__()
        self.hd_sz = hd_sz
        self.in_sz = input_size
        self.input_op_sz = inp_op_sz
        self.num_layers = num_layers
        self.encoder = encoder
        self.time_steps = time_steps

        self.rnn = nn.GRU(
            input_size=self.in_sz + self.input_op_sz,
            hidden_size=self.hd_sz,
            num_layers=self.num_layers,
            batch_first=False)


        self.logsoftmax = nn.LogSoftmax(1)
        self.softmax = nn.Softmax(1)

        self.dense_fc_1 = nn.Linear(
            in_features=self.hd_sz, out_features=self.hd_sz)
        self.dense_output = nn.Linear(
            in_features=self.hd_sz, out_features=(self.input_op_sz))
        self.drop = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

    def forward(self, x: List):
 
        data, input_op, program_len = x

        #assert data.size()[0] == program_len + 1, "Incorrect stack size!!"
        batch_size = data.size()[0]
        h = Variable(torch.zeros(self.num_layers, batch_size, self.hd_sz))
        x_f = self.encoder(data)
        x_f = x_f.view(1, batch_size, self.in_sz)
        outputs = []
        for timestep in range(0, program_len + 1):
            # X_f is always input to the RNN at every time step
            # along with previous predicted label
            input_op_rnn = input_op[:, timestep, :]
            input_op_rnn = input_op_rnn.view(1, batch_size,
                                                self.input_op_sz)
            input = torch.cat((self.drop(x_f), input_op_rnn), 2)
            print(input.shape)
            out, h = self.rnn(input, h)
            hd = self.relu(self.dense_fc_1(self.drop(out[0])))
            output = self.logsoftmax(self.dense_output(self.drop(hd)))
            outputs.append(output)
        return outputs
    
    def test(self, x: List):
 
        data, input_op, program_len = x

        batch_size = data.size()[0]
        h = Variable(torch.zeros(self.num_layers, batch_size, self.hd_sz))
        x_f = self.encoder(data)
        x_f = x_f.view(1, batch_size, self.in_sz)
        last_output = input_op[:,0,:]
        outputs = []
        for timestep in range(0, program_len + 1):
            input_op_rnn = last_output.view(1, batch_size,
                                                self.input_op_sz)
            input = torch.cat((self.drop(x_f), input_op_rnn), 2)
            out, h = self.rnn(input, h)
            hd = self.relu(self.dense_fc_1(self.drop(out[0])))
            output = self.logsoftmax(self.dense_output(self.drop(hd)))
            next_input_op = torch.max(output, 1)[1].view(batch_size, 1)
            arr = Variable(
                    torch.zeros(batch_size, self.input_op_sz).scatter_(
                        1, next_input_op.data.cpu(), 1.0)).cuda()

            last_output = arr
            outputs.append(output)
        return outputs

    def beam_search(self, data: List, w: int, max_time: int):
        """
        Implements beam search for different models.
        :param data: Input data
        :param w: beam width
        :param max_time: Maximum length till the program has to be generated
        :return all_beams: all beams to find out the indices of all the
        """
        data, input_op = data

        # Beam, dictionary, with elements as list. Each element of list
        # containing index of the selected output and the corresponding
        # probability.
        batch_size = data.size()[0]
        h = Variable(torch.zeros(1, batch_size, self.hd_sz))
        # Last beams' data
        B = {0: {"input": input_op, "h": h}, 1: None}
        next_B = {}
        x_f = self.encoder(data)
        x_f = x_f.view(1, batch_size, self.in_sz)
        # List to store the probs of last time step
        prev_output_prob = [
            Variable(torch.ones(batch_size, self.input_op_sz))
        ]
        all_beams = []
        all_inputs = []
        for timestep in range(0, max_time):
            outputs = []
            for b in range(w):
                if not B[b]:
                    break
                input_op = B[b]["input"]

                h = B[b]["h"]
                print(input_op.shape)
                input_op_rnn = input_op[:,0,:].view(1, batch_size,
                                                 self.input_op_sz)
                input = torch.cat((x_f, input_op_rnn), 2)
                out, h = self.rnn(input, h)
                hd = self.relu(self.dense_fc_1(self.drop(out[0])))
                dense_output = self.dense_output(self.drop(hd))
                output = self.logsoftmax(dense_output)
                # Element wise multiply by previous probabs
                output = torch.nn.Softmax(1)(output)

                output = output * prev_output_prob[b]
                outputs.append(output)
                next_B[b] = {}
                next_B[b]["h"] = h
            if len(outputs) == 1:
                outputs = outputs[0]
            else:
                outputs = torch.cat(outputs, 1)

            next_beams_index = torch.topk(outputs, w, 1, sorted=True)[1]
            next_beams_prob = torch.topk(outputs, w, 1, sorted=True)[0]
            # print (next_beams_prob)
            current_beams = {
                "parent":
                next_beams_index.data.cpu().numpy() // (self.input_op_sz),
                "index": next_beams_index % (self.input_op_sz)
            }
            # print (next_beams_index % (self.num_draws))
            next_beams_index %= (self.input_op_sz)
            all_beams.append(current_beams)

            # Update previous output probabilities
            temp = Variable(torch.zeros(batch_size, 1))
            prev_output_prob = []
            for i in range(w):
                for index in range(batch_size):
                    temp[index, 0] = next_beams_prob[index, i]
                prev_output_prob.append(temp.repeat(1, self.input_op_sz))
            # hidden state for next step
            B = {}
            for i in range(w):
                B[i] = {}
                temp = Variable(torch.zeros(h.size()))
                for j in range(batch_size):
                    temp[0, j, :] = next_B[current_beams["parent"][j, i]]["h"][
                        0, j, :]
                B[i]["h"] = temp

            # one_hot for input to the next step
            for i in range(w):
                arr = Variable(
                    torch.zeros(batch_size, self.input_op_sz).scatter_(
                        1, next_beams_index[:, i:i + 1].data.cpu(),
                        1.0))
                B[i]["input"] = arr.unsqueeze(1)
            all_inputs.append(B)

        return all_beams, next_beams_prob, all_inputs



In [8]:
data = torch.randn((4,8, 3, 256, 256)) #L, B, C, H, W
input_op_idx = torch.randint(0, 7, (8, 4, 1))
input_op = torch.zeros((8, 4, 7))
input_op = input_op.scatter_(2, input_op_idx, 1) #B, L, S
input_op.shape

torch.Size([8, 4, 7])

In [9]:
program_len = 3
x = [data, input_op, program_len]


In [10]:
model = Image2NodeNet(256, 512, 8, enc, num_layers=1, time_steps=5)

In [23]:
def string_tuple_to_list(string_tuple):
    non_bracket = string_tuple[1:-1]
    sequence = non_bracket.split(',')
    return sequence

def create_vocabulary(labels):
    sequences = list(labels.values())
    unique_symbols = sorted(list(set(sym for seq in sequences for sym in seq)))
    return unique_symbols

def map_sym2idx(sym, uniq):
    return uniq.index(sym)

def map_idx2sym(idx, uniq):
    return uniq[idx]


In [24]:
with open('data/tables/Tables_new.txt', 'r') as label_file:
    labels = {}
    for line in label_file:
        stripped_line = line.strip()
        table, label = stripped_line.split(' ')
        sequence = ['<s>'] + string_tuple_to_list(label) + ['</s>']
        labels[table] = sequence
labels

{'Table_1': ['<s>', '1', '3', '4', '5', '</s>'],
 'Table_2': ['<s>', '2', '3', '4', '5', '</s>'],
 'Table_3': ['<s>', '7', '3', '4', '5', '</s>'],
 'Table_4': ['<s>', '8', '3', '4', '5', '</s>'],
 'Table_5': ['<s>', '9', '3', '4', '5', '</s>'],
 'Table_6': ['<s>', '10', '3', '4', '5', '</s>'],
 'Table_7': ['<s>', '11', '3', '4', '6', '</s>'],
 'Table_8': ['<s>', '12', '3', '4', '6', '</s>'],
 'Table_9': ['<s>', '13', '3', '4', '6', '</s>'],
 'Table_10': ['<s>', '14', '3', '4', '6', '</s>'],
 'Table_11': ['<s>', '15', '3', '4', '6', '</s>'],
 'Table_12': ['<s>', '16', '3', '4', '6', '</s>']}

In [25]:
uniq = create_vocabulary(labels)
uniq

['1',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '</s>',
 '<s>']

In [24]:
labels_idx = {}
for key, val in labels.items():
    labels_idx[key] = [map_sym2idx(sym, uniq) for sym in val]
labels_idx

{'Table_1': [17, 0, 9, 10, 11, 16],
 'Table_2': [17, 8, 9, 10, 11, 16],
 'Table_3': [17, 13, 9, 10, 11, 16],
 'Table_4': [17, 14, 9, 10, 11, 16],
 'Table_5': [17, 15, 9, 10, 11, 16],
 'Table_6': [17, 1, 9, 10, 11, 16],
 'Table_7': [17, 2, 9, 10, 12, 16],
 'Table_8': [17, 3, 9, 10, 12, 16],
 'Table_9': [17, 4, 9, 10, 12, 16],
 'Table_10': [17, 5, 9, 10, 12, 16],
 'Table_11': [17, 6, 9, 10, 12, 16],
 'Table_12': [17, 7, 9, 10, 12, 16]}

In [25]:
labels_sym = {}
for key, val in labels_idx.items():
    labels_sym[key] = [map_idx2sym(idx, uniq) for idx in val]
labels_sym

{'Table_1': ['<s>', '1', '3', '4', '5', '</s>'],
 'Table_2': ['<s>', '2', '3', '4', '5', '</s>'],
 'Table_3': ['<s>', '7', '3', '4', '5', '</s>'],
 'Table_4': ['<s>', '8', '3', '4', '5', '</s>'],
 'Table_5': ['<s>', '9', '3', '4', '5', '</s>'],
 'Table_6': ['<s>', '10', '3', '4', '5', '</s>'],
 'Table_7': ['<s>', '11', '3', '4', '6', '</s>'],
 'Table_8': ['<s>', '12', '3', '4', '6', '</s>'],
 'Table_9': ['<s>', '13', '3', '4', '6', '</s>'],
 'Table_10': ['<s>', '14', '3', '4', '6', '</s>'],
 'Table_11': ['<s>', '15', '3', '4', '6', '</s>'],
 'Table_12': ['<s>', '16', '3', '4', '6', '</s>']}

In [26]:
import json

In [27]:
with open('data/complete_symbol_sequence.json', 'w') as com_sym_file:
    json.dump(labels, com_sym_file)

with open('data/complete_index_sequence.json', 'w') as com_idx_file:
    json.dump(labels_idx, com_idx_file)

In [28]:
import random
def train_valid_test_split(num_images, num_test, num_dev):
    all_indices = np.arange(num_images)
    total_set_aside = num_test + num_dev
    set_aside_indices = []
    for i in range(total_set_aside):
        if i==0:
            idx = i+1
        else:
            idx = set_aside_indices[-1] + 2
        set_aside_indices.append(all_indices[idx])
    
    test_indices = random.sample(set_aside_indices, 2)
    dev_indices = list(np.setdiff1d(set_aside_indices, test_indices))
    
    return test_indices, dev_indices

In [29]:

test_indices, dev_indices = train_valid_test_split(10, 2, 2)
print('Test_indices: ', test_indices)
print('Dev_indices: ', dev_indices)

Test_indices:  [1, 3]
Dev_indices:  [5, 7]


In [30]:
import glob
import os
import json

In [31]:
def get_train_dev_test_examples(path, test_indices, dev_indices):
    examples = sorted(glob.glob(os.path.join(path, '*.png')))
    test_examples = [ex for ex in examples if int(os.path.splitext(ex)[0][-1]) in test_indices]
    dev_examples = [ex for ex in examples if int(os.path.splitext(ex)[0][-1]) in dev_indices]
    train_examples = list(np.setdiff1d(examples, test_examples+dev_examples))
    return test_examples, dev_examples, train_examples

In [32]:
test_ex, dev_ex, examples = get_train_dev_test_examples('./data/tables/Table_1', test_indices, dev_indices)
print(test_ex)
print(dev_ex)
print(examples)

['./data/tables/Table_1/Table_1_Image_1.png', './data/tables/Table_1/Table_1_Image_3.png']
['./data/tables/Table_1/Table_1_Image_5.png', './data/tables/Table_1/Table_1_Image_7.png']
['./data/tables/Table_1/Table_1_Image_10.png', './data/tables/Table_1/Table_1_Image_2.png', './data/tables/Table_1/Table_1_Image_4.png', './data/tables/Table_1/Table_1_Image_6.png', './data/tables/Table_1/Table_1_Image_8.png', './data/tables/Table_1/Table_1_Image_9.png']


In [33]:
root_dir = './data/tables'
table_folders = [folder for folder in os.listdir(root_dir) if not folder.endswith('.txt') and not folder.startswith('.')]
table_folders

['Table_3',
 'Table_4',
 'Table_11',
 'Table_5',
 'Table_2',
 'Table_10',
 'Table_12',
 'Table_9',
 'Table_7',
 'Table_1',
 'Table_6',
 'Table_8']

In [34]:
train_set = []
test_set = []
dev_set = []
for folder in table_folders:
    path = os.path.join(root_dir, folder)
    test_indices, dev_indices = train_valid_test_split(10, 2, 2)
    test_ex, dev_ex, train_ex = get_train_dev_test_examples(path, test_indices, dev_indices)
    train_set += train_ex
    test_set += test_ex
    dev_set += dev_ex

In [35]:
print('Total train_examples: ', len(train_set))
print('Total test_examples: ', len(test_set))
print('Total dev_examples: ', len(dev_set))

Total train_examples:  72
Total test_examples:  24
Total dev_examples:  24


In [173]:
def dump_json(filepath, ex):
    with open(filepath, 'w') as file:
        json.dump(ex, file)

In [190]:
dump_json('./data/train.json', train_set)
dump_json('./data/test.json', test_set)
dump_json('./data/dev.json', dev_set)

In [12]:
import json
def load_json(json_path):
    with open(json_path, 'r') as j:
        content = json.load(j)
    return content

class Image2NodeDataset(Dataset):
    def __init__(self, im_json, label_json):
        self.images = load_json(im_json)
        self.labels = load_json(label_json)
        transform_list = [transforms.ToTensor(),
                          transforms.Normalize((0.5,0.5, 0.5), (0.5,0.5, 0.5))]

        self.transform = transforms.Compose(transform_list)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image = Image.open(self.images[idx]).convert('RGB')
        image_256 = transforms.functional.resize(image, 256, transforms.InterpolationMode.BICUBIC)
        image_t = self.transform(image_256)

        #label
        filename = os.path.split(self.images[idx])[1]
        object_name = '_'.join(filename.split('_')[:2]) 
        sequence = torch.tensor(self.labels[object_name])
        ip_op = sequence[:-1]
        label = sequence[1:]
        return {'image':image_t, 'inp_op':ip_op, 'label': label, 'program_len': len(sequence) - 2}

In [13]:
dataset = Image2NodeDataset('./data/train.json', './data/complete_index_sequence.json')
len(dataset)

72

In [29]:
i = dataset[1]
i['image'].shape

torch.Size([3, 256, 455])

In [14]:
loader = DataLoader(dataset, 8)

In [15]:
batch = next(iter(loader))
print(batch['image'].shape)
print(batch['inp_op'].shape)
print(batch['label'].shape)
print(batch['program_len'])

torch.Size([8, 3, 256, 455])
torch.Size([8, 5])
torch.Size([8, 5])
tensor([4, 4, 4, 4, 4, 4, 4, 4])


In [35]:
op = model.beam_search([batch['image'], input_op], )

torch.Size([1, 8, 519])
torch.Size([1, 8, 519])
torch.Size([1, 8, 519])
torch.Size([1, 8, 519])


In [29]:
image, input_op_idx, label, program_len = batch['image'], batch['inp_op'], batch['label'], batch['program_len']
input_op_idx = input_op_idx.unsqueeze(2)
input_op_idx.shape

torch.Size([8, 5, 1])

In [30]:
input_op = torch.zeros((input_op_idx.shape[0], input_op_idx.shape[1], 8))
input_op = input_op.scatter_(2, input_op_idx, 1) #B, L, S
input_op.shape

torch.Size([8, 5, 8])

In [53]:
input_op_idx[0,:,:]

tensor([[7],
        [0],
        [2],
        [3],
        [4]])

In [56]:
input_op[:,0,:]

tensor([[0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 1.]])

In [34]:
input_op.shape

torch.Size([8, 5, 8])

In [71]:
all_beams, next_beams_prob, all_inputs = model.beam_search([image, input_op], 3, 5)

torch.Size([8, 5, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])
torch.Size([8, 1, 8])


In [68]:
torch.max(input_op[:,0,:], 1)

torch.return_types.max(
values=tensor([1., 1., 1., 1., 1., 1., 1., 1.]),
indices=tensor([7, 7, 7, 7, 7, 7, 7, 7]))

In [16]:
def beams_parser(all_beams, batch_size, beam_width=5):
    # all_beams = [all_beams[k].data.numpy() for k in all_beams.keys()]
    all_expression = {}
    W = beam_width
    T = len(all_beams)
    for batch in range(batch_size):
        all_expression[batch] = []
        for w in range(W):
            temp = []
            parent = w
            for t in range(T - 1, -1, -1):
                temp.append(all_beams[t]["index"][batch, parent].data.cpu()
                            .numpy())
                parent = all_beams[t]["parent"][batch, parent]
            temp = temp[::-1]
            all_expression[batch].append(np.array(temp))
        all_expression[batch] = np.squeeze(np.array(all_expression[batch]))
    return all_expression

In [63]:
all_beams[4]["index"][0, 0].data.cpu().numpy()

array(2)

In [75]:
exp = beams_parser(all_beams, 8, 3)

In [76]:
exp

{0: array([[5, 5, 4, 4, 5],
        [5, 5, 4, 4, 2],
        [2, 4, 1, 4, 5]]),
 1: array([[4, 3, 5, 5, 4],
        [4, 3, 5, 5, 1],
        [4, 3, 5, 5, 2]]),
 2: array([[4, 2, 5, 2, 1],
        [4, 2, 5, 1, 4],
        [4, 2, 5, 1, 5]]),
 3: array([[0, 1, 4, 5, 4],
        [0, 1, 4, 5, 2],
        [7, 5, 5, 1, 2]]),
 4: array([[1, 2, 1, 4, 1],
        [1, 2, 1, 5, 2],
        [1, 2, 1, 5, 1]]),
 5: array([[4, 4, 4, 1, 2],
        [4, 4, 4, 1, 4],
        [4, 4, 4, 1, 1]]),
 6: array([[4, 1, 1, 2, 2],
        [2, 2, 4, 4, 2],
        [2, 2, 4, 4, 7]]),
 7: array([[4, 2, 5, 2, 4],
        [4, 2, 5, 2, 2],
        [5, 2, 1, 4, 2]])}

In [74]:
all_beams

[{'parent': array([[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]]),
  'index': tensor([[4, 5, 2],
          [4, 2, 1],
          [4, 1, 2],
          [7, 0, 4],
          [1, 4, 7],
          [4, 7, 1],
          [2, 4, 6],
          [4, 0, 5]])},
 {'parent': array([[0, 2, 1],
         [0, 0, 1],
         [0, 0, 0],
         [1, 0, 2],
         [0, 1, 0],
         [0, 0, 1],
         [0, 0, 1],
         [0, 2, 1]]),
  'index': tensor([[1, 4, 5],
          [4, 3, 4],
          [2, 5, 1],
          [1, 5, 4],
          [2, 1, 0],
          [4, 0, 1],
          [2, 4, 1],
          [2, 2, 2]])},
 {'parent': array([[0, 2, 1],
         [1, 0, 0],
         [0, 1, 1],
         [0, 1, 0],
         [0, 0, 2],
         [0, 1, 1],
         [0, 0, 2],
         [0, 1, 2]]),
  'index': tensor([[4, 4, 1],
          [5, 5, 1],
          [5, 4, 2],
          [4, 5, 2],
          [1, 2, 5],
          

In [14]:
import json
from Dataset import Image2NodeDataset
from models import Image2NodeNet, ResnetEncoder
from utils import beams_parser


In [15]:
test_dataset = Image2NodeDataset('./data/test.json', './data/complete_index_sequence.json')
len(test_dataset)

24

In [16]:
loader = DataLoader(test_dataset, shuffle=False, batch_size=1)

In [17]:
ckpt = torch.load('./model/tables_12/best_model.pth', map_location=torch.device('cpu'))
ckpt.keys()

dict_keys(['epoch', 'model_state', 'encoder_state', 'best_val_loss'])

In [18]:
im_encoder = ResnetEncoder(arch_name='resnet18', in_channels=3, 
                           pretrained=False)
im_encoder.load_state_dict(ckpt['encoder_state'])
im_encoder.eval()
image2node_net = Image2NodeNet(hd_sz=256, input_size=512, inp_op_sz=16+2,
                             encoder=im_encoder)
image2node_net.load_state_dict(ckpt['model_state'])
image2node_net.eval()

Building resnet18 model (pretrained=False)!!


Image2NodeNet(
  (encoder): ResnetEncoder(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(

In [19]:
from tqdm import tqdm

In [20]:
expressions = []
for batch in tqdm(loader):
        image = batch['image']
        input_op_idx, label, program_lens = batch['inp_op'], batch['label'], batch['program_len']
        # Reshaping and getting one hot encoding of input operations
        input_op = torch.zeros((input_op_idx.shape[0], input_op_idx.shape[1], 16+2))
        input_op = input_op.scatter_(2, input_op_idx.unsqueeze(2), 1)
        all_beams, next_beams_prob, all_inputs = image2node_net.beam_search([image, input_op], 3, 5)
        expression = beams_parser(all_beams, 1, 1)
        expressions.append(expression)
        

  4%|▍         | 1/24 [00:00<00:06,  3.79it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 12%|█▎        | 3/24 [00:00<00:04,  5.24it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 21%|██        | 5/24 [00:00<00:03,  5.67it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 29%|██▉       | 7/24 [00:01<00:02,  5.84it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 33%|███▎      | 8/24 [00:01<00:02,  6.06it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 38%|███▊      | 9/24 [00:01<00:02,  5.01it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 46%|████▌     | 11/24 [00:02<00:02,  4.79it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 54%|█████▍    | 13/24 [00:02<00:02,  5.23it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 62%|██████▎   | 15/24 [00:02<00:01,  5.67it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 67%|██████▋   | 16/24 [00:03<00:01,  5.66it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])


 75%|███████▌  | 18/24 [00:03<00:01,  5.49it/s]

torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 83%|████████▎ | 20/24 [00:03<00:00,  5.77it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


 92%|█████████▏| 22/24 [00:04<00:00,  5.80it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])


100%|██████████| 24/24 [00:04<00:00,  5.43it/s]

torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 5, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])
torch.Size([1, 1, 18])





In [21]:
expressions

[{0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 1,  9, 10, 11, 16])},
 {0: array([14,  9, 10, 11, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 6,  9, 10, 12, 16])},
 {0: array([ 1,  9, 10, 11, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 1,  9, 10, 12, 16])},
 {0: array([ 3,  9, 10, 12, 16])},
 {0: array([ 7,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 7,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 6,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 6,  9, 10, 12, 16])},
 {0: array([ 6,  9, 10, 12, 16])},
 {0: array([ 2,  9, 10, 12, 16])},
 {0: array([ 6,  9, 10, 12, 16])},
 {0: array([ 1,  9, 10, 11, 16])},
 {0: array([ 3,  9, 10, 12, 16])},
 {0: array([ 3,  9, 10, 12, 16])}]

In [29]:
labels_idx = []
for val in expressions:
    labels_idx.append([map_idx2sym(sym, uniq) for sym in val[0]])
labels_idx

[['11', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['10', '3', '4', '5', '</s>'],
 ['8', '3', '4', '5', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['15', '3', '4', '6', '</s>'],
 ['10', '3', '4', '5', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['10', '3', '4', '6', '</s>'],
 ['12', '3', '4', '6', '</s>'],
 ['16', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['16', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['15', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['15', '3', '4', '6', '</s>'],
 ['15', '3', '4', '6', '</s>'],
 ['11', '3', '4', '6', '</s>'],
 ['15', '3', '4', '6', '</s>'],
 ['10', '3', '4', '5', '</s>'],
 ['12', '3', '4', '6', '</s>'],
 ['12', '3', '4', '6', '</s>']]

In [41]:
i = 22
labels_idx[i:i+2]

[['12', '3', '4', '6', '</s>'], ['12', '3', '4', '6', '</s>']]