In [1]:
from fastai.vision.all import *
from fastai.learner import *
from fastai.data.all import *
from fastai.callback.tracker import SaveModelCallback
import pandas as pd
import matplotlib.pyplot as plt
from pathlib2 import Path
import numpy as np
import random
from torch.nn import MSELoss

In [2]:
%%time

df = pd.read_csv('data/examples.csv')
df.shape

CPU times: user 25.5 s, sys: 2.45 s, total: 28 s
Wall time: 28 s


(17937758, 9)

In [3]:
# grabbing most common English nouns

from urllib.request import urlopen

url = "https://www.thoughtco.com/learn-the-most-important-english-nouns-4087688"
page = urlopen(url)

In [4]:
html_bytes = page.read()
html = html_bytes.decode("utf-8")

import re

matches = re.findall('\d\. (\w*)', html)

In [5]:
matches = [m for m in matches if m != '' and m[-1]!='s']

In [6]:
len(matches)

567

In [7]:
singular = set(matches)

In [8]:
df[df.target_word.str.lower().isin(singular)].shape

(1143044, 9)

In [9]:
plural = set([m + 's' for m in matches])

In [10]:
df[df.target_word.str.lower().isin(plural)].shape

(282133, 9)

In [11]:
%%time
fn2features = pd.read_pickle('data/fn2feature.pkl')

CPU times: user 15.5 s, sys: 9.95 s, total: 25.5 s
Wall time: 25.5 s


In [12]:
dataset_mean = -5
dataset_std = 15

def normalize_data(ary):
    return (ary - dataset_mean) / dataset_std

In [13]:
def empty_list(): return list()

In [14]:
word2row_idxs = pd.read_pickle('data/word2row_idxs.pkl')

In [15]:
def prepare_features(fn, pad_to=291, pad_left=False):
    ary = fn2features[fn][:pad_to]
    example = np.zeros((pad_to, 13))
    if pad_left:
        example[-ary.shape[0]:, :] = ary
    else: example[:ary.shape[0], :] = ary
    return example.astype(np.float32)

In [16]:
df['singular'] = False
df.loc[df.target_word.str.lower().isin(singular), 'singular'] = True

In [17]:
df['plural'] = False
df.loc[df.target_word.str.lower().isin(plural), 'plural'] = True

In [18]:
df[df.singular].shape, df[df.plural].shape

((1143044, 11), (282133, 11))

In [19]:
singular_idxs = np.array(df[df.singular].index)
plural_idxs = np.array(df[df.plural].index)

In [20]:
np.random.shuffle(singular_idxs)
np.random.shuffle(plural_idxs)

In [58]:
class Dataset():
    def __init__(self, train=True):
        self.train = train
        if train:
            self.idxs = np.concatenate((singular_idxs[:-40_000], plural_idxs[:-40_000]))
        else:
            self.idxs = np.concatenate((singular_idxs[-40_000:], plural_idxs[-40_000:]))
    def __len__(self):
        return len(self.idxs)
    def __getitem__(self, idx):
        row_idx = self.idxs[idx]
        target_fn = df.target_fn[row_idx]
        x = normalize_data(prepare_features(target_fn, pad_left=True))
        return x, np.array([0 if df.target_word[row_idx][-1] != 'S' else 1]).astype(np.float32)

In [59]:
BS = 2048
LR = 1e-3
NUM_WORKERS = 8

train_dl = DataLoader(Dataset(), BS, NUM_WORKERS, shuffle=True)
valid_dl = DataLoader(Dataset(train=False), BS, NUM_WORKERS)

dls = DataLoaders(train_dl, valid_dl)

In [88]:
# bidirectional encoder, 1 layer, concatenate hidden state
class Model(Module):
    def __init__(self, hidden_size=25, num_layers_encoder=3):
        self.return_embeddings = False
        self.num_layers_encoder = num_layers_encoder
        self.hidden_size = hidden_size
        
        self.encoder= nn.LSTM(
            input_size=13,
            hidden_size=hidden_size,
            num_layers=self.num_layers_encoder,
            batch_first=True,
            dropout=0,
            bidirectional=True
        )
        self.classifier = nn.Linear(2*hidden_size, 1)
            
    def forward(self, x):
        _, (embeddings, _) = self.encoder(x)
        embeddings = torch.cat((embeddings[-1], embeddings[-2]), 1)
        return self.classifier(embeddings)

In [91]:
learn = Learner(dls.cuda(), Model().cuda(), loss_func=BCEWithLogitsLossFlat(), lr=1e-3, opt_func=Adam, metrics=[accuracy_multi])

In [92]:
learn.fit(10, lr=1e-3)

epoch,train_loss,valid_loss,accuracy_multi,time
0,0.138274,0.231876,0.914925,01:35
1,0.108827,0.182405,0.93195,01:35
2,0.092519,0.160802,0.940262,01:36
3,0.079705,0.180992,0.93095,01:35
4,0.074716,0.159256,0.94105,01:35
5,0.068807,0.112312,0.9611,01:38
6,0.06598,0.102621,0.964737,01:36
7,0.059957,0.124236,0.95505,01:36
8,0.057515,0.104107,0.964163,01:36
9,0.053419,0.111485,0.959875,01:35
