In [1]:
# !pip3 install wandb duckduckgo_search -qq

In [5]:
from fastcore.all import *
from fastai.vision.widgets import *
from fastdownload import download_url
from fastai.vision.all import *
from time import sleep
import pandas as pd
import wandb
import params
import utils
import torchvision.models as tvmodels
from fastai.callback.wandb import WandbCallback

In [6]:
train_config = SimpleNamespace(
    framework="fastai",
    img_size=(180, 320),
    batch_size=8,
    augment=True, # use data augmentation
    epochs=10, 
    lr=2e-3,
    arch="resnet18",
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
    log_preds=True,
)

In [1]:
def download_data():
    processed_data_at = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')
    processed_dataset_dir = Path(processed_data_at.download())
    return processed_data_dir

In [3]:
def get_df(processed_dataset_dir, is_test=False):
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    path = processed_dataset_dir/'bcc_images'

    #assign paths
    df['image_fname'] = [f for f in get_image_files(path)]
    df['label'] = [utils.label_func(f) for f in df.image_fname.values]
    
    #val/test mode
    if not is_test:
        df = df[df.stage != 'test'].reset_index(drop=True)
        df['is_valid'] = df.stage == 'valid'
    else:
        df = df[df.stage == 'test'].reset_index(drop=True)

    return df

In [11]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, CategoryBlock),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs, shuffle=True)

In [17]:
def train(config):
    set_seed(config.seed, reproducible=True)
    run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=config)
    
    config = wandb.config
    processed_dataset_dir = download_data()
    df = get_df(processed_dataset_dir)
    dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

    metrics = [accuracy, error_rate, F1Score(average='weighted'), HammingLoss()]
    learn = vision_learner(dls, arch=getattr(tvmodels, config.arch), pretrained=config.pretrained, metrics=metrics)

    callbacks = [
        SaveModelCallback(monitor='accuracy'),
        WandbCallback(log_preds=False, log_model=True)
    ]

    learn.fit_one_cycle(config.epochs, config.lr, cbs=callbacks)
    
    if config.log_preds:
        log_predictions(learn)
    log_metrics(learn)
    
    wandb.finish()

In [18]:
def log_predicitons(learner):
    "Log Predictions with class probabilities"
    samples, outputs, predictions = utils.get_predictions(learn)
    table = utils.create_prob_table(samples, outputs, predictions, params.BIG_CAT_CLASSES)
    wandb.log({"pred_table":table})

In [19]:
def log_metrics(learner):
    scores = learner.validate()
    metric_names = ['final_loss'] + [f'final_{x}' for x in ['accuracy', 'error_rate', 'f1score_weighted', 'hamming_loss']]
    final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
    for k,v in final_results.items():
        wandb.summary[k] = v

In [20]:
train(train_config)

0,1
accuracy,▁▇▆▅█▇▆▅▅▅
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
eps_2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
error_rate,█▂▃▅▁▂▃▄▄▅
f1_score,▁▇▇▅█▇▇▆▆▅
hamming_loss,█▂▃▄▁▂▃▄▄▄
lr_0,▁▁▂▃▄▅▆▇▇██████▇▇▇▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
lr_1,▁▁▂▃▄▅▆▇▇██████▇▇▇▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁

0,1
accuracy,0.83333
epoch,10.0
eps_0,1e-05
eps_1,1e-05
eps_2,1e-05
error_rate,0.16667
f1_score,0.84278
final_accuracy,0.96667
final_error_rate,0.03333
final_f1score_weighted,0.96633


In [26]:
samples[18], outputs[18][0], predictions[18]

((TensorImage([[[ 95,  95,  96,  ...,  92,  90,  84],
                [ 93,  93,  96,  ...,  96,  87,  86],
                [ 92,  93,  95,  ...,  94,  89,  90],
                ...,
                [115, 120, 128,  ..., 137, 138, 143],
                [120, 115, 120,  ..., 134, 135, 139],
                [123, 120, 122,  ..., 143, 137, 132]],
  
               [[105, 105, 105,  ..., 105, 104,  97],
                [103, 104, 106,  ..., 106,  98,  97],
                [102, 103, 105,  ..., 103,  98,  98],
                ...,
                [104, 109, 117,  ..., 122, 123, 128],
                [109, 104, 109,  ..., 119, 120, 124],
                [112, 109, 111,  ..., 128, 122, 117]],
  
               [[ 72,  72,  75,  ...,  70,  69,  61],
                [ 72,  74,  77,  ...,  72,  63,  60],
                [ 76,  77,  79,  ...,  68,  61,  60],
                ...,
                [ 58,  63,  72,  ...,  89,  90,  95],
                [ 63,  59,  64,  ...,  86,  87,  91],
           