In [None]:
# default_exp dataloaders
# default_cls_lvl 3

In [None]:
#hide
%matplotlib notebook
from fastai2.callback.progress import *
from fastai2.callback.tracker import *
from fastai2.callback.schedule import *

In [None]:
#export
from seqdata.core import *
from seqdata.model import *
from seqdata.learner import *
from fastai2.basics import *

import math

## Custom Dataloaders
> Pytorch Modules for Training Models for sequential data

# Truncated Backpropagation Through Time

The tbptt dataloader needs to split the minibatches that are created in several smaller minibatches that will be returned sequentially before the next minibatch may be created.

In [None]:
#export
@delegates()
class TbpttDl(TfmdDL):

    def __init__(self, dataset, sub_seq_len=None,max_batches=None, seq_len = None ,shuffle=True,num_workers=0, **kwargs):
        store_attr(self,'sub_seq_len,max_batches,seq_len')
        super().__init__(dataset=dataset, shuffle=shuffle, num_workers=num_workers, **kwargs)

        self.rnn_reset = sub_seq_len is None #always reset stateful rnns if there are no subsequences
    @property
    def n_sub_seq(self):
        if self.seq_len is None: self.seq_len = self.do_item(0)[0].shape[0]
        return math.ceil(self.seq_len / self.sub_seq_len)
        
    def __len__(self):
        l = super().__len__()
        if self.sub_seq_len is not None: l *= self.n_sub_seq
        if self.max_batches is not None: l = min(l,self.max_batches)
        return l
    
    def create_batches(self, samps):
        yield from self._tbptt_generator(super().create_batches(samps))
        
    def _tbptt_generator(self,batch_iter):
        '''generator function that splits batches in smaller windows and truncates batch count if max_batches is set'''
        for idx,b in enumerate(batch_iter):
            if self.sub_seq_len is None:
                self.rnn_reset = True
                if self.max_batches is not None and idx >= self.max_batches: return
                yield b
            else:
                for i in range(self.n_sub_seq):
                    if self.max_batches is not None and ((idx*self.n_sub_seq)+i) >= self.max_batches: return
                    self.rnn_reset = i == 0
                    #it is importan to retain the tuple type, or future transforms may now work
                    trunc_b = tuple([retain_type(x[:,i*self.sub_seq_len:(i+1)*self.sub_seq_len],x) for x in b])
                    yield trunc_b
                    

In [None]:
seq = DataBlock(blocks=(SequenceBlock.from_hdf(['current','voltage'],TensorSequencesInput,clm_shift=[-1,-1]),
                        SequenceBlock.from_hdf(['voltage'],TensorSequencesOutput,clm_shift=[1])),
                 get_items=CreateDict([DfHDFCreateWindows(win_sz=1000+1,stp_sz=1000,clm='current')]),
                 splitter=ApplyToDict(ParentSplitter()))
db = seq.dataloaders(get_hdf_files('test_data/'),dl_type=TbpttDl,sub_seq_len=10,max_batches=1)

In [None]:
db.train.one_batch()[0].shape,db.valid.one_batch()[0].shape

(torch.Size([64, 10, 2]), torch.Size([64, 1000, 2]))

num_workers has to be 0. If there are parallel workers, the order of minibatches will be corrupted

## TBPTT_Reset_Callback
The stateful model needs to reset its hidden state, when a new sequence begins. The callback reads the reset flag and acts accordingly.

In [None]:
#export
class TbpttResetCB(Callback):
    "`Callback` resets the rnn model with every new sequence for tbptt"
        
    def begin_batch(self):
        dl = self.learn.dls.train if self.training else self.learn.dls.valid
#         if not self.training: import pdb; pdb.set_trace()
        if hasattr(dl,'rnn_reset')and dl.rnn_reset and hasattr(self.model,'reset'):
            self.model.reset()
        
    def after_fit(self): 
        if hasattr(self.model,'reset'): self.model.reset()

## Example

In [None]:
lrn = RNNLearner(db,num_layers=1,rnn_type='gru',stateful=False,metrics=[SkipNLoss(fun_rmse,100)])
lrn.add_cb(TbpttResetCB())

<fastai2.learner.Learner at 0x7f8778937668>

In [None]:
lrn.fit_one_cycle(1,lr_max=3e-2)

epoch,train_loss,valid_loss,fun_rmse,time
0,14.254913,14.269542,3.765729,00:00


In [None]:
db.train.max_batches = 100

In [None]:
db.train.sub_seq_len = 10

In [None]:
lrn.fit_one_cycle(1,lr_max=3e-2)

epoch,train_loss,valid_loss,fun_rmse,time
0,1.72559,0.01813,0.110914,00:02


# Weighted Sampling Dataloader

A weighted sampling dataloader for nonuniforly distributed data. A factory method receives the base Dataloader class and returns the inherited weighted sampling dataloader class

In [None]:
def WeightedDL_Factory(cls):
    assert issubclass(cls, TfmdDL)
    
    class WeightedDL(cls):
        def __init__(self, dataset, wgts=None,shuffle=True, **kwargs):
            super().__init__(dataset=dataset, shuffle=True, **kwargs)
            wgts = array([1.]*len(dataset) if wgts is None else wgts)
            self.wgts = wgts/wgts.sum()

        def get_idxs(self):
            if self.n==0: return []
            if not self.shuffle: return super().get_idxs()
            return list(np.random.choice(self.n, self.n, p=self.wgts))
    return WeightedDL

In [None]:
dl = WeightedDL_Factory(TfmdDL)([1,2]*5,bs=10,wgts=[2,1]*5)

In [None]:
dl.one_batch()

tensor([2, 1, 1, 1, 1, 1, 2, 1, 1, 2])

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_model.ipynb.
Converted 02_learner.ipynb.
Converted 03_dataloaders.ipynb.
Converted 11_dualrnn.ipynb.
Converted 12_TensorQuaternions.ipynb.
Converted 13_HPOpt.ipynb.
Converted index.ipynb.
