Skip to content

Commit

Permalink
metrics now require tensors; nlp updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jph00 committed Jan 27, 2018
1 parent 8eaf0d5 commit cb216e1
Show file tree
Hide file tree
Showing 11 changed files with 30 additions and 33 deletions.
4 changes: 2 additions & 2 deletions courses/dl1/nlp.ipynb
Expand Up @@ -82,8 +82,8 @@
"PATH='data/aclImdb/'\n",
"\n",
"names = ['neg','pos']\n",
"trn,trn_y = texts_from_folders(f'{PATH}train',names)\n",
"val,val_y = texts_from_folders(f'{PATH}test',names)"
"trn,trn_y = texts_labels_from_folders(f'{PATH}train',names)\n",
"val,val_y = texts_labels_from_folders(f'{PATH}test',names)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions courses/ml1/lesson5-nlp.ipynb
Expand Up @@ -131,8 +131,8 @@
},
"outputs": [],
"source": [
"trn,trn_y = texts_from_folders(f'{PATH}train',names)\n",
"val,val_y = texts_from_folders(f'{PATH}test',names)"
"trn,trn_y = texts_labels_from_folders(f'{PATH}train',names)\n",
"val,val_y = texts_labels_from_folders(f'{PATH}test',names)"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion fastai/core.py
Expand Up @@ -14,7 +14,7 @@ def T(a):
if a.dtype in (np.int8, np.int16, np.int32, np.int64):
res = torch.LongTensor(a.astype(np.int64))
elif a.dtype in (np.float32, np.float64):
return torch.FloatTensor(a.astype(np.float32))
res = torch.FloatTensor(a.astype(np.float32))
else: raise NotImplementedError(a.dtype)
return to_gpu(res, async=True)

Expand Down
3 changes: 2 additions & 1 deletion fastai/dataloader.py
@@ -1,6 +1,7 @@
import torch, queue
from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler
from .imports import *
from .core import *
import collections,sys,traceback,threading

string_classes = (str, bytes)
Expand All @@ -27,7 +28,7 @@ def np_collate(batch, pad_idx):

def get_tensor(batch, pin):
if isinstance(batch, (np.ndarray, np.generic)):
batch = torch.from_numpy(batch).contiguous()
batch = to_gpu(torch.from_numpy(batch).contiguous())
return batch.pin_memory() if pin else batch
elif isinstance(batch, string_classes): return batch
elif isinstance(batch, collections.Mapping):
Expand Down
3 changes: 2 additions & 1 deletion fastai/imports.py
Expand Up @@ -21,7 +21,8 @@
matplotlib.rc('animation', html='html5')
np.set_printoptions(precision=5, linewidth=110, suppress=True)

def in_notebook(): return 'ipykernel' in sys.modules
from ipykernel.kernelapp import IPKernelApp
def in_notebook(): return IPKernelApp.initialized()

import tqdm as tq
from tqdm import tqdm_notebook, tnrange
Expand Down
6 changes: 3 additions & 3 deletions fastai/learner.py
Expand Up @@ -152,7 +152,7 @@ def fit_gen(self, model, data, layer_opt, n_cycle, cycle_len=None, cycle_mult=1,
elif not self.sched: self.sched=LossRecorder(layer_opt)
callbacks+=[self.sched]
n_epoch = sum_geom(cycle_len if cycle_len else 1, cycle_mult, n_cycle)
fit(model, data, n_epoch, layer_opt.opt, self.crit,
return fit(model, data, n_epoch, layer_opt.opt, self.crit,
metrics=metrics, callbacks=callbacks, reg_fn=self.reg_fn, clip=self.clip, **kwargs)

def get_layer_groups(self): return self.models.get_layer_groups()
Expand Down Expand Up @@ -206,12 +206,12 @@ def fit(self, lrs, n_cycle, wds=None, **kwargs):
"""
self.sched = None
layer_opt = self.get_layer_opt(lrs, wds)
self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)
return self.fit_gen(self.model, self.data, layer_opt, n_cycle, **kwargs)

def warm_up(self, lr, wds=None):
layer_opt = self.get_layer_opt(lr/4, wds)
self.sched = LR_Finder(layer_opt, len(self.data.trn_dl), lr, linear=True)
self.fit_gen(self.model, self.data, layer_opt, 1)
return self.fit_gen(self.model, self.data, layer_opt, 1)

def lr_find(self, start_lr=1e-5, end_lr=10, wds=None, linear=False):
"""Helps you find an optimal learning rate for a model.
Expand Down
19 changes: 3 additions & 16 deletions fastai/lm_rnn.py
Expand Up @@ -130,32 +130,19 @@ def forward(self, input):
outputs.append(o)
return self.concat(raw_outputs), self.concat(outputs)


class LinearRNNOutput(nn.Module):
class LinearDecoder(nn.Module):
initrange=0.1
def __init__(self, n_out, nhid, dropout):
def __init__(self, n_out, nhid, dropout, tie_encoder=None):
super().__init__()
self.decoder = nn.Linear(nhid, n_out)
self.decoder.bias.data.fill_(0)
self.decoder.weight.data.uniform_(-self.initrange, self.initrange)
self.dropout = LockedDropout(dropout)
if tie_encoder: self.decoder.weight = tie_encoder.weight

def forward(self, input):
raw_outputs, outputs = input
output = self.dropout(outputs[-1])
return output, raw_outputs, outputs


class LinearDecoder(LinearRNNOutput):
""" A custom Linear layer that reads the signals from the output of the RNN_Encoder layer,
and decodes to a output of size n_tokens.
"""
def __init__(self, n_out, nhid, dropout, tie_encoder=None):
super().__init__(n_out, nhid, dropout)
if tie_encoder: self.decoder.weight = tie_encoder.weight

def forward(self, input):
output, raw_outputs, outputs = super().forward(input)
decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
result = decoded.view(-1, decoded.size(1))
return result, raw_outputs, outputs
Expand Down
6 changes: 3 additions & 3 deletions fastai/metrics.py
Expand Up @@ -2,12 +2,12 @@
from .torch_imports import *

def accuracy(preds, targs):
preds = np.argmax(preds, axis=1)
return (preds==targs).mean()
preds = torch.max(preds, dim=1)[1]
return (preds==targs).float().mean()

def accuracy_thresh(thresh):
return lambda preds,targs: accuracy_multi(preds, targs, thresh)

def accuracy_multi(preds, targs, thresh):
return ((preds>thresh)==targs).mean()
return ((preds>thresh)==targs).float().mean()

12 changes: 10 additions & 2 deletions fastai/model.py
Expand Up @@ -78,10 +78,15 @@ def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, **kwargs):
avg_mom=0.98
batch_num,avg_loss=0,0.
for cb in callbacks: cb.on_train_begin()
num_batch = len(data.trn_dl)
if epochs<1:
num_batch = int(num_batch*epochs)
epochs = 1

for epoch in tnrange(epochs, desc='Epoch'):
stepper.reset(True)
t = tqdm(iter(data.trn_dl), leave=False, total=len(data.trn_dl))
t = tqdm(iter(data.trn_dl), leave=False, total=num_batch)
i = 0
for (*x,y) in t:
batch_num += 1
for cb in callbacks: cb.on_batch_begin()
Expand All @@ -92,6 +97,8 @@ def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, **kwargs):
stop=False
for cb in callbacks: stop = stop or cb.on_batch_end(debias_loss)
if stop: return
if i>num_batch: break
i += 1

vals = validate(stepper, data.val_dl, metrics)
print(np.round([epoch, debias_loss] + vals, 6))
Expand All @@ -100,6 +107,7 @@ def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, **kwargs):
if stop: break

for cb in callbacks: cb.on_train_end()
return vals


def validate(stepper, dl, metrics):
Expand All @@ -108,7 +116,7 @@ def validate(stepper, dl, metrics):
for (*x,y) in iter(dl):
preds,l = stepper.evaluate(VV(x), VV(y))
loss.append(to_np(l))
res.append([f(to_np(preds),to_np(y)) for f in metrics])
res.append([f(preds.data,y) for f in metrics])
return [np.mean(loss)] + list(np.mean(np.stack(res),0))

def get_prediction(x):
Expand Down
2 changes: 1 addition & 1 deletion fastai/nlp.py
Expand Up @@ -356,7 +356,7 @@ def to_model(self, m, opt_fn):
return RNN_Learner(self, model, opt_fn=opt_fn)

def get_model(self, opt_fn, max_sl, bptt, emb_sz, n_hid, n_layers, dropout, **kwargs):
m = get_rnn_classifer(max_sl, bptt, self.bs, self.c, self.nt,
m = get_rnn_classifer(bptt, max_sl, self.c, self.nt,
layers=[emb_sz*3, self.c], drops=[dropout],
emb_sz=emb_sz, n_hid=n_hid, n_layers=n_layers, pad_token=self.pad_idx, **kwargs)
return self.to_model(m, opt_fn)
Expand Down
2 changes: 1 addition & 1 deletion fastai/text.py
Expand Up @@ -152,7 +152,7 @@ def get_batch(self, i, seq_len):
class LanguageModel(BasicModel):
def get_layer_groups(self):
m = self.model[0]
return [(self.model[1], m.dropouti), *zip(m.rnns, m.dropouths)]
return [*zip(m.rnns, m.dropouths), (self.model[1], m.dropouti)]


class LanguageModelData():
Expand Down

0 comments on commit cb216e1

Please sign in to comment.