# Digit AI Trainer
#### How-to:
- Paste function into last cell
- Run all cells

In [2]:
import torch
from fastai.vision.all import *

In [19]:
def save_confusion_matrix(interp, pic_path, cmap:str="Blues"):
    title = "Confusion Matrix"
    "Plot the confusion matrix, with `title` and using `cmap`."
    # This function is mainly copied from the sklearn docs
    cm = interp.confusion_matrix()
    fig = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(interp.vocab))
    plt.xticks(tick_marks, interp.vocab, rotation=90)
    plt.yticks(tick_marks, interp.vocab, rotation=0)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        coeff = f'{cm[i, j]}'
        plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                    if cm[i, j] > thresh else "black")

    ax = fig.gca()
    ax.set_ylim(len(interp.vocab)-.5,-.5)

    plt.tight_layout()
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.grid(False)
    fig.savefig(pic_path)

In [20]:
def train_model(digit_folder="digits", batch_size=32, epochs=12, seed=69, lr=3e-3, model_output="balls.pkl", save_conf=False, conf_out='None', base_model="resnet18"):
    # Setup device
    available = torch.cuda.is_available()
    device = torch.device('cuda' if available else 'cpu')
    print(f"Device: {torch.cuda.get_device_name(device)}")

    model_map = {
    'resnet18': models.resnet18,
    'resnet34': models.resnet34,
    'resnet50': models.resnet50,
    'resnet101': models.resnet101,
    'resnet152': models.resnet152,
    'resnext50_32x4d': models.resnext50_32x4d,
    'resnext101_32x8d': models.resnext101_32x8d,
    'resnext101_64x4d': models.resnext101_64x4d,}

    arch = model_map[base_model]

    # Load data, set seed, augment training photos
    set_seed(seed)
    fingers = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_items=get_image_files,
        splitter=GrandparentSplitter(),
        get_y=parent_label,
        batch_tfms=aug_transforms(mult = 1.5,max_zoom=1.))
    dls = fingers.dataloaders(digit_folder, batch_size=batch_size)

    # Create classifier
    learn = vision_learner(dls, arch, metrics=error_rate, lr=lr)
    learn.model.to(device)

    # Train classifier
    learn.fine_tune(epochs)
    torch.cuda.empty_cache()

    # Analyze classifier
    if save_conf:
        interp = ClassificationInterpretation.from_learner(learn)
        save_confusion_matrix(interp, conf_out)


    # Export model
    # Check if file name ends in .pkl
    if not model_output.endswith(".pkl"):
            model_output += ".pkl"
    learn.export(fname="models/" + model_output)

##### Paste code in cell below

In [None]:
# Example: train_model(save_conf=True, conf_out="BALLS.png", epochs=1)