## imports

In [None]:
import fastbook
fastbook.setup_book()
from fastbook import *
from fastai.vision.all import *
import torchvision.models as models
import pandas as pd
import numpy as np
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from jetcam.csi_camera import CSICamera
import ipywidgets
import traitlets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg

import warnings
warnings.filterwarnings("ignore")

## camera

In [None]:
camera = CSICamera(width=224, height=224, capture_device=0)
camera.running = True
print("camera created")

## camera_widget

In [None]:
camera.unobserve_all()
camera_widget = ipywidgets.Image()
traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)
camera_view_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_widget])
])
print("camera_view_widget created")

## model dataset

In [None]:
LABEL_COLS = ['background', 'maya', 'neg_basil', 'neg_coco', 'pos_chain', 'pos_cluster']
filepath_list = glob.glob('gram_stain/*/*.jpg')
labels = [str(filepath_list[i]).split("/")[-2] for i in range(len(filepath_list))]
filepath = pd.Series(filepath_list, name='filepath').astype(str)
label = pd.Series(labels, name='label')
train_df = pd.concat([label, filepath], axis=1)
train_df = train_df.sample(frac=1, random_state=0).reset_index(drop=True)
N_FOLDS = 5
train_df['fold'] = -1
strat_kfold = MultilabelStratifiedKFold(n_splits=N_FOLDS, random_state=43, shuffle=True)
for i, (_, test_index) in enumerate(strat_kfold.split(train_df.filepath.values, train_df.iloc[:,1:].values)):
    train_df.iloc[test_index, -1] = i
train_df['fold'] = train_df['fold'].astype('int')
train_df = train_df.reset_index(drop=True)
augs_train = []
def get_data(fold):
    train_df_fold = ((train_df.loc[train_df.fold==fold]).reset_index(drop=True)).index
    dblock = DataBlock(blocks=(ImageBlock(cls=PILImage), CategoryBlock(vocab=LABEL_COLS)),
                       splitter=IndexSplitter(train_df_fold),
                       get_x=ColReader('filepath'),
                       get_y=ColReader('label'),
                       item_tfms=Resize(300, method="squish"),
                       batch_tfms=augs_train,
                       )
    dls = dblock.dataloaders(train_df, bs=2)
    return dls

## live_execution_widget

In [None]:
import threading
import time
import torch.nn.functional as F
import torch

state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')
prediction_widget = ipywidgets.Text(description='prediction')

def live(state_widget, datal, model, camera, prediction_widget):
    while state_widget.value == 'live':
        image = camera.value
        
        with torch.no_grad():
            image_, = first(datal.test_dl([image]))
            image_ = TensorImage(datal.train.decode((image_,))[0][0])
            tani_, tani_id, prob_list_ = model.predict(image_)
            prediction_widget.value = tani_
            
def start_live(change):
    dls = get_data(0)
    learner_cnn = cnn_learner(dls, models.mobilenet_v2, cut=-1, normalize=True, loss_func=CrossEntropyLossFlat(), opt_func=Adam, metrics=[accuracy] ).load('xresnet18_fold_0')
    #learner_cnn = cnn_learner(dls, xresnet50, normalize=True, loss_func=CrossEntropyLossFlat(), opt_func=Adam, metrics=[accuracy] ).load('xresnet50')
    
    
    if change['new'] == 'live':
        execute_thread = threading.Thread(target=live, args=(state_widget, dls, learner_cnn, camera, prediction_widget))
        execute_thread.start()

state_widget.observe(start_live, names='value')

live_execution_widget = ipywidgets.VBox([
    prediction_widget,
    state_widget
])
print("live_execution_widget created")

## prediction

In [None]:
all_widget = ipywidgets.VBox([
    ipywidgets.HBox([camera_view_widget, live_execution_widget]), 
])

display(all_widget)

In [None]:
import os
import IPython
if type(camera) is CSICamera:
    print("Ignore 'Exception in thread' tracebacks\n")
    camera.cap.release()
os._exit(0)