Skip to content

Commit

Permalink
Merge branch 'master' into madhava/timm_unet
Browse files Browse the repository at this point in the history
  • Loading branch information
madhavajay committed Jul 16, 2022
2 parents f10ac6d + c31cd24 commit e1508c1
Show file tree
Hide file tree
Showing 52 changed files with 813 additions and 559 deletions.
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,21 @@

<!-- do not remove -->

## 2.7.6

### New Features

- Initial Mac GPU (mps) support ([#3719](https://github.com/fastai/fastai/issues/3719))


## 2.7.5

### New Features

- auto-normalize timm models ([#3716](https://github.com/fastai/fastai/issues/3716))
- PyTorch 1.12 support


## 2.7.4

### New Features
Expand Down
2 changes: 1 addition & 1 deletion fastai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = "2.7.5"
__version__ = "2.7.8"

2 changes: 1 addition & 1 deletion fastai/callback/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
dl:DataLoader=None, # `DataLoader` used for fetching `Learner` predictions
with_input:bool=False, # Whether to return inputs in `GatherPredsCallback`
with_decoded:bool=False, # Whether to return decoded predictions
cbs:(Callback,list)=None, # `Callback` to temporarily remove from `Learner`
cbs:Callback|list=None, # `Callback` to temporarily remove from `Learner`
reorder:bool=True # Whether to sort prediction results
):
self.cbs = L(cbs)
Expand Down
2 changes: 1 addition & 1 deletion fastai/callback/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def tensorboard_log(x:TensorImage, y: TensorCategory, samples, outs, writer, ste

# Cell
@typedispatch
def tensorboard_log(x:TensorImage, y: (TensorImageBase, TensorPoint, TensorBBox), samples, outs, writer, step):
def tensorboard_log(x:TensorImage, y: TensorImageBase|TensorPoint|TensorBBox, samples, outs, writer, step):
fig,axs = get_grid(len(samples), return_fig=True, double=True)
for i in range(2):
axs[::2] = [b.show(ctx=c) for b,c in zip(samples.itemgot(i),axs[::2])]
Expand Down
4 changes: 2 additions & 2 deletions fastai/callback/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def _unlist(l):

# Cell
@typedispatch
def wandb_process(x:TensorImage, y:(TensorCategory,TensorMultiCategory), samples, outs, preds):
def wandb_process(x:TensorImage, y:TensorCategory|TensorMultiCategory, samples, outs, preds):
table = wandb.Table(columns=["Input image", "Ground_Truth", "Predictions"])
for (image, label), pred_label in zip(samples,outs):
table.add_data(wandb.Image(image.permute(1,2,0)), label, _unlist(pred_label))
Expand All @@ -310,7 +310,7 @@ def wandb_process(x:TensorImage, y:TensorMask, samples, outs, preds):

# Cell
@typedispatch
def wandb_process(x:TensorText, y:(TensorCategory,TensorMultiCategory), samples, outs, preds):
def wandb_process(x:TensorText, y:TensorCategory|TensorMultiCategory, samples, outs, preds):
data = [[s[0], s[1], o[0]] for s,o in zip(samples,outs)]
return {"Prediction_Samples": wandb.Table(data=data, columns=["Text", "Target", "Prediction"])}

Expand Down
4 changes: 2 additions & 2 deletions fastai/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self,

# Cell
def CategoryBlock(
vocab:(list, pd.Series)=None, # List of unique class names
vocab:list|pd.Series=None, # List of unique class names
sort:bool=True, # Sort the classes alphabetically
add_na:bool=False, # Add `#na#` to `vocab`
):
Expand All @@ -41,7 +41,7 @@ def CategoryBlock(
# Cell
def MultiCategoryBlock(
encoded:bool=False, # Whether the data comes in one-hot encoded
vocab:(list,pd.Series)=None, # List of unique class names
vocab:list|pd.Series=None, # List of unique class names
add_na:bool=False, # Add `#na#` to `vocab`
):
"`TransformBlock` for multi-label categorical targets"
Expand Down
16 changes: 8 additions & 8 deletions fastai/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ class DataLoaders(GetAttr):
_default='train'
def __init__(self,
*loaders, # `DataLoader` objects to wrap
path:(str,Path)='.', # Path to store export objects
path:str|Path='.', # Path to store export objects
device=None # Device to put `DataLoaders`
):
self.loaders,self.path = list(loaders),Path(path)
Expand Down Expand Up @@ -257,7 +257,7 @@ def cpu(self): return self.to(device=torch.device('cpu'))
@classmethod
def from_dsets(cls,
*ds, # `Datasets` object(s)
path:(str,Path)='.', # Path to put in `DataLoaders`
path:str|Path='.', # Path to put in `DataLoaders`
bs:int=64, # Size of batch
device=None, # Device to put `DataLoaders`
dl_type=TfmdDL, # Type of `DataLoader`
Expand All @@ -274,7 +274,7 @@ def from_dsets(cls,
def from_dblock(cls,
dblock, # `DataBlock` object
source, # Source of data. Can be `Path` to files
path:(str, Path)='.', # Path to put in `DataLoaders`
path:str|Path='.', # Path to put in `DataLoaders`
bs:int=64, # Size of batch
val_bs:int=None, # Size of batch for validation `DataLoader`
shuffle:bool=True, # Whether to shuffle data
Expand All @@ -290,7 +290,7 @@ def from_dblock(cls,
valid_ds="Validation `Dataset`",
to="Use `device`",
add_tfms="Add `tfms` to `loaders` for `event",
cuda="Use the gpu if available",
cuda="Use accelerator if available",
cpu="Use the cpu",
new_empty="Create a new empty version of `self` with the same transforms",
from_dblock="Create a dataloaders from a given `dblock`")
Expand All @@ -315,7 +315,7 @@ def dataloaders(self,
shuffle:bool=True, # Shuffle training `DataLoader`
val_shuffle:bool=False, # Shuffle validation `DataLoader`
n:int=None, # Size of `Datasets` used to create `DataLoader`
path:(str, Path)='.', # Path to put in `DataLoaders`
path:str|Path='.', # Path to put in `DataLoaders`
dl_type:TfmdDL=None, # Type of `DataLoader`
dl_kwargs:list=None, # List of kwargs to pass to individual `DataLoader`s
device:torch.device=None, # Device to put `DataLoaders`
Expand Down Expand Up @@ -346,7 +346,7 @@ class TfmdLists(FilteredBase, L, GetAttr):
_default='tfms'
def __init__(self,
items:list, # Items to apply `Transform`s to
tfms:(list,Pipeline), # `Transform`(s) or `Pipeline` to apply
tfms:list|Pipeline, # `Transform`(s) or `Pipeline` to apply
use_list:bool=None, # Use `list` in `L`
do_setup:bool=True, # Call `setup()` for `Transform`
split_idx:int=None, # Apply `Transform`(s) to training or validation set. `0` for training set and `1` for validation set
Expand Down Expand Up @@ -444,7 +444,7 @@ class Datasets(FilteredBase):
"A dataset that creates a tuple from each `tfms`"
def __init__(self,
items:list=None, # List of items to create `Datasets`
tfms:(list,Pipeline)=None, # List of `Transform`(s) or `Pipeline` to apply
tfms:list|Pipeline=None, # List of `Transform`(s) or `Pipeline` to apply
tls:TfmdLists=None, # If None, `self.tls` is generated from `items` and `tfms`
n_inp:int=None, # Number of elements in `Datasets` tuple that should be considered part of input
dl_type=None, # Default type of `DataLoader` used when function `FilteredBase.dataloaders` is called
Expand Down Expand Up @@ -502,7 +502,7 @@ def set_split_idx(self, i):

# Cell
def test_set(
dsets:(Datasets, TfmdLists), # Map- or iterable-style dataset from which to load the data
dsets:Datasets|TfmdLists, # Map- or iterable-style dataset from which to load the data
test_items, # Items in test dataset
rm_tfms=None, # Start index of `Transform`(s) from validation set in `dsets` to apply
with_labels:bool=False # Whether the test items contain labels
Expand Down
2 changes: 1 addition & 1 deletion fastai/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _FakeLoader:
def _fn_noops(self, x=None, *args, **kwargs): return x

_IterableDataset_len_called,_auto_collation,collate_fn,drop_last = None,False,_fn_noops,False
_index_sampler,generator,prefetch_factor = Inf.count,None,2
_index_sampler,generator,prefetch_factor,_get_shared_seed = Inf.count,None,2,noop
dataset_kind = _dataset_kind = _DatasetKind.Iterable

def __init__(self, d, pin_memory, num_workers, timeout, persistent_workers,pin_memory_device):
Expand Down
4 changes: 2 additions & 2 deletions fastai/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def from_learner(cls,
return cls(learn, dl, losses, act)

def top_losses(self,
k:(int,None)=None, # Return `k` losses, defaults to all
k:int|None=None, # Return `k` losses, defaults to all
largest:bool=True, # Sort losses by largest or smallest
items:bool=False # Whether to return input items
):
Expand All @@ -67,7 +67,7 @@ def top_losses(self,
else: return losses, idx

def plot_top_losses(self,
k:(int,list), # Number of losses to plot
k:int|list, # Number of losses to plot
largest:bool=True, # Sort losses by largest or smallest
**kwargs
):
Expand Down
28 changes: 22 additions & 6 deletions fastai/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,32 @@ def before_epoch(self):
# Cell
class Learner(GetAttr):
_default='model'
def __init__(self, dls, model, loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=trainable_params, cbs=None,
metrics=None, path=None, model_dir='models', wd=None, wd_bn_bias=False, train_bn=True,
moms=(0.95,0.85,0.95)):
def __init__(self,
dls, # `DataLoaders` containing data for each dataset needed for `model`
model:callable, # The model to train or use for inference
loss_func:callable|None=None, # Loss function for training
opt_func=Adam, # Optimisation function for training
lr=defaults.lr, # Learning rate
splitter:callable=trainable_params, # Used to split parameters into layer groups
cbs=None, # Callbacks
metrics=None, # Printed after each epoch
path=None, # Parent directory to save, load, and export models
model_dir='models', # Subdirectory to save and load models
wd=None, # Weight decay
wd_bn_bias=False, # Apply weight decay to batchnorm bias params?
train_bn=True, # Always train batchnorm layers?
moms=(0.95,0.85,0.95), # Momentum
default_cbs:bool=True # Include default callbacks?
):
path = Path(path) if path is not None else getattr(dls, 'path', Path('.'))
if loss_func is None:
loss_func = getattr(dls.train_ds, 'loss_func', None)
assert loss_func is not None, "Could not infer loss function from the data, please pass a loss function."
self.dls,self.model = dls,model
store_attr(but='dls,model,cbs')
self.training,self.create_mbar,self.logger,self.opt,self.cbs = False,True,print,None,L()
self.add_cbs(L(defaults.callbacks)+L(cbs))
if default_cbs: self.add_cbs(L(defaults.callbacks))
self.add_cbs(cbs)
self.lock = threading.Lock()
self("after_create")

Expand Down Expand Up @@ -200,7 +215,8 @@ def _do_one_batch(self):
self.opt.zero_grad()

def _set_device(self, b):
model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device('cpu')
# model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device('cpu')
model_device = next(self.model.parameters()).device
dls_device = getattr(self.dls, 'device', default_device())
if model_device == dls_device: return to_device(b, dls_device)
else: return to_device(b, model_device)
Expand Down Expand Up @@ -412,7 +428,7 @@ def load_learner(fname, cpu=True, pickle_module=pickle):
map_loc = 'cpu' if cpu else default_device()
try: res = torch.load(fname, map_location=map_loc, pickle_module=pickle_module)
except AttributeError as e:
e.args = [f"Custom classes or functions exported with your `Learner` are not available in the namespace currently.\nPlease re-declare or import them before calling `load_learner`:\n\t{e.args[0]}"]
e.args = [f"Custom classes or functions exported with your `Learner` not available in namespace.\Re-declare/import before loading:\n\t{e.args[0]}"]
raise
if cpu:
res.dls.cpu()
Expand Down
4 changes: 2 additions & 2 deletions fastai/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def _contiguous(self, x:Tensor) -> TensorBase:
return TensorBase(x.transpose(self.axis,-1).contiguous()) if isinstance(x,torch.Tensor) else x

def __call__(self,
inp:(Tensor,list), # Predictions from a `Learner`
targ:(Tensor,list), # Actual y label
inp:Tensor|list, # Predictions from a `Learner`
targ:Tensor|list, # Actual y label
**kwargs
) -> TensorBase: # `loss_cls` calculated on `inp` and `targ`
inp,targ = map(self._contiguous, (inp,targ))
Expand Down
8 changes: 4 additions & 4 deletions fastai/medical/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class TensorDicom(TensorImage):
class PILDicom(PILBase):
_open_args,_tensor_cls,_show_args = {},TensorDicom,TensorDicom._show_args
@classmethod
def create(cls, fn:(Path,str,bytes), mode=None)->None:
def create(cls, fn:Path|str|bytes, mode=None)->None:
"Open a `DICOM file` from path `fn` or bytes `fn` and load it as a `PIL Image`"
if isinstance(fn,bytes): im = Image.fromarray(pydicom.dcmread(pydicom.filebase.DicomBytesIO(fn)).pixel_array)
if isinstance(fn,(Path,str)): im = Image.fromarray(pydicom.dcmread(fn).pixel_array)
Expand Down Expand Up @@ -301,7 +301,7 @@ def to_3chan(x:DcmDataset, win1, win2, bins=None):

# Cell
@patch
def save_jpg(x:(Tensor,DcmDataset), path, wins, bins=None, quality=90):
def save_jpg(x:Tensor|DcmDataset, path, wins, bins=None, quality=90):
"Save tensor or dicom image into `jpg` format"
fn = Path(path).with_suffix('.jpg')
x = (x.to_nchan(wins, bins)*255).byte()
Expand All @@ -310,14 +310,14 @@ def save_jpg(x:(Tensor,DcmDataset), path, wins, bins=None, quality=90):

# Cell
@patch
def to_uint16(x:(Tensor,DcmDataset), bins=None):
def to_uint16(x:Tensor|DcmDataset, bins=None):
"Convert into a unit16 array"
d = x.hist_scaled(bins).clamp(0,1) * 2**16
return d.numpy().astype(np.uint16)

# Cell
@patch
def save_tif16(x:(Tensor,DcmDataset), path, bins=None, compress=True):
def save_tif16(x:Tensor|DcmDataset, path, bins=None, compress=True):
"Save tensor or dicom image into `tiff` format"
fn = Path(path).with_suffix('.tif')
Image.fromarray(x.to_uint16(bins)).save(str(fn), compression='tiff_deflate' if compress else None)
Expand Down
9 changes: 1 addition & 8 deletions fastai/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class _BaseOptimizer():
"Common functionality between `Optimizer` and `OptimWrapper`"
def all_params(self,
n:(slice, int)=slice(None), # Extended slicing over the optimizer `param_lists`
n:slice|int=slice(None), # Extended slicing over the optimizer `param_lists`
with_grad:bool=False # Get all param tuples. If `True` select only those with a gradient
):
res = L((p,pg,self.state[p],hyper) for pg,hyper in zip(self.param_lists[n],self.hypers[n]) for p in pg)
Expand All @@ -45,13 +45,6 @@ def freeze(self):
assert(len(self.param_lists)>1)
self.freeze_to(-1)

def set_freeze(self,
n:int,
rg:bool, # Whether grad is required
ignore_force_train=False # Overwrites "force_train" or batch norm always trains even if frozen
):
for p in self.param_lists[n]: p.requires_grad_(rg or (state.get('force_train', False) and not ignore_force_train))

def set_hypers(self, **kwargs): L(kwargs.items()).starmap(self.set_hyper)
def _set_hyper(self,
k, # Hyperparameter key
Expand Down
4 changes: 2 additions & 2 deletions fastai/tabular/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TabularDataLoaders(DataLoaders):
@delegates(Tabular.dataloaders, but=["dl_type", "dl_kwargs"])
def from_df(cls,
df:pd.DataFrame,
path:(str,Path)='.', # Location of `df`, defaults to current working directory
path:str|Path='.', # Location of `df`, defaults to current working directory
procs:list=None, # List of `TabularProc`s
cat_names:list=None, # Column names pertaining to categorical variables
cont_names:list=None, # Column names pertaining to continuous variables
Expand All @@ -37,7 +37,7 @@ def from_df(cls,

@classmethod
def from_csv(cls,
csv:(str,Path,io.BufferedReader), # A csv of training data
csv:str|Path|io.BufferedReader, # A csv of training data
skipinitialspace:bool=True, # Skip spaces after delimiter
**kwargs
):
Expand Down
2 changes: 1 addition & 1 deletion fastai/tabular/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def tabular_learner(
emb_szs:list=None, # Tuples of `n_unique, embedding_size` for all categorical features
config:dict=None, # Config params for TabularModel from `tabular_config`
n_out:int=None, # Final output size of the model
y_range:(float, float)=None, # Low and high for the final sigmoid function
y_range:Tuple[float,float]=None, # Low and high for the final sigmoid function
**kwargs
):
"Get a `Learner` using `dls`, with `metrics`, including a `TabularModel` created using the remaining params."
Expand Down
4 changes: 2 additions & 2 deletions fastai/tabular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _one_emb_sz(classes, n, sz_dict=None):

# Cell
def get_emb_sz(
to:(Tabular, TabularPandas),
to:Tabular|TabularPandas,
sz_dict:dict=None # Dictionary of {'class_name' : size, ...} to override default `emb_sz_rule`
) -> list: # List of embedding sizes for each category
"Get embedding size for each cat_name in `Tabular` or `TabularPandas`, or populate embedding size manually using sz_dict"
Expand All @@ -42,7 +42,7 @@ def __init__(self,
n_cont:int, # Number of continuous variables
out_sz:int, # Number of outputs for final `LinBnDrop` layer
layers:list, # Sequence of ints used to specify the input and output size of each `LinBnDrop` layer
ps:(float, list)=None, # Sequence of dropout probabilities for `LinBnDrop`
ps:float|list=None, # Sequence of dropout probabilities for `LinBnDrop`
embed_p:float=0., # Dropout probability for `Embedding` layer
y_range=None, # Low and high for `SigmoidRange` activation
use_bn:bool=True, # Use `BatchNorm1d` in `LinBnDrop` layers
Expand Down
6 changes: 3 additions & 3 deletions fastai/text/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def load_model_text(
model, # Model architecture
opt:Optimizer, # `Optimizer` used to fit the model
with_opt:bool=None, # Enable to load `Optimizer` state
device:(int,str,torch.device)=None, # Sets the device, uses 'cpu' if unspecified
device:int|str|torch.device=None, # Sets the device, uses 'cpu' if unspecified
strict:bool=True # Whether to strictly enforce the keys of `file`s state dict match with the model `Module.state_dict`
):
"Load `model` from `file` along with `opt` (if available, and if `with_opt`)"
Expand Down Expand Up @@ -126,7 +126,7 @@ def save_encoder(self,

def load_encoder(self,
file:str, # Filename of the saved encoder
device:(int,str,torch.device)=None # Device used to load, defaults to `dls` device
device:int|str|torch.device=None # Device used to load, defaults to `dls` device
):
"Load the encoder `file` from the model directory, optionally ensuring it's on `device`"
encoder = get_model(self.model)[0]
Expand Down Expand Up @@ -159,7 +159,7 @@ def load_pretrained(self,
def load(self,
file:str, # Filename of saved model
with_opt:bool=None, # Enable to load `Optimizer` state
device:(int,str,torch.device)=None, # Device used to load, defaults to `dls` device
device:int|str|torch.device=None, # Device used to load, defaults to `dls` device
**kwargs
):
if device is None: device = self.dls.device
Expand Down
2 changes: 1 addition & 1 deletion fastai/text/models/awdlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class WeightDropout(Module):
def __init__(self,
module:nn.Module, # Wrapped module
weight_p:float, # Weight dropout probability
layer_names:(str,list)='weight_hh_l0' # Name(s) of the parameters to apply dropout to
layer_names:str|list='weight_hh_l0' # Name(s) of the parameters to apply dropout to
):
self.module,self.weight_p,self.layer_names = module,weight_p,L(layer_names)
for layer in self.layer_names:
Expand Down

0 comments on commit e1508c1

Please sign in to comment.