In [1]:
from fastai.vision.all import *
import csv
import glob
from collections import namedtuple
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
a1_attr = 'path description'
A1Row = namedtuple('A1Row', a1_attr)
csv_path = 'a1.csv'


In [3]:
rows = []
description_map = {}
A1_input = []

with open(csv_path, newline='') as csvfile:
    spamreader = csv.reader(csvfile, delimiter=',', quotechar='"')
    for i, (partial_path, description) in enumerate(spamreader):
        if i == 0: continue
        path = Path('A/A/A1') / partial_path
        row = A1Row(path, description)
        rows.append(row)
        description_map[row.path.name] = row
        A1_input.append(row.path)


In [4]:
def plot_metrics(name, self: Recorder, nrows=None, ncols=None, figsize=None, **kwargs):
    metrics = np.stack(self.values)
    names = self.metric_names[1:-1]
    n = len(names) - 1
    if nrows is None and ncols is None:
        nrows = int(math.sqrt(n))
        ncols = int(np.ceil(n / nrows))
    elif nrows is None: nrows = int(np.ceil(n / ncols))
    elif ncols is None: ncols = int(np.ceil(n / nrows))
    figsize = figsize or (ncols * 6, nrows * 4)
    fig, axs = subplots(nrows, ncols, figsize=figsize, **kwargs)
    axs = [ax if i < n else ax.set_axis_off() for i, ax in enumerate(axs.flatten())][:n]
    for i, (name, ax) in enumerate(zip(names, [axs[0]] + axs)):
        ax.plot(metrics[:, i], color='#1f77b4' if i == 0 else '#ff7f0e', label='valid' if i > 0 else 'train')
        ax.set_title(name if i > 1 else 'losses')
        ax.legend(loc='best')
    plt.savefig(f'{name}.png')

In [5]:
def description_label(f):
    return description_map[f].description
path = Path('A/A/A1')

In [6]:
def train(name, label_fn):
    print('Training', name)
    dls = ImageDataLoaders.from_name_func(path, A1_input, label_fn, item_tfms=Resize(224))
    learn = vision_learner(dls, resnet18, metrics=[error_rate, accuracy, F1Score(average='micro'), Recall(average='micro'), Precision(average='micro')])
    learn.fine_tune(2)
    plot_metrics(name, learn.recorder)
    learn.show_results()
    return learn
    
    

In [7]:
learn = train('Purple', description_label)

Training Purple




epoch,train_loss,valid_loss,error_rate,accuracy,f1_score,recall_score,precision_score,time
0,0.704352,0.212884,0.074883,0.925117,0.925117,0.925117,0.925117,01:34


epoch,train_loss,valid_loss,error_rate,accuracy,f1_score,recall_score,precision_score,time
0,0.054805,0.018661,0.00624,0.99376,0.99376,0.99376,0.99376,02:14


In [23]:
learn.export('purple')

In [None]:
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix()