In [1]:
#|default_exp fine_tune

In [12]:
#|export
import timm
import wandb
import argparse
import torchvision
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

In [13]:
#|hide
from nbdev.showdoc import *

In [14]:
#|export
WANDB_PROJECT = 'paddy-ft'
WANDB_ENTITY = 'bilalcodehub'

In [15]:
#|export
config_defaults = SimpleNamespace(
    batch_size=32,
    epochs=1,
    num_experiments=1,
    learning_rate=2e-3,
    img_size=224,
    resize_method="crop",
    model_name="resnet34",
    pool="concat",
    seed=42,
    wandb_project=WANDB_PROJECT,
    wandb_entity=WANDB_ENTITY,
    split_func="default",
)   

In [16]:
#|export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=config_defaults.batch_size)
    parser.add_argument('--epochs', type=int, default=config_defaults.epochs)
    parser.add_argument('--num_experiments', type=int, default=config_defaults.num_experiments)
    parser.add_argument('--learning_rate', type=float, default=config_defaults.learning_rate)
    parser.add_argument('--img_size', type=int, default=config_defaults.img_size)
    parser.add_argument('--resize_method', type=str, default=config_defaults.resize_method)
    parser.add_argument('--model_name', type=str, default=config_defaults.model_name)
    parser.add_argument('--split_func', type=str, default=config_defaults.split_func)
    parser.add_argument('--pool', type=str, default=config_defaults.pool)
    parser.add_argument('--seed', type=int, default=config_defaults.seed)
    parser.add_argument('--wandb_project', type=str, default=WANDB_PROJECT)
    parser.add_argument('--wandb_entity', type=str, default=WANDB_ENTITY)
    return parser.parse_args()

In [17]:
#|export
def get_gpu_mem(device=0):
    gpu_mem = torch.cuda.memory_stats_as_nested_dict(device=device)
    return (gpu_mem["reserved_bytes"]["small_pool"]["peak"] + gpu_mem["reserved_bytes"]["large_pool"]["peak"])*1024**-3

In [18]:
#|export
def get_dataset(batch_size, img_size, seed, method="crop"):
    path = Path.home()/'.fastai/data/paddy'
    files = get_image_files(path/'train_images')
    dls = ImageDataLoaders.from_folder(path, files, valid_pct=0.2,
                                       seed=seed, bs=batch_size, item_tfms=Resize(img_size, method=method))
    return dls, [error_rate, accuracy]

In [19]:
#|export
def train(config=config_defaults):
    with wandb.init(project=config.wandb_project, group='timm', entity=config.wandb_entity, config=config):
        config=wandb.config
        dls, metrics = get_dataset(config.batch_size, config.img_size, config.seed, config.resize_method)
        learn = vision_learner(dls, config.model_name, metrics=metrics, concat_pool=(config.pool=="pool"),
                              cbs=WandbCallback(log=None, log_preds=False)).to_fp16()
        ti=time.perf_counter()
        learn.fine_tune(config.epochs, config.learning_rate)
        wandb.summary['GPU_mem'] = get_gpu_mem(learn.dls.device)
        wandb.summary['model_family'] = gconfig.model_name.split('_')[0]
        wandb.summary['fit_time'] = ti.perf_counter()-ti
        

In [20]:
#|export
if __name__=="__main__":
    args=parse_args()
    train(config=args)

usage: ipykernel_launcher.py [-h] [--batch_size BATCH_SIZE] [--epochs EPOCHS]
                             [--num_experiments NUM_EXPERIMENTS]
                             [--learning_rate LEARNING_RATE]
                             [--img_size IMG_SIZE]
                             [--resize_method RESIZE_METHOD]
                             [--model_name MODEL_NAME]
                             [--split_func SPLIT_FUNC] [--pool POOL]
                             [--seed SEED] [--wandb_project WANDB_PROJECT]
                             [--wandb_entity WANDB_ENTITY]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/bilal/.local/share/jupyter/runtime/kernel-00deb071-8000-4693-81de-ae91f05fbff7.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
#|hide
from nbdev.export import notebook2script
notebook2script()