In [None]:
#default_exp data.source

In [None]:
#export
from local.imports import *
from local.test import *
from local.core import *
from local.data.core import *
from local.data.pipeline import *
from local.notebook.showdoc import show_doc

# Data source
> Base container for all the items

## Convenience functions

In [None]:
#export core
def all_union(sets):
    "Set of union of all `sets` (each `setified` if needed)"
    return set().union(*(map(setify,sets)))

In [None]:
sets = [[1,2],[2,3]]
test_eq(all_union(sets), {1,2,3})

In [None]:
#export core
def all_disjoint(sets):
    "`True` iif no element appears in more than one item of `sets`"
    return sum(map(len,sets))==len(all_union(sets))

In [None]:
assert not all_disjoint(sets)
assert all_disjoint([[1,2],[3,4]])
assert all_disjoint([[1,2],[]])
assert all_disjoint([[1,2]])
assert all_disjoint([])

## Classes

### DataSource -

In [None]:
# export
class DataSource(PipedList):
    "Applies a `Pipeline` of `tfms` to filtered subsets of `items`"
    def __init__(self, items, tfms=None, filts=None):
        if filts is None: filts = [range_of(items)]
        self.filts = listify(mask2idxs(filt) for filt in filts)
        # Create map from item id to filter id
        assert all_disjoint(self.filts)
        self.filt_idx = ListContainer([None]*len(items))
        for i,f in enumerate(self.filts): self.filt_idx[f] = i
        super().__init__(items, tfms)

    def len(self,filt): return len(self.filts[filt])
    def subset(self, i): return _DsrcSubset(self, i)
    def __call__(self, x, filt, **kwargs): return super().__call__(x, filt=filt, **kwargs)
    def decode  (self, x, filt, **kwargs): return super().decode  (x, filt=filt, **kwargs)
    
    def __getitem__(self, i):
        "Transformed item(s) at `i`"
        its,fts = self.items[i],self.filt_idx[i]
        if is_iter(i): return [self(it,f) for it,f in zip(its,fts)]
        else: return self(its, fts)

DataSource.train,DataSource.valid = add_props(lambda i,x: x.subset(i), 2)

In [None]:
# export
add_docs(
    DataSource,
    len="`len` of subset `filt`",
    subset="Filtered subset `i`",
    decode="Transform decode",
)

In [None]:
#export
class _DsrcSubset:
    def __init__(self, dsrc, filt): self.dsrc,self.filt,self.filts = dsrc,filt,dsrc.filts[filt]
    def __getitem__(self,i): return self.dsrc[self.filts[i]]
    def decode(self, o, **kwargs): return self.dsrc.decode(o, self.filt, **kwargs)
    def decode_at(self, i, **kwargs): return self.decode(self[i], **kwargs)
    def __len__(self): return len(self.filts)
    def __eq__(self,b): return all_equal(b,self)
    def __repr__(self): return coll_repr(self)
    def show_at(self, i, **kwargs): return self.dsrc.show(self.decode_at(i), **kwargs)

A `DataSource` provides filtering and transformation capabilities to a list of items. If you don't pass any filters or transforms, it simply provides a single subset with the same behavior as a `ListContainer`.

In [None]:
inp = [0,1,2,3,4]
dsrc = DataSource(inp)

In [None]:
test_eq(dsrc,inp)               # No filters, so equal to input items
test_eq(dsrc.subset(0), inp)     # Only one subset
test_ne(dsrc, [0,1,2,3,5])
test_eq(dsrc[2], 2)         # Retrieve one item (subset 0 is the default)
test_eq(dsrc[1,2], [1,2]) # Retrieve two items by index
mask = [True,False,False,True,False]
test_eq(dsrc[mask], [0,3])  # Retrieve two items by mask
dsrc

DataSource: ListContainer (5 items) [0,1,2,3,4]
tfms - []

Passing `filts` allows you to create multiple subsets.

In [None]:
# filts can be indices
dsrc = DataSource(range(5), filts=[tensor([0,2]), [1,3,4]])
test_eq(dsrc.subset(0), [0,2])
test_eq(dsrc.subset(1), [1,3,4])
test_eq(dsrc.subset(1)[2], 4)          # item 2 of subset 1

# filts can be boolean masks (they don't have to cover all items, but must be disjoint)
filts = [[False,True,True,False,True], [True,False,False,False,False]]
dsrc = DataSource(range(5), filts=filts)
test_eq(dsrc.subset(0), [1,2,4])
test_eq(dsrc.subset(1), [0])
dsrc

DataSource: ListContainer (5 items) [0,1,2,3,4]
tfms - []

Pass `tfms` to have transformations applied before returning items.

In [None]:
# apply transforms to all items
tfms = [lambda x: x*2,lambda x: x+1]
filts = [[1,2],[0,3,4]]
dsrc = DataSource(range(5), tfms, filts=filts)
test_eq(dsrc.subset(0),[3,5])
test_eq(dsrc.subset(1),[1,7,9])
test_eq(dsrc.subset(0)[False,True], [5])

The subset idx is also passed to your transform, so if it is an instance of `Transform` it will only be applied if the filt idx matches.

In [None]:
# only transform subset 1
class _Tfm(Transform):
    def encodes(self, x): return x*2
    def decodes(self, x): return x//2
    def show(self, x): return f" * {x}"
        
tfm = _Tfm(filt=1)
dsrc = DataSource(range(5), tfm, filts=[[1,2],[0,3,4]])
test_eq(dsrc.subset(0),[1,2])
test_eq(dsrc.subset(1),[0,6,8])
test_eq(dsrc.train[False,True], [2])

In [None]:
show_doc(DataSource.__getitem__)

<h4 id="DataSource.__getitem__" class="doc_header"><code>__getitem__</code><a href="https://nbviewer.jupyter.org/github/fastai/fastai_docs/blob/master/dev/05_data_source.ipynb#DataSource--" class="source_link" style="float:right">[source]</a></h4>

> <code>__getitem__</code>(**`i`**)

Transformed item(s) at `i`

`i` can be an int, or list of ints, or a boolean mask.

In [None]:
dsrc[False,True], dsrc[1], dsrc[1]

([1], 1, 1)

In [None]:
show_doc(DataSource.decode_at)

<h4 id="Pipeline.decode_at" class="doc_header"><code>decode_at</code><a href="https://nbviewer.jupyter.org/github/fastai/fastai_docs/blob/master/dev/02_data_pipeline.ipynb#Pipeline--" class="source_link" style="float:right">[source]</a></h4>

> <code>decode_at</code>(**`idx`**)

Decoded version of `get`

In [None]:
test_eq(dsrc.valid[1], 6)
test_eq(dsrc.valid.decode_at(1), 3)

In [None]:
show_doc(DataSource.len)

<h4 id="DataSource.len" class="doc_header"><code>len</code><a href="https://nbviewer.jupyter.org/github/fastai/fastai_docs/blob/master/dev/05_data_source.ipynb#DataSource--" class="source_link" style="float:right">[source]</a></h4>

> <code>len</code>(**`filt`**)

`len` of subset `filt`

In [None]:
[dsrc.len(i) for i in range(2)]

[2, 3]

In [None]:
show_doc(DataSource.decode)

<h4 id="DataSource.decode" class="doc_header"><code>decode</code><a href="https://nbviewer.jupyter.org/github/fastai/fastai_docs/blob/master/dev/05_data_source.ipynb#DataSource--" class="source_link" style="float:right">[source]</a></h4>

> <code>decode</code>(**`x`**, **`filt`**, **\*\*`kwargs`**)

Transform decode

In [None]:
t = dsrc.valid[1]
test_eq(dsrc.decode(t,filt=1), 3)

In [None]:
show_doc(DataSource.show_at)

<h4 id="Pipeline.show_at" class="doc_header"><code>show_at</code><a href="https://nbviewer.jupyter.org/github/fastai/fastai_docs/blob/master/dev/02_data_pipeline.ipynb#Pipeline--" class="source_link" style="float:right">[source]</a></h4>

> <code>show_at</code>(**`idx`**)

Call `tfm.show` for item `idx`/`filt`

In [None]:
test_eq(dsrc.valid.show_at(1), ' * 3')

## Export -

In [None]:
#hide
from local.notebook.export import notebook2script
notebook2script(all_fs=True)

Converted 00_test.ipynb.
Converted 01_core.ipynb.
Converted 02_data_pipeline.ipynb.
Converted 03_data_external.ipynb.
Converted 04_data_core.ipynb.
Converted 05_data_source.ipynb.
Converted 06_vision_core.ipynb.
Converted 07_pets_tutorial.ipynb.
Converted 90_notebook_core.ipynb.
Converted 91_notebook_export.ipynb.
Converted 92_notebook_showdoc.ipynb.
Converted 93_notebook_export2html.ipynb.
Converted 94_index.ipynb.
