In [6]:
import getpass, ipywidgets as ipw, os, json, shlex, io, re, tempfile, subprocess,unittest
import datetime
import pydicom,numpy as np,csv,warnings,pickle,sys,tensorflow as tf
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from IPython.display import FileLink, JSON
from io import BytesIO

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

%load_ext autoreload
%autoreload 2
#%tb

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
from juxnat_lib.xnat_utils import *


In [9]:
from universal_scan_classifier import *
#from matplotlib import pyplot as plt

In [10]:
class ScanClassificationModelGUI(FrontDesk):
    def __init__(self):        
       
        #plumbing objects
        self._scm=ScanClassificationModel()
    
        ######################################################################
        #define interface elements
        layout=ipw.Layout(margin='0 100pt 0 0')
        layout1=ipw.Layout(justify_content='center')
        st={'description_width':'initial'}
                
        #User types in the nomenclature (scan types), comma separated
        self._txt_scan_types=ipw.Text(value="MPRAGE,T1hi,SWI,T2lo,T1lo,DWI,T2FLAIR,T2hi,DSC,GRE_qBOLD,DWI_FRAME1,DWI_PART2,TRACEW,FA,FA_WU,MD,MD_WU,DSC_FRAME1,MTT,MTT_WU,CBV,CBV_WU,CBF,CBF_WU,PBP,TTP", description='Scan types, comma separated:',
                            layout={'width':'400pt'}, style=st, disabled=False)
        self._txt_scan_types.observe(self._on_change_scan_types)
        self._on_change_scan_types(self._txt_scan_types)
        
        #dropdown with all available DICOM tags.
        self._drop_taglist=ipw.Dropdown(options=self._scm._supported_tags,description=\
                             'Available DICOM tags:',style=st,layout={'width':'300pt'})
        
        self._drop_taglist_xnat=ipw.Dropdown(options=self._scm._supported_tags_xnat,description=\
                             'Available XNAT fields:',style=st,layout={'width':'300pt'})

        #button to add active DICOM tag from the list, to the list of tags that go into the model.
        self._btn_addtag=ipw.Button(description="Add",style={},layout={'width':'200pt'})
        self._btn_addtag_xnat=ipw.Button(description="Add",style={},layout={'width':'200pt'})
        self._btn_addtag.on_click(self._addtag) #done
        self._btn_addtag_xnat.on_click(self._addtag_xnat)
        
        #clear all DICOM tags, start over (easier than add/remove etc buttons)
        self._btn_clearall=ipw.Button(description="Clear DICOM tags",style={},layout={'width':'200pt'})
        self._btn_clearall.on_click(self._cleartags) #done

        self._btn_clearall_xnat=ipw.Button(description="Clear XNAT fields",style={},layout={'width':'200pt'})
        self._btn_clearall_xnat.on_click(self._cleartags_xnat) #done
        
        #a box listing currently selected DICOM tags.
        self._tag_btn_box=ipw.HBox([self._btn_addtag,self._btn_clearall])
        self._txt_used_tags=ipw.Textarea(value='', description='Used DICOM tags:', disabled=True,style=st)
        
        self._tag_xnat_btn_box=ipw.HBox([self._btn_addtag_xnat,self._btn_clearall_xnat])
        self._txt_used_tags_xnat=ipw.Textarea(value='', description='Used XNAT fields:', disabled=True,style=st)

        #nomenclature definition name (DICOM tags+ontology)
        self._txt_nomenclature_name=ipw.Text(value="NeuroOncologyMRI", description='Nomenclature name:',
                            layout={'width':'200pt'}, style=st, disabled=False)
        self._txt_nomenclature_name.observe(self._on_change_nomenclature_name)
        self._on_change_nomenclature_name(self._txt_nomenclature_name)
        
        #button save the nomenclature definition to a json file
        self._btn_save=ipw.Button(description="Save definition",style={},layout={'width':'200pt'})
        self._btn_save.on_click(self._save) #in process
        
        #button load the nomenclature definition from a json file
        self._btn_load=ipw.FileUpload(accept='.json',\
                        multiple=False,style={},layout={'width':'200pt'})
        #ipw.Button(description="Load definition",style={},layout={'width':'200pt'})
        self._btn_load.observe(self._load, names='value')
        #self._btn_load.on_click(self._load) #in process
        
        #box containing load and save buttons.
        self._loadsave_btn_box=ipw.HBox([self._btn_save,self._btn_load])
        
        #file link output
        self._file_lnk_out=ipw.Output()
        self._out_log = ipw.Output()
        self._out_json = ipw.Output()
        
        #box containing all elements of nomenclature definition page.
        #name must be 'main_box' to be visible to the GUIPages()
        
        ###########################################
        # split GUI for DICOM tags and XNAT fields.
        dcm_tag_box=ipw.VBox([self._drop_taglist, self._tag_btn_box, self._txt_used_tags])
        xnat_field_box=ipw.VBox([self._drop_taglist_xnat, \
                                 self._tag_xnat_btn_box, self._txt_used_tags_xnat])
        
        acc_titles=['Add DICOM tags','Add XNAT fields']
        acc=ipw.Accordion(children=[dcm_tag_box, xnat_field_box],selected_index=None)
        for i in range(0,2): acc.set_title(i,acc_titles[i])
        ############# end accordion setup #########

        self.main_box=ipw.VBox([self._txt_nomenclature_name,self._txt_scan_types,
                                acc,self._loadsave_btn_box, self._file_lnk_out,self._out_log,self._out_json])
        
        self._tempfile = None

    def _on_change_nomenclature_name(self,b):        
        self._scm._nomenclature_name=self._txt_nomenclature_name.value        
        
    def _addtag(self,b):
        '''
        add current DICOM tag from the list of tags (self._drop_taglist) to the chosen tag list.
        '''
        if self._drop_taglist.value not in self._scm._selected_tags:
            self._scm._selected_tags.append(self._drop_taglist.value)
            self._refresh_txt_used_tags()
            
    def _addtag_xnat(self,b):
        if self._drop_taglist_xnat.value not in self._scm._selected_fields_xnat:
            self._scm._selected_fields_xnat.append(self._drop_taglist_xnat.value)
            self._refresh_txt_used_tags_xnat()
        pass
        
    def _on_change_scan_types(self,b):
        try:
            self._scm._scan_types=sorted(self._txt_scan_types.value.split(','))
            #print (self._scm._scan_types)
        except:
            pass
        
    
    def get_selected_tags_group_element(self):
        return [ ScanClassificationModel.tagname_to_group_element(t) for t in self._scm._selected_tags ]
    
        
    def _refresh_txt_used_tags(self):
        '''
        update used tag list from the internal variable self._selected_tags
        '''
        val=''
        for t in self._scm._selected_tags:
            tag=pydicom.tag.Tag(pydicom.datadict.tag_for_keyword(t))
            val+=t+' '+str(tag)+'\n'
        self._txt_used_tags.value=val
        
    def _refresh_txt_used_tags_xnat(self):
        '''
        update xnat field list from the internal variable self._selected_tags
        '''
        val=''
        for t in self._scm._selected_fields_xnat:
            val+=t+'\n'
        self._txt_used_tags_xnat.value=val
        
    def _cleartags(self,b):
        '''
        clear the selected tag list and text box.
        '''        
        self._scm.clear_selected_tags()
        self._txt_used_tags.value=''
        
    def _cleartags_xnat(self,b):
        self._scm.clear_selected_fields_xnat()
        self._txt_used_tags_xnat.value=''
            
    def _save(self,b):
        '''
        save nomenclature definition (selected DICOM tags, XNAT fields and scan types in a json file)
        '''
        if not self._scm.check_validity():
            print ("Invalid input, cannot save")
            return
        
        d=dict()
        d['scan_types']=self._scm._scan_types
        d['selected_dcm_tags']=self._scm._selected_tags
        d['selected_fields_xnat']=self._scm._selected_fields_xnat
        d['nomenclature_name']=self._scm._nomenclature_name
        
        #save to temp json file
        try:
            os.mkdir('./temp')
        except:
            pass
        self._tempfile=tempfile.NamedTemporaryFile(dir='./temp',mode='w',prefix='nomenclature',suffix='.json')
        with open(self._tempfile.name,'w') as fp: 
            json.dump(d,fp)
            
        #display file link
        out=self._file_lnk_out
        out.outputs=(); lnk=FileLink('temp/'+os.path.basename(self._tempfile.name))
        with out: display(lnk)

    def refresh(self):
        self.enable_nav_prev(False)            
    
    def _load(self,b):        
        '''
        load nomenclature definition (selected DICOM tags and scan types in a json file)
        '''
        with self._out_log: 
            print('_load triggered')
            try:
                fupl=self._btn_load.value
                files=list(fupl)
                if len(files)<1: return
                #print(files[0])
                d=json.loads(bytes(files[0]['content']).decode('utf-8'))
                self._scm.load(d)
                self._txt_scan_types.value=','.join(sorted(d['scan_types']))
                self._txt_nomenclature_name.value=d['nomenclature_name']    
                self._refresh_txt_used_tags()
                self._refresh_txt_used_tags_xnat()
                print('nomenclature uploaded')
            except Exception as e:
                print(e)

class DataSelector(FrontDesk):
    def __init__(self,serialize_file,classifier_model):
        #scans read from uploaded files
        self.scans_unclassified,self.scans_classified=self.scans_unclassified_compressed=[],[]
        
        
        #file name or other description of test, train and validation set sources
        self.test_src=self.train_src=self.val_src='None'
                
        self._classifier_model=classifier_model
        self._connected=False
        self._serialize_file=serialize_file
        st={'description_width':'initial'}
        layout=ipw.Layout(margin='0 100pt 0 0')
        layout1=ipw.Layout(justify_content='center')
        
        self.sp=ServerParams()
        self.sp.serialize(serialize_file,{},True)
                
        
        ###################################
        #GUI elements for XNAT login.
        self.text1=ipw.Text(value=self.sp.server, description='XNAT server:', 
                            layout={'width':'200pt'}, style=st, disabled=False)
#        self.text1=ipw.Text(value='https://cnda.wustl.edu', description='XNAT server:',
#                            layout={'width':'200pt'}, style=st, disabled=False)

        self.text2=ipw.Text(value=self.sp.user,description='user:',
                                disabled=False, style=st, layout={'width':'100pt'})
        self.text3=ipw.Password(value='',description='password:',
                                disabled=False, style=st, layout={'width':'100pt'})
        self.lbl1=ipw.Label('status: not connected', layout={'width':'120pt'}, style=st) #layout={'width':'240px','justify-content':'center'}
        lbl2=ipw.Label('',layout={'width':'120pt'},style=st)
        
        self._project_text=ipw.Text(value=self.sp.project,description="XNAT project:",disabled=False,style=st,\
                           layout={'width':'120pt'})
        self._project_text.observe(self._change_project)
        
        self.btn1=ipw.Button(description="connect",style={},layout={'width':'200pt'})        
        self.btn1.on_click(self.on_connect)                        
        vb1=ipw.HBox([self.text1,self.text2,self.text3])        
        vb2=ipw.HBox([self.btn1,lbl2,self.lbl1])        
        
        #print('project:',self.sp.project)
        
        self.xnat_login_box=ipw.VBox([vb1,self._project_text,vb2])
        
        self._xi=XnatIterator(self.sp)
        
        
        lay1={'width':'200pt'}
        lay2={'width':'250pt'}

        ##########################################
        # GUI elements to prepare a dataset
        prep_lbl1=ipw.Label(value='1. Upload a csv with XNAT experiments, with Subject and Experiment columns',layout=lay1)
        prep_upl1=ipw.FileUpload(accept='.csv',multiple=False)
        self._upl1_status=ipw.Label(value='waiting for upload...',layout=lay1)
        prep_upl1.observe(self.read_uploaded_file1)
        self._fupl1_exp=prep_upl1
        
        prep_lbl3=ipw.Label(value='2. Generate tagged scan list',layout=lay1)
        #prep_btn1=ipw.Button(description='Generate', layout=lay1)
        #prep_btn1.on_click(self.collect_scans)
        self._generate_output=ProcessWithTextProgress('Generate',\
                                                self.collect_scans)
        prep_btn1=self._generate_output._btn_run
        prep_btn1.disabled=False        
        self._out_lnk=ipw.Output()
        
        prep_status=ipw.Label(value='waiting to generate...',layout=lay1)
        prep_hb1=ipw.HBox([prep_lbl1,prep_upl1,self._upl1_status])
        prep_vb2=ipw.VBox([prep_lbl3,prep_status,self._generate_output.main_box,self._out_lnk])
        self._prep_box=ipw.VBox([self.xnat_login_box,prep_hb1,prep_vb2])
        self._coll_status=prep_status
        
        ##############################################
        # GUI elements to upload raw dataset
        prep_lbl5=ipw.Label(value='Upload a csv with tagged scan list, scan type undefined', layout=lay2)
        prep_upl2=ipw.FileUpload(accept='.csv',multiple=False)
        prep_upl2.observe(self.read_uploaded_file2)
        self._upl2_status=ipw.Label(value='waiting for upload...',layout=lay1)
        self._upl_raw_box=ipw.HBox([prep_lbl5,prep_upl2,self._upl2_status])
        self._fupl2_scans_unclassified=prep_upl2
                
        ##############################################
        # GUI elements to download compressed raw
        dlraw_lbl1=ipw.Label(value='Remove duplicates from the scan list (type undefined)', layout=lay2)
        self._dlraw_lbl_current_file=ipw.Label(value='Current dataset: undefined',layout=lay2)
        self._dlraw_btn=dlraw_btn1=ipw.Button(description='Generate',style={},layout={'width':'200pt'},disabled=True)
        dlraw_btn1.on_click(self.on_compress_scans)
        self._dlraw_lnk=ipw.Output()
        self._dlraw_box=ipw.VBox([dlraw_lbl1,self._dlraw_lbl_current_file,dlraw_btn1,self._dlraw_lnk])

        
        ###########################################
        # GUI elements to upload training dataset
        train_lbl1=ipw.Label(value='Upload a csv with tagged scan list, scan type defined')
        train_upl1=ipw.FileUpload(accept='.csv',multiple=False)
        train_upl1.observe(self.read_uploaded_file3)
        self._upl3_status=ipw.Label(value='waiting for upload...',layout=lay1)
        self._train_box=ipw.HBox([train_lbl1,train_upl1,self._upl3_status])
        self._fupl3_scans_classified=train_upl1
        
        ###########################################
        # GUI put together.
        titles=['Generate unlabeled scan list',\
                'Upload unlabeled scan list (testing set)',\
                'Generate unlabeled scan list, no duplicates','Upload pre-labeled scan list (training set)']
        
        acc=ipw.Accordion(children=[self._prep_box, self._upl_raw_box, self._dlraw_box,self._train_box],\
                          selected_index=None)
        for i in range(0,4): acc.set_title(i,titles[i])

            #'XNAT login (optional)',
            #                                    'Prepare dataset (raw or training)',
            #                                    'Upload labeled dataset (optional)'))
        
        ###continue here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        self.main_box=ipw.VBox([acc])
        self._verbosity=1
        '''        
        hc=HOF_Classifier()
        hc.load_model_nn('./scan_classifier_nn.11.26.2019')
        self.hof_classifier=hc
        '''
    def _change_project(self,b):
        self.sp.project=self._project_text.value
        #print('project:',self.sp.project)
        
    def on_connect(self,b):
        #self._show_scanview(False)
        self.lbl1.value='status: connecting...'
        self.sp.server,self.sp.user,self.sp.password=self.text1.value,self.text2.value,self.text3.value
        
        if self.sp.connect():                    
            self.lbl1.value='status: connected'
            self.btn1.description='Reconnect'            
            self._connected=True
            #self.enable_nav_next(True)
            self.sp.serialize(self._serialize_file,{},False)
        else:
            self.lbl1.value='status: connection failed'  
            if self._verbosity>0:
                print(self.sp.jsession)
            self.enable_nav_next(False)    
    
    def _noxnat_callback(self,b):        
        self.btn1.enabled=not b.value

    def refresh(self):
        self.enable_nav_prev(False)
        
    def show_file_link(self,out,file):
        out.outputs=();  f=FileLink(file)
        with out:
            display(f)
            
    def on_compress_scans(self,b):
        '''
        compress the currently available testing set.
        '''
        if len(self.scans_unclassified)<1: return
        sp=ScanProcessor(self._classifier_model._scm)
        compressed=sp.compress_scans(self.scans_unclassified)
        fil='scans_compressed.csv'
        self.write_scans_csv(compressed,fil)
        self.show_file_link(self._dlraw_lnk,fil)
            

    def collect_scans(self,b):
        with self._generate_output.out as o:
            if self._verbosity>0:
                print (self._rows)
            subjs=[ s['Subject'] for s in self._rows ]
            if self._verbosity>0:
                print("len_subj",len(subjs))
            exps=[ s['Experiment'] for s in self._rows ]
            if self._verbosity>0:
                print("len_exps", len(exps))

            tags=self._classifier_model.get_selected_tags_group_element()
            fields=self._classifier_model._scm.get_selected_fields_xnat()
            if (self._verbosity>0):
                print(tags)
                print(fields)

            self.scans_unclassified=self._xi.list_scans_in_experiments(subjs,exps,self._coll_status,\
                                                                    include_dcm_tags=tags,\
                                                         include_xnat_fields=fields,verbosity=2)

            self.scans_unclassified_set="Collected from "+list(self._fupl1_exp.value)[0]

            self._coll_status='Status: found {} scans'.format(len(self.scans_unclassified))
        #print(self.scans)
        fil='all_scans.csv'
        self._dlraw_btn.disabled=False
        self._dlraw_lbl_current_file.value='Current dataset: collected scans'
        self.write_scans_csv(self.scans_unclassified,fil)
        self.show_file_link(self._out_lnk,fil)        
    
    def read_uploaded_file(self,b,status):
        fupl=b
        keys=list(fupl.value)
        if self._verbosity>0:
            print(keys,len(keys))
        if len(keys)<1: return '',False
        #try:
        content_bytes=bytes(fupl.value[0].content)
        csv_reader = csv.DictReader(io.TextIOWrapper(io.BytesIO(content_bytes)),skipinitialspace=True)
        sp=ScanProcessor(self._classifier_model._scm)
        self._rows=sp.uncompress_scans([{k: str(v) for k,v in row.items()} for row in csv_reader])
        
        #except:
        #    status.value='cannot parse csv'
        #    return False
        
        status.value='csv loaded with {} rows'.format(len(self._rows))
        if self._verbosity>0:
            print (len(self._rows))
        #if self._fupl_drop.value=='scans-raw' or self._fupl_drop.value=='scans-classified':
        #    self.scans=self._rows
        #print(self._exps)
        #print (list(fupl.value)[0])
        return list(fupl.value)[0],True
        
    def clear_upload(self,fupl):
        fupl.value = ()
        fupl._counter=0
        
    def read_uploaded_file1(self,b):
        _,res=self.read_uploaded_file(self._fupl1_exp,self._upl1_status)
        self.clear_upload(self._fupl1_exp)
        return res
    
    def read_uploaded_file2(self,b):
        src,res=self.read_uploaded_file(self._fupl2_scans_unclassified,self._upl2_status)
        if res: 
            self.scans_unclassified=self._rows
            self.test_src=src
            self._dlraw_btn.disabled=False
            self._dlraw_lbl_current_file.value='Current dataset: uploaded'
        self.clear_upload(self._fupl2_scans_unclassified)
        return res
    
    def read_uploaded_file3(self,b):
        src,res=self.read_uploaded_file(self._fupl3_scans_classified,self._upl3_status)
        print(src,res)
        if res: 
            self.scans_classified=self._rows
            self.train_src=self.val_src=src
            if self._verbosity>0:
                print('train src:',src)
        self.clear_upload(self._fupl3_scans_classified)
        return res
    
    def write_scans_csv(self, scans, file):
        with open(file, 'w') as output_file:
            dict_writer = csv.DictWriter(output_file, scans[0].keys())
            dict_writer.writeheader()
            dict_writer.writerows(scans)            

In [11]:
class UniversalScanClassifierGUI(FrontDesk):
    def __init__(self,data_selector,scan_classification_model):
        ds=self._ds=data_selector
        scm=self._scm=scan_classification_model
        self._uc=UniversalScanClassifier(scm)
        self.verbosity=0
        
        btn_lay={'width':'200pt'}
        
        self._current_model='None'
        self._current_model_saved=False
        
        #1. load model gui.
        self._load_lbl=ipw.Label(value='Upload model file (.pkl or .zip)')
        self._load_btn=ipw.FileUpload(accept='.pkl,.zip',multiple=False)
        self._load_btn.observe(self._on_load_model)
        self._load_lbl_current_model=ipw.Label(value='Current model: None')
        self._load_box=ipw.VBox([self._load_lbl,self._load_btn,\
                                    self._load_lbl_current_model])
        
        #2. Train model gui.
        try:
            trset=ds.train_src
        except:
            trset='None'
            
        self._train_current_set_lbl=ipw.Label(value='Current training set:{}'.format(trset))
        self._train_lbl_current_model=ipw.Label(value='Current model: {}'.format(self._current_model))
        #self._train_btn=ipw.Button(description='Train model',layout=btn_lay)
        #self._train_btn.on_click(self._on_train_button)
        self._train_output=ProcessWithTextProgress('Train model',self._on_train_button)
        self._train_btn=self._train_output._btn_run
        
        self._train_save_btn=ipw.Button(description='Save trained model',layout=btn_lay,disabled=True)
        self._train_save_btn.on_click(self._on_save_model)
        self._train_lnk1=ipw.Output()
        self._train_lnk2=ipw.Output()
        self._train_box=ipw.VBox([self._train_current_set_lbl,self._train_lbl_current_model,\
                                 self._train_output.main_box,self._train_save_btn,\
                                 self._train_lnk1,self._train_lnk2])
        
        #3. validate model gui
        try:
            vset=ds.train_src
        except:
            vset='None'
        
        self._val_current_set_lbl=ipw.Label(value='Current validation set:{}'.format(vset))
        self._val_lbl_current_model=ipw.Label(value='Current model: {}'.format(self._current_model))
        #self._val_btn=ipw.Button(description='Validate model',layout=btn_lay)
        #self._val_btn.on_click(self._on_validate_button)
        self._val_output=ProcessWithTextProgress('Validate model',self._on_validate_button)
        self._val_btn=self._val_output._btn_run
        
        self._val_box=ipw.VBox([self._val_current_set_lbl,self._val_lbl_current_model,\
                               self._val_output.main_box])
        
        #4. test model gui
        try:
            tset=ds.test_src
        except:
            tset='None'
        self._tst_current_set_lbl=ipw.Label(value='Current testing set:{}'.format(tset))
        self._tst_lbl_current_model=ipw.Label(value='Current model: {}'.format(self._current_model))
        self._tst_btn=ipw.Button(description='Classify scans',layout=btn_lay)
        self._tst_btn.on_click(self._on_test_button)
        self._tst_lnk=ipw.Output()
        self._tst_box=ipw.VBox([self._tst_current_set_lbl,self._tst_lbl_current_model,\
                                   self._tst_btn,self._tst_lnk])

        #Main accordion.
        titles=['Load a previously saved model','Train a model on the training set',\
                'Validate model on the validation set','Test model on the testing set']
        
        acc=ipw.Accordion(children=[self._load_box, self._train_box,\
                                    self._val_box,self._tst_box],selected_index=None)
        for i in range(0,4): acc.set_title(i,titles[i])
            
        self.main_box=ipw.VBox([acc])
                
    def _on_load_model(self,b):
        '''
        Triggered when model files are uploaded
        '''
        fupl=self._load_btn
        if len(fupl.value)<1: return
        keys=list(fupl.value)
        if self.verbosity>0:
            print(keys,len(keys))
        if len(keys)<1: return False
        #try:        
        model_nn=False
        model_file=fupl.value[0].name
        model_data=BytesIO(bytes(fupl.value[0].content))
        with open(model_file,"wb") as f:
            f.write(model_data.getbuffer())            
        
        if self.verbosity>0:
            print('model file: '+model_file)
        if os.path.splitext(model_file)[1]=='.zip': 
            if self.verbosity>0: print('neural net model selected')
            model_nn=True
        
        
        if model_nn:
            if not self._uc.load_model_nn(model_file):
                print('Error loading model '+model_file)
                return
        else:
            if not self._uc.load_model(model_file):
                print('Error loading model '+model_file)
                return
                
        self._current_model='{}'.format(os.path.basename(model_file))
        self._current_model_saved=True        
        self.refresh()    
    
    def _on_train_button(self,b):
        '''
        Runs when 'Train' button is clicked
        '''
        uc,scans=self._uc,self._ds.scans_classified
        #1. generate vocabulary
        with self._train_output.out:
            uc.init_and_run_nn_training(scans)
            
        self._current_model='{}_{}_{}'.format(self._scm._nomenclature_name, self._ds.train_src,
                                                  datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))
        if self.verbosity>0:
            print('current model: ',self._current_model)
        self._current_model_saved=False
        self.refresh()
    
    def _on_save_model(self,b):
        #print('on_save_model triggered')
        self._uc.save_model_nn(self._current_model+'.zip')
        self._current_model_saved=True
        self.show_file_link(self._train_lnk1,self._current_model+'.zip')
        
    def _on_validate_button(self,b):
        '''
        Runs when 'Validate' button is clicked
        '''
        tscans=self._ds.scans_classified
        with self._val_output.out:
            classified_types=self._uc.infer_nn(tscans)
            n=0
            for i in range(len(tscans)):
                
                if classified_types[i]!=tscans[i]['hof_id']:
                    print('position: {}, predicted: {}, actual: {}'\
                          .format(i,classified_types[i],tscans[i]['hof_id']))
                    n+=1
            print('Classification accuracy:',1.-n/len(tscans))
            print("Done.")
        
    
    def _on_test_button(self,b):
        '''
        Runs when 'Test' button is clicked        
        '''
        tscans=self._ds.scans_unclassified
        classified_types=self._uc.infer_nn(tscans)
        for scan,ct in zip(tscans,classified_types):
            scan['hof_id']=ct
        tempfile=re.sub('[^0-9a-zA-Z]+', '_', 'classified_scans_{}.csv'.format(str(datetime.datetime.now())))
        self._uc.write_scans_csv(tscans,tempfile)
        self.show_file_link(self._tst_lnk,tempfile)
        
    def refresh(self):
        if self.verbosity>0:
            print('refresh')
        self._train_current_set_lbl.value='Current training set: {}'.format(ds.train_src)
        self._val_current_set_lbl.value='Current validation set: {}'.format(ds.train_src)
        self._tst_current_set_lbl.value='Current testing set: {}'.format(ds.test_src)
        
        self.enable_nav_prev(True)
        model_saved="saved" if self._current_model_saved else "unsaved"
        model_loaded=self._current_model!="None"
                
        cm="Current model: {} ({})".format(self._current_model,model_saved)
        self._load_lbl_current_model.value=self._train_lbl_current_model.value=\
                self._val_lbl_current_model.value=self._tst_lbl_current_model.value=cm
        
        #enable Train and Validate buttons if classified scans are loaded
        self._train_btn.disabled=not len(self._ds.scans_classified)>0
        self._val_btn.disabled=not(model_loaded and len(self._ds.scans_classified)>0)
        self._tst_btn.disabled=not(model_loaded and len(self._ds.scans_unclassified)>0)
        
        
        #print('_current_model_saved:',self._current_model_saved)
        #print('_current_model:',self._current_model)
        #print('model_saved:',model_saved)
        #print('model_loaded:',model_loaded)
        
        #print('train_save_btn.enabled',not self._current_model_saved and model_loaded)
        self._train_save_btn.disabled=self._current_model_saved
                            
    def show_file_link(self,out,file):
        out.outputs=();  f=FileLink(file)
        with out:
            display(f)

    

In [12]:
scm_gui=ScanClassificationModelGUI()
ser_file='universal_scan_classifier_params.json'
ds=DataSelector(ser_file,scm_gui)
sc_gui=UniversalScanClassifierGUI(ds,scm_gui._scm)

#debug
ds._xi._verbosity=2
ds._verbosity=2

pages=[
    {'title':'Classifier model','frontdesk':scm_gui,'plumbing':None,'prev_label':None,'next_label':'Training set'},
    {'title':'Dataset preparation','frontdesk':ds,'plumbing':None,'prev_label':'Classifier model','next_label':'Training/testing'},
    {'title':'Training/testing','frontdesk':sc_gui,'plumbing':None,'prev_label':'Dataset preparation','next_label':None}
]

g=GUIBook(pages)


VBox(children=(HTML(value='<h4>Classifier model</h4>'), VBox(children=(VBox(children=(Text(value='NeuroOncolog…