In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from nb_005 import *
from collections import Counter

# Wikitext 2

## Data

Download the dataset [here](https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip) and unzip it so it's in the folder wikitext.

In [None]:
EOS = '<eos>'
PATH=Path('data/wikitext')

Small helper function to read the tokens.

In [None]:
def read_file(filename):
    tokens = []
    with open(PATH/filename, encoding='utf8') as f:
        for line in f:
            tokens.append(line.split() + [EOS])
    return np.array(tokens)

In [None]:
trn_tok = read_file('wiki.train.tokens')
val_tok = read_file('wiki.valid.tokens')
tst_tok = read_file('wiki.test.tokens')

In [None]:
len(trn_tok), len(val_tok), len(tst_tok)

In [None]:
' '.join(trn_tok[4][:20])

In [None]:
cnt = Counter(word for sent in trn_tok for word in sent)
cnt.most_common(10)

Give an id to each token and add the pad token (just in case we need it).

In [None]:
itos = [o for o,c in cnt.most_common()]
itos.insert(0,'<pad>')

In [None]:
vocab_size = len(itos); vocab_size

Creates the mapping from token to id then numericalizing our datasets.

In [None]:
stoi = collections.defaultdict(lambda : 5, {w:i for i,w in enumerate(itos)})

In [None]:
trn_ids = np.array([([stoi[w] for w in s]) for s in trn_tok])
val_ids = np.array([([stoi[w] for w in s]) for s in val_tok])
tst_ids = np.array([([stoi[w] for w in s]) for s in tst_tok])

## Model

### 1. Dropout

We want to use the AWD-LSTM from [Stephen Merity](https://arxiv.org/abs/1708.02182). First, we'll need all different kinds of dropouts. Dropout consists into replacing some coefficients by 0 with probability p. To ensure that the averga of the weights remains constant, we apply a correction to the weights that aren't nullified of a factor `1/(1-p)`.

In [None]:
def dropout_mask(x, sz, p):
    "Returns a dropout mask of the same type as x, size sz, with probability p to cancel an element."
    return x.new(*sz).bernoulli_(1-p)/(1-p)

In [None]:
x = torch.randn(10,10)
dropout_mask(x, (10,10), 0.5)

Once with have a dropout mask `m`, applying the dropout to `x` is simply done by `x = x * m`. We create our own dropout mask and don't rely on pytorch dropout because we want to nullify the coefficients on the batch dimension but not the token dimension (aka the same coefficients are replaced by zero for each word in the sentence). 

Inside a RNN, a tensor x will have three dimensions: seq_len, bs, vocab_size, so we create a dropout mask for the last two dimensions and broadcast it to the first dimension.

In [None]:
class RNNDropout(nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p=p

    def forward(self, x):
        if not self.training or not self.p: return x
        m = dropout_mask(x.data, (1, x.size(1), x.size(2)), self.p)
        return m * x

In [None]:
dp_test = RNNDropout(0.5)
x = torch.randn(2,5,10)
x, dp_test(x)

In [None]:
def noop(x): return x

In [None]:
class WeightDropout(nn.Module):
    "A module that warps another layer in which some weights will be replaced by 0 during training."
    
    def __init__(self, module, dropout, layer_names=['weight_hh_l0']):
        super().__init__()
        self.module,self.dropout,self.layer_names = module,dropout,layer_names
    
    def _setweights(self):
        for layer in self.layer_names:
            raw_w = getattr(self, f'{layer}_raw')
            w1 = F.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, layer, w1)
    
    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)
    
    def reset(self):
        for layer in self.layer_names:
            #Makes a copy of the weights of the selected layers.
            w = getattr(self.module, layer)
            self.register_parameter(f'{layer}_raw', nn.Parameter(w.data))
            del self.module._parameters[layer]
        if hasattr(self.module, 'reset'): self.module.reset()
    
    def update_raw(self):
        for layer in self.layer_names:
            w = getattr(self.module, layer)
            mask = w != 0.
            self.raw_weights[layer][mask] = w[mask] * (1-self.dropout)

In [None]:
module = nn.LSTM(20, 20)
dp_module = WeightDropout(module, 0.5)
dp_module.reset()
opt = optim.SGD(dp_module.parameters(), 10)
dp_module.train()

In [None]:
w = F.dropout(w_raw, p=0.5, training=True)

In [None]:
w

In [None]:
x = torch.randn(2,5,20)
x.requires_grad_(requires_grad=True)
h = (torch.zeros(1,5,20), torch.zeros(1,5,20))
for _ in range(5): x,h = dp_module(x,h)

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')

In [None]:
target = torch.randint(0,20,(10,)).long()
loss = F.nll_loss(x.view(-1,20), target)
loss.backward()
opt.step()

In [None]:
w, w_raw = getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')
w.grad, w_raw.grad

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')

In [None]:
opt.param_groups

In [None]:
class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights=['weight_hh_l0'], dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', nn.Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = mask.expand_as(raw_w) * raw_w
            else:
                w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)

In [None]:
x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
h0 = None
lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
lin.cuda()
run1 = [x.sum() for x in lin(x).data]
run2 = [x.sum() for x in lin(x).data]

print('All items should be different')
print('Run 1:', run1)
print('Run 2:', run2)

print('Testing WeightDrop with LSTM')

wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)
wdrnn.cuda()

run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
run2 = [x.sum() for x in wdrnn(x, h0)[0].data]

print('First timesteps should be equal, all others should differ')
print('Run 1:', run1)
print('Run 2:', run2)

print('---')

In [None]:
module = nn.LSTM(10, 20)
dp_module = WeightDrop(module, dropout=0.5)
#dp_module.reset()
opt = optim.SGD(dp_module.parameters(), 10)

In [None]:
getattr(dp_module.module, 'weight_hh_l0'),getattr(dp_module,'weight_hh_l0_raw')

In [None]:
x = torch.randn(2,5,10)
x.requires_grad_(requires_grad=True)
h = (torch.zeros(1,5,20), torch.zeros(1,5,20))
out,h = dp_module(x,h)

In [None]:
lstm = nn.LSTM(5, 3)  # Input dim is 3, output dim is 3
inputs = torch.randn(7, 2, 5)  # make a sequence of length 5

# initialize the hidden state.
hidden = (torch.randn(1, 2, 3),
          torch.randn(1, 2, 3))
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(inputs.view(7, 2, -1), hidden)