# Digit AI Trainer
#### How-to:
- Paste function into last cell
- Run all cells

In [5]:
import torch
from fastai.vision.all import *
import logging
import time
import numpy as np
from sklearn.metrics import accuracy_score

def format_elapsed_time(start_time, end_time):
    # Calculate the time difference in seconds
    delta_time = end_time - start_time

    # Convert to minutes and hours if necessary
    if delta_time >= 3600:  # If delta_time is greater than or equal to 1 hour
        hours = int(delta_time // 3600)
        delta_time %= 3600
    else:
        hours = 0

    if delta_time >= 60:  # If delta_time is greater than or equal to 1 minute
        minutes = int(delta_time // 60)
        delta_time %= 60
    else:
        minutes = 0

    seconds = delta_time  # Remaining time in seconds
    formatted_seconds = f"{seconds:.2f}"  # Format seconds with 2 decimal places

    # Construct and return the formatted string
    return f"{hours} hours, {minutes} minutes, {formatted_seconds} seconds"

def save_confusion_matrix(interp, pic_path, cmap:str="Blues"):
    title = "Confusion Matrix"
    "Plot the confusion matrix, with `title` and using `cmap`."
    # This function is mainly copied from the sklearn docs
    cm = interp.confusion_matrix()
    fig = plt.figure()
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    tick_marks = np.arange(len(interp.vocab))
    plt.xticks(tick_marks, interp.vocab, rotation=90)
    plt.yticks(tick_marks, interp.vocab, rotation=0)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        coeff = f'{cm[i, j]}'
        plt.text(j, i, coeff, horizontalalignment="center", verticalalignment="center", color="white"
                    if cm[i, j] > thresh else "black")

    ax = fig.gca()
    ax.set_ylim(len(interp.vocab)-.5,-.5)

    plt.tight_layout()
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.grid(False)
    fig.savefig(pic_path)

def train_model(digit_folder="pics/digits", batch_size=32, epochs=12, seed=69, lr=3e-3, model_output="models/test1.pkl", save_conf=False, conf_out='None', base_model="resnet18"):
    # Logging
    log_name = model_output.replace(".pkl", ".txt")
    logging.getLogger().handlers = []
    logging.basicConfig(level=logging.INFO)
    file_handler = logging.FileHandler(log_name)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', '%H:%M:%S')
    file_handler.setFormatter(formatter)
    logging.getLogger().addHandler(file_handler)
    with open(log_name, 'w'):
        pass
    logging.info('Logger setup')
    
    # Setup device
    available = torch.cuda.is_available()
    logging.info(f'CUDA: {available}')
    device = torch.device('cuda' if available else 'cpu')
    logging.info(f"Device: {torch.cuda.get_device_name(device)}")

    model_map = {
    'resnet18': models.resnet18,
    'resnet34': models.resnet34,
    'resnet50': models.resnet50,
    'resnet101': models.resnet101,
    'resnet152': models.resnet152,
    'resnext50_32x4d': models.resnext50_32x4d,
    'resnext101_32x8d': models.resnext101_32x8d,
    'resnext101_64x4d': models.resnext101_64x4d,}

    arch = model_map[base_model]

    logging.info(f'Base model architecture: {base_model}')

    # Load data, set seed, augment training photos
    set_seed(seed)

    logging.info(f'Seed: {seed}')

    fingers = DataBlock(
        blocks=(ImageBlock, CategoryBlock),
        get_items=get_image_files,
        splitter=GrandparentSplitter(),
        get_y=parent_label,
        batch_tfms=aug_transforms(mult = 1.5,max_zoom=1.))
    
    logging.info(f'DataBlock created')

    dls = fingers.dataloaders(digit_folder, batch_size=batch_size)

    logging.info(f'DataLoader created')

    # Create classifier
    learn = vision_learner(dls, arch, metrics=error_rate, lr=lr)
    learn.model.to(device)

    logging.info(f'Learner created on {device}')

    logging.info(f'Training started: \n\tNumber of epochs = {epochs} \n\tBatch size = {batch_size} \n\tLearning rate = {lr}')
    time1 = time.time()

    # Train classifier
    learn.fine_tune(epochs)
    torch.cuda.empty_cache()

    time2 = time.time()

    dt = format_elapsed_time(time1, time2)

    logging.info(f'Training finished \n\tTime taken = {dt}')

    # Analyze classifier
    if save_conf:
        interp = ClassificationInterpretation.from_learner(learn)
        save_confusion_matrix(interp, conf_out)
        logging.info(f'Confusion matrix: \n{interp.confusion_matrix()}')

    # Accuracy score
    preds, targets = learn.get_preds(dl=dls.valid)
    acc = accuracy_score(targets.numpy(), np.argmax(preds, axis=1))
    logging.info(f'Validation accuracy = {acc}')

    # Export model
    # Check if file name ends in .pkl
    if not model_output.endswith(".pkl"):
            model_output += ".pkl"

    learn.export(fname=model_output)

    # Close logging
    logging.info("Exiting logger")
    file_handler.close()

##### Paste code into cell below

In [None]:
# Example: train_model(save_conf=True, conf_out="models/test1.png", epochs=0)