In [159]:
import re
import cv2
import tensorflow as tf
import h5py
import numpy as np
import imageio
import matplotlib.pyplot as plt
from itertools import chain
from tqdm import tqdm

%matplotlib inline
plt.rcParams["figure.figsize"] = (10.0, 8.0) # set default size of plots
plt.rcParams["image.interpolation"] = "nearest"
plt.rcParams["image.cmap"] = "gray"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [101]:
%reload_ext autoreload

from sketch2code.data_model import *
from sketch2code.datasets import *
from sketch2code.helpers import *
from sketch2code.methods.lstm import *

### Prepare data

Use PennTreeBank to solve POS tagging

In [32]:
from nltk.corpus import treebank

full_data = [treebank.tagged_sents(x) for x in treebank.fileids()]

train_data, test_data = full_data[:int(len(full_data) * 0.7)], full_data[int(len(full_data) * 0.7):]
train_data, valid_data = train_data[:-int(len(full_data) * 0.1)], train_data[-int(len(full_data) * 0.1):]

print(len(train_data), len(valid_data), len(test_data))

def flatten(xs):
    return [x for a in xs for x in a]

train_data, valid_data, test_data = flatten(train_data), flatten(valid_data), flatten(test_data)

120 19 60


In [47]:
# make vocabulary and predicted labels
vocab = set(w for sent in train_data for w, l in sent)
assert '<pad>' not in vocab and '<unknown>' not in vocab
vocab = ['<pad>'] + sorted(list(vocab)) + ['<unknown>']
vocab_w2i = {w: i for i, w in enumerate(vocab)}

labels = set(l for sent in train_data for w, l in sent)
assert all(l in labels for sent in chain(valid_data, test_data) for w, l in sent)
labels = {l: i for i, l in enumerate(['<pad>'] + sorted(list(labels)))}

In [46]:
# separate input and label
def separate_input_and_label(data):
    global vocab_w2i, labels
    X = []
    Y = []
    unknown_id = vocab_w2i['<unknown>']
    for sent in data:
        X.append([vocab_w2i.get(w, unknown_id) for w, l in sent])
        Y.append([labels[l] for w, l in sent])
    return X, Y


X_train, y_train = separate_input_and_label(train_data)
X_valid, y_valid = separate_input_and_label(valid_data)
X_test, y_test = separate_input_and_label(test_data)

print(len(X_train), len(y_train))
print(len(X_valid), len(y_valid))
print(len(X_test), len(y_test))

2712 2712
356 356
846 846


## Build model

In [151]:
class TestModel(torch.nn.Module):
    
    def __init__(self, lstm, padding_label_idx: int, n_labels: int):
        super().__init__()
        self.padding_label_idx = padding_label_idx
        self.n_labels = n_labels
        self.lstm = lstm
        self.hidden2tag = torch.nn.Linear(self.lstm.hidden_size, n_labels)
    
    def forward(self, X, X_lengths):
        batch_size, T = X.shape
        
        X, hn = self.lstm(X, X_lengths)
        X = X.contiguous()
        X = X.view(-1, X.shape[2])
        X = self.hidden2tag(X)
        X = torch.nn.functional.log_softmax(X, dim=1)
        
        # convert back to batch_size, T, tags
        X = X.view(batch_size, T, -1)
        return X
    
    def loss(self, Y_hat, Y, X_lengths):
        # flatten Y and create a mask if it is 
        Y = Y.view(-1)
        mask = (Y != self.padding_label_idx).float()
        Y_hat = Y_hat.view(-1, self.n_labels)
              
        Y_hat = Y_hat[range(Y_hat.shape[0]), Y] * mask
        # cross entropy loss
        ce_loss = -torch.sum(Y_hat) / int(torch.sum(mask).data[0])
        
        return ce_loss

In [153]:
def prepare_batch(sents: List[List[int]], sent_lbls: List[List[int]]):
    pad_w = vocab_w2i['<pad>']
    pad_l = labels['<pad>']
    
    sentence_index_and_length = sorted([(i, len(s)) for i, s in enumerate(sents)], key=lambda x: x[1], reverse=True)
    
    padded_sents = torch.ones((len(sents), sentence_index_and_length[0][1]), dtype=torch.long) * pad_w
    padded_lbls = torch.ones_like(padded_sents) * pad_l
    
    for i, (j, nw) in enumerate(sentence_index_and_length):
        padded_sents[i, :nw] = torch.tensor(sents[j])
        padded_lbls[i, :nw] = torch.tensor(sent_lbls[j])
    
    return padded_sents, padded_lbls, torch.tensor([nw for i, nw in sentence_index_and_length])


def iter_batch(batch_size: int, X, y, shuffle: bool=False): 
    index = range(len(X))
    if shuffle:
        np.random.shuffle(index)
        
    for i in range(0, len(X), batch_size):
        batch_idx = index[i:i+batch_size]
        bx, by, bxlen = prepare_batch([X[j] for j in batch_idx], [y[j] for j in batch_idx])
        yield (bx, by, bxlen)

## Train the model

In [164]:
device = torch.device('cuda')
lstm = LSTM(vocab_size=len(vocab), embedding_dim=20, hidden_size=20, n_layers=1, padding_token_idx=0)
model = TestModel(lstm, padding_label_idx=0, n_labels=len(labels))
model = model.to(device)

In [166]:
n_epoches = 10
batch_size = 20

loss_func = torch.nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-3)

bx_val, by_val, bx_val_len = prepare_batch(X_valid, y_valid)
bx_val = bx_val.to(device)
by_val = by_val.to(device)
bx_val_len = bx_val_len.to(device)

for epoch in range(n_epoches):
    for bx, by, bxlen in tqdm(iter_batch(batch_size, X_train, y_train)):
        model.zero_grad()
        by_pred = model(bx.to(device), bxlen.to(device))
        loss = loss_func(by_pred.view(-1, len(labels)), by.to(device).view(-1))
        loss.backward()
        optimizer.step()
    
    print("Epoch", epoch, 'validation loss:', loss_func(model(bx_val, bx_val_len).view(-1, len(labels)), by_val.view(-1)).item(), flush=True)

136it [00:01, 88.26it/s]

Epoch 0 validation loss: 2.4000649452209473



136it [00:01, 90.01it/s]

Epoch 1 validation loss: 1.552728533744812



136it [00:01, 91.34it/s]

Epoch 2 validation loss: 0.9840391278266907



136it [00:01, 89.38it/s]

Epoch 3 validation loss: 0.6627480983734131



136it [00:01, 92.71it/s]

Epoch 4 validation loss: 0.5016286373138428



136it [00:01, 90.63it/s]

Epoch 5 validation loss: 0.41821572184562683



136it [00:01, 89.54it/s]

Epoch 6 validation loss: 0.36950239539146423



136it [00:01, 91.68it/s]

Epoch 7 validation loss: 0.34019362926483154



136it [00:01, 92.34it/s]

Epoch 8 validation loss: 0.3250828981399536



136it [00:01, 91.25it/s]

Epoch 9 validation loss: 0.3156919479370117



