In [1]:
#export
from __future__ import annotations
import math,numpy as np,matplotlib.pyplot as plt
from operator import itemgetter
from itertools import zip_longest
import fastcore.all as fc

from torch.utils.data import default_collate

try:
    from .training import *
except:
    from src.miniai.training import *

In [2]:
#export
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [3]:
#export
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [4]:
#export
@fc.delegates(plt.Axes.imshow)
def show_image(im, ax=None, figsize=None, title=None, noframe=True, **kwargs):
    "Show a PIL or PyTorch image on `ax`."
    if fc.hasattrs(im, ('cpu','permute','detach')):
        im = im.detach().cpu()
        if len(im.shape)==3 and im.shape[0]<5: im=im.permute(1,2,0)
    elif not isinstance(im,np.ndarray): im=np.array(im)
    if im.shape[-1]==1: im=im[...,0]
    if ax is None: _,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, **kwargs)
    if title is not None: ax.set_title(title)
    ax.set_xticks([]) 
    ax.set_yticks([]) 
    if noframe: ax.axis('off')
    return ax

In [5]:
#export
@fc.delegates(plt.subplots, keep=True)
def subplots(
    nrows:int=1, # Number of rows in returned axes grid
    ncols:int=1, # Number of columns in returned axes grid
    figsize:tuple=None, # Width, height in inches of the returned figure
    imsize:int=3, # Size (in inches) of images that will be displayed in the returned figure
    suptitle:str=None, # Title to be set to returned figure
    **kwargs
): # fig and axs
    "A figure and set of subplots to display images of `imsize` inches"
    if figsize is None: figsize=(ncols*imsize, nrows*imsize)
    fig,ax = plt.subplots(nrows, ncols, figsize=figsize, **kwargs)
    if suptitle is not None: fig.suptitle(suptitle)
    if nrows*ncols==1: ax = np.array([ax])
    return fig,ax

In [6]:
#export
@fc.delegates(subplots)
def get_grid(
    n:int, # Number of axes
    nrows:int=None, # Number of rows, defaulting to `int(math.sqrt(n))`
    ncols:int=None, # Number of columns, defaulting to `ceil(n/rows)`
    title:str=None, # If passed, title set to the figure
    weight:str='bold', # Title font weight
    size:int=14, # Title font size
    **kwargs,
): # fig and axs
    "Return a grid of `n` axes, `rows` by `cols`"
    if nrows: ncols = ncols or int(np.floor(n/nrows))
    elif ncols: nrows = nrows or int(np.ceil(n/ncols))
    else:
        nrows = int(math.sqrt(n))
        ncols = int(np.floor(n/nrows))
    fig,axs = subplots(nrows, ncols, **kwargs)
    for i in range(n, nrows*ncols): axs.flat[i].set_axis_off()
    if title is not None: fig.suptitle(title, weight=weight, size=size)
    return fig,axs

In [7]:
#export
@fc.delegates(subplots)
def show_images(ims:list, # Images to show
                nrows:int|None=None, # Number of rows in grid
                ncols:int|None=None, # Number of columns in grid (auto-calculated if None)
                titles:list|None=None, # Optional list of titles for each image
                **kwargs):
    "Show all images `ims` as subplots with `rows` using `titles`"
    axs = get_grid(len(ims), nrows, ncols, **kwargs)[1].flat
    for im,t,ax in zip_longest(ims, titles or [], axs): show_image(im, ax=ax, title=t)

In [8]:
#export
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f))

In [9]:
import nbformat
import os

In [10]:
def export_cells(notebook_file, output_file):
    with open(notebook_file) as f:
        nb = nbformat.read(f, as_version=4)

    exported_cells = [cell for cell in nb['cells'] if cell['source'].startswith('#export')]

    with open(output_file, 'w') as f:
        for cell in exported_cells:
            f.write(cell['source'][len('#export'):].strip() + '\n')

In [11]:
fname = '05_datasets.ipynb'
with open(fname) as f:
    nb = nbformat.read(f, as_version=4)

In [12]:
export_cells = [cell for cell in nb['cells'] if cell['source'].startswith('#export')]

In [13]:
export_cells

[]

In [14]:
nb

{'cells': [{'cell_type': 'code',
   'execution_count': 4,
   'metadata': {},
   'outputs': [],
   'source': '#|export\nfrom __future__ import annotations\nimport math,numpy as np,matplotlib.pyplot as plt\nfrom operator import itemgetter\nfrom itertools import zip_longest\nimport fastcore.all as fc\n\nfrom torch.utils.data import default_collate\n\ntry:\n    from .training import *\nexcept:\n    from src.miniai.training import *'},
  {'cell_type': 'code',
   'execution_count': 5,
   'metadata': {},
   'outputs': [],
   'source': '#|export\ndef inplace(f):\n    def _f(b):\n        f(b)\n        return b\n    return _f'},
  {'cell_type': 'code',
   'execution_count': 6,
   'metadata': {},
   'outputs': [],
   'source': '#|export\ndef collate_dict(ds):\n    get = itemgetter(*ds.features)\n    def _f(b): return get(default_collate(b))\n    return _f'},
  {'cell_type': 'code',
   'execution_count': 7,
   'metadata': {},
   'outputs': [],
   'source': '#|export\n@fc.delegates(plt.Axes.imshow)