In [12]:
import dtext as dt
import tensorflow as tf
import getpass
import ipywidgets as ipw
import os
import json
import shlex
import re
from PIL import Image
import logging
import time
import threading
from IPython.display import FileLink, FileLinks
import csv
import tempfile
from datetime import datetime
import threading
import math

class ServerParams:
    '''
    Container parameters received from XNAT
    '''
    def __init__(self,server=None, user=None, password=None, project=None,subject=None,experiment=None):
        self.server,self.user,self.password,self.project,self.subject,self.experiment= \
            server,user,password,project,subject,experiment
        self.jsession=''
        self.connected=False

    def __str__(self):
        return "server:{}, user: {}, project: {}, subject: {}, experiment: {}, connected: {}".\
            format(self.server,self.user,self.project,self.subject,self.experiment,self.connected)
        
    def connect(self):
        cmd="curl -o jsession.txt -k -u "+ self.user+":"+self.password+ \
            " "+self.server+"/data/JSESSION"
        os.system(cmd)
        with open("jsession.txt") as f:
            self.jsession=f.read()           
        self.connected=(len(self.jsession)==32)
        return self.connected
    
class XnatIterator:
    def __init__(self,sp):
        self.sp=sp
        self._subjects=[]
        self._experiments=[]
        self._scans=[]
            
    def _curl_cmd_prefix(self):
        return "curl  -k --cookie JSESSIONID=" + self.sp.jsession
    
    def _curl_cmd_path(self,path):
        return shlex.quote(self.sp.server+"/data/archive/projects/"+self.sp.project+path)
    
    def _curl_cmd(self,path):        
        cmd=self._curl_cmd_prefix()+' -o temp_query.json '+self._curl_cmd_path(path)
        os.system('rm -f temp_query.json')
        os.system(cmd)
        
    def curl_download_single_file(self,path,dest):
        cmd=self._curl_cmd_prefix()+' -o '+dest+' '+ self.sp.server + path
        #print(cmd)
        os.system(cmd)
        
    def set_project(self,pr):
        self.sp.project=pr
    
    def list_subjects(self):
        self._curl_cmd('/subjects?format=json')
        with open ('temp_query.json') as tq:
            try: 
                df=json.loads(tq.read())
            except:
             #   print ('cannot list subjects')
                return []
        #print(df)
        subjs=sorted(df['ResultSet']['Result'], key=lambda k:k['label'])        
        self._subjects=[f['label'] for f in subjs]        
        return self._subjects
    
    def scan_file_loader(self,scans,tdir,lock):
        for s in scans:
            #print(s)
            files=self.list_scan_files(s['subject'],s['experiment'],s['ID'])
            if len(files)>0:
                t=tdir+'/'+s['subject']+'_'+s['experiment']+'_'+s['ID']
                self.curl_download_single_file(files[0],t+'.dcm')
                os.system("dcmj2pnm +G +Wn +on "+t+".dcm "+ t + ".png")
                os.system( "rm -f " + t + ".dcm" )
                lock.acquire()
                s['png'] = t+".png"
                lock.release()
            else:
                s['png']='N/A'
    
    def list_experiments(self,subject):
        self._curl_cmd('/subjects/'+subject+"/experiments?xsiType=xnat:imageSessionData&format=json")        
        with open ('temp_query.json') as tq:
            try: 
                df=json.loads(tq.read())
            except: 
                print ('error listing experiments!')
                return []
        #print(df['ResultSet']['Result'])
        exps=sorted(df['ResultSet']['Result'], key=lambda k:k['date'])
        self._experiments=[f['label'] for f in exps]
        return self._experiments
    
    def list_scans(self,subject,experiment, listDcmFiles=False):
        self._curl_cmd('/subjects/'+ subject +'/experiments/' \
            +experiment + "/scans?columns=ID,frames,type,series_description")
        
        with open ('temp_query.json') as sf:
            try: df=json.loads(sf.read())
            except:
                #print ('cannot list scans')
                return []
        self._scans=sorted(df['ResultSet']['Result'], key=lambda k:k['xnat_imagescandata_id'])
        for s in self._scans:            
            s['subject']=subject
            s['experiment']=experiment
        
        if listDcmFiles:
            for s in self._scans:
                files=self.list_scan_files(subject,experiment,s['ID'])
                s['files']=files
        return self._scans
    
    def get_dcm_files_for_scans(self,subject,experiment,scans):
        for s in scans:
            files=self.list_scan_files(subject,experiment,s['ID'])
            s['files']=files
        
    def list_scan_files(self,subject,experiment,scan):
        self._curl_cmd('/subjects/'+ subject +'/experiments/' \
            +experiment + '/scans/'+scan+'/resources/DICOM/files')
        with open ('temp_query.json') as sf:
            try: df=json.loads(sf.read())
            except:
                #print ('cannot list scans')
                return []
        lst=sorted(df['ResultSet']['Result'], key=lambda k:k['Name'])
        return [ f['URI'] for f in lst ]        
    """
    list all scans in project, filtered by subject prefix. 
    Display progres in output textarea.
    Save output in speficified json file.
    """
    def list_scans_all(self,subjects,subject_prefix,json_out_file,output):
        scans=[]
        ns=0
        for su in subjects:
            if not su.lower().startswith(subject_prefix.lower()): continue
            experiments=self.list_experiments(su)
            for e in experiments:
                if output: output.value='running, found {} scans'.format(ns)
                #print('list_scans {} {}'.format(su,e))
                sscans=self.list_scans(su,e)
                for s in sscans:
                    scans.append(s)
                    ns+=1
            with open(json_out_file, 'w') as fp:
                json.dump(scans, fp)
        return scans

In [21]:
class Selector:
    '''
    GUI to select a particular experiment to process.
    '''       
    
    def on_connect(self,b):
        self._show_scanview(False)
        self.lbl1.value='status: connecting...'
        self.sp.server,self.sp.user,self.sp.password=self.text1.value,self.text2.value,self.text3.value
        if self.sp.connect():                    
            self.lbl1.value='status: connected'
            self.btn1.description='reconnect'            
            if self._first_time: self._show()
            self._project_list()
            self._first_time=False
            #self.btn1.disabled=True
        else:
            self.lbl1.value='status: connection failed'
                
    def show_login_form(self):                
        st={'description_width':'initial'}
        layout=ipw.Layout(margin='0 100pt 0 0')
        layout1=ipw.Layout(justify_content='center')
        
        #st={}
        self.text1=ipw.Text(value='https://xnat-dev-mga1.nrg.wustl.edu', description='XNAT server:', 
                            layout={'width':'200pt'}, style=st, disabled=False)
#        self.text1=ipw.Text(value='https://cnda.wustl.edu', description='XNAT server:', 
#                            layout={'width':'200pt'}, style=st, disabled=False)

        self.text2=ipw.Text(value='admin',description='user:',
                                disabled=False, style=st, layout={'width':'120pt'})
        self.text3=ipw.Password(value='admin',description='password:',
                                disabled=False, style=st, layout={'width':'120pt'})
#        self.text2=ipw.Text(value='mmilch',description='user:',
#                                disabled=False, style=st, layout={'width':'120pt'})
#        self.text3=ipw.Password(value='',description='password:',
#                                disabled=False, style=st, layout={'width':'120pt'})
        self.lbl1=ipw.Label('status: not connected', layout={'width':'120pt'}, style=st) #layout={'width':'240px','justify-content':'center'}
        lbl2=ipw.Label('',layout={'width':'120pt'},style=st)
        self.btn1=ipw.Button(description="connect",style={},layout={'width':'200pt'})
        self.btn1.on_click(self.on_connect)
        vb1=ipw.HBox([self.text1,self.text2,self.text3])
        vb2=ipw.HBox([self.btn1,lbl2,self.lbl1])
        self._login_box=ipw.VBox([vb1,vb2])
        display(self._login_box)
                
    def _show(self):
        display(self._box); self._set_enable('proj_sel',True)
    
    def _project_list(self):        
        cmd=self._query_prefix()+'?format=json'
        sl=self._ps_lbl_status
        sl.value='status: listing projects...'        
        os.system(cmd)
        with open ('mga_temp_query.json') as tq:
            #try: 
            df=json.loads(tq.read())
            #except: 
            #   print ('cannot read mga_temp_query.json')
            #    return
        projs=sorted(df['ResultSet']['Result'], key=lambda k:k['ID'])
        
        self._projects=[f['ID'] for f in projs]
        self._ps_drop_prj.options=self._projects
        sl.value='status: ready'
        
            
    def _query_prefix(self):
        return "curl -o mga_temp_query.json -k --cookie JSESSIONID=" + self.sp.jsession + " " + self.sp.server+"/data/archive/projects/"

    def _set_enable_below(self,module,status):
        bind=self._modules[module][0]        
        for mk,mv in zip(self._modules.keys(),self._modules.values()):
            if mv[0] > bind: self._set_enable(mk,status)
                
    def _set_enable(self,module,status):
        b=self._modules[module][1]
        for c in b.children: c.disabled=not status
            
    def _hide_scans(self):
        self.scans=[]
        self._vs_boxL.children=[]
        
    def _subject_list(self):
        if self._ps_drop_prj.value is None: return
        cmd=self._query_prefix()+self._ps_drop_prj.value+'/subjects?format=json'
        self._ps_lbl_status.value='status: listing subjects...'
        os.system(cmd)
        with open ('mga_temp_query.json') as tq:
            try: 
                df=json.loads(tq.read())
            except: return
        #print(df)
        subjs=sorted(df['ResultSet']['Result'], key=lambda k:k['label'])        
        self._subjects=[f['label'] for f in subjs]
        self._ps_lbl_status.value='status: found {} subject(s)'.format(len(self._subjects))
        
    def _on_project_changed(self,d):
        self._hide_scans()
        self._show_scanview(False)
        self._set_enable_below('proj_sel',False)
        self.sp.project=self._ps_drop_prj.value        
        self._subject_list()        
        self._set_enable('build_sl',True)
        self._sl_enable(True)
    
    def _sl_enable(self,enabled):
        self._sl_lbl_status.value='status: ready'
        self._sl_btn_build.disabled=not enabled        
    
    def _on_build_scan_list(self,b):
        b.disabled=True
        st=self._sl_lbl_status
        st.value='listing matching scans in project'
        self._scans=self.xi.list_scans_all(self._subjects,self._sl_txt_pref.value,'temp_scans.json',st)
        st.value='status: collected {} scans'.format(len(self._scans))
        if len(self._scans)>0:
            self._set_enable('dtext',True)
            self._dt_enable()
        #print(self._scans)
        
    def _dt_enable(self):
        self._dt_btn_detect.disabled=False
        self._dt_lbl_status.value='status: ready'
        self._dt_out.value=''

    def _show_scanview(self,show):
        if show:            
            self._shs_box.children=[self._shs_pgsel,self._shs_lbl_ptot]
            self._vs_box.children=[self._vs_boxL,self._vs_boxR] 
            self._dt_out.outputs=()
        else:
            self._vs_box.children=[] 
            self._shs_box.children=[]
            self._dt_out.outputs=()
    
    def _make_wrk_dir(self):
        dir_name='./dtext_build/'+self.sp.project+'_'+datetime.now().strftime("%Y%m%d_%H%M%S")
        #print (dir_name)
        os.system('mkdir -p '+dir_name)        
        self._wrkdir=dir_name
        
    def _on_detect_text(self,b):
        b.disabled=True
        ls=self._dt_lbl_status
        batch_size=self._batch_size
        self._make_wrk_dir()
        lock=threading.Lock()
        self._show_scanview(False)
        
        #detector thread is the main thread.
        threading.Thread(target=self.xi.scan_file_loader,args=(self._scans,self._wrkdir,lock)).start()
        nS=len(self._scans)
        ls.value='running text detection'
        self._dt_out.outputs=()        
        with self._dt_out:
            #print("nS: {}".format(nS))
            #print ("range: {}".format(nS))
            for i in range(0,nS,batch_size):
                #print("cur_i: {}".format(i))
                iE=min(i+batch_size,nS)
                #print("cur_iE: {}".format(iE))
                batch_verified=False
                while not batch_verified:
                    print ('verifying batch {}-{}'.format(i+1,iE))
                    lock.acquire()
                    batch=[ self._scans[k] for k in range(i,iE) ]
                    for s in batch:
                        if not 'png' in s.keys(): batch_verified=False; break
                    else:
                        batch_verified=True
                    lock.release()
                    if not batch_verified:
                        print ('waiting for batch {}-{} to download'.format(i+1,iE))
                        time.sleep(1.1)
                ls.value='running text detection, {}/{} complete'.format(i,nS)
                print ('detecting text in batch {}-{}'.format(i+1,iE))
                files=[ s['png'] for s in batch ]
                #print(files)
                res=dt.run_detection_on_files(self._tf_model,0.99,files)
                #print(res)                
                for s,r in zip(batch,res): 
                    try: os.remove(r['infile']) 
                    except: pass                    
                    if not 'text_present' in r.keys(): continue
                    if r['text_present']==1:
                        lock.acquire()
                        s['dtext']=r
                        lock.release()
                        self._scans_dtext+=[s]
                
        ls.value='detected text in {} scans'.format(len(self._scans_dtext))
        #print(self._scans_dtext)
        '''
        self._scans_dtext=[]
        for i in range(0,min(100,len(self._scans))):
            sdt=self._scans[i]
            sdt['dtext']={'infile':'./xnat_temp/MW253_MW253_MR_20090812_preop_6.png','outfile':'./xnat_temp/MW253_MW253_MR_20090812_preop_6.dtext.png'}
            self._scans_dtext+=[sdt]
        '''    
        self._show_scanview(True)
        self._show_dtext_scans()
        
    def _show_dtext_scans(self,page=1):
        #print('show_dtext_scans')
        ntot=len(self._scans_dtext)
        if ntot<1: return
        npg=int(math.ceil(ntot/self._scans_per_page))
        self._set_enable('selpage',True)
        self._scan_pages=npg
        self._shs_lbl_ptot.value='of '+str(npg)
        self._shs_pgsel.disabled=False
        self._shs_pgsel.value=page
    
    def _on_scan_page_changed(self,s):
        #print('on_scan_page_changed')
        ps=self._shs_pgsel
        if ps.value<1: return
        if ps.value>self._scan_pages: ps.value=self._scan_pages; return
        self._show_dtext_scan_page()
        
    def _scan_page_range(self):
        #scan rows.
        npg,cur_page=self._scan_pages,self._shs_pgsel.value
        sst=self._scans_per_page*(cur_page-1)
        return sst,min(sst+self._scans_per_page,len(self._scans_dtext))                
    
    def _show_dtext_scan_page(self):
        #print('show_dtext_scan_page')
        style={'description_width':'initial'}
        #header.
        rows=[ipw.HBox([
            ipw.Label(value='Subj/Exp/Scan',style=style,layout={'width':'300px'}),
            ipw.Label(value='Description',style=style,layout={'width':'150px'}),
            ipw.Label(value='Frm',style=style,layout={'width':'40px'}),
            ipw.Label(value='Preview',style=style,layout={'width':'50px'}),
            ipw.Label(value='PHI',style=style,layout={'width':'60px'})
        ])]
        sst,sen=self._scan_page_range()
        #print("npg: {}, cur_page: {}, nscans: {} sst: {}, sen: {}".format(npg,cur_page,len(self._scans_dtext),sst,sen))
        for i in range(sst,sen):
            s=self._scans_dtext[i]
            btn=ipw.Button(description='view', 
                           tooltip="{}|{}|{}|{}".format(s['subject'],s['experiment'],s['ID'],s['dtext']['outfile']),
                           disabled=False,layout={'width':'50px'})
            btn.on_click(self._on_dtext_view)
            phi='Y' if 'phi' in s.keys() else ''
            row=ipw.HBox([
                ipw.Label(value="{}/{}/{}".format(s['subject'],s['experiment'],s['ID']),style=style,layout={'width':'300px'}),
                ipw.Label(value=s['series_description'],style=style,layout={'width':'150px'}),
                ipw.Label(value=s['frames'],style=style,layout={'width':'40px'}),
                btn,
                ipw.Label(value=phi,style=style,layout={'width':'60px'})
            ])
            rows.append(row)        
        rows.append(self._dts_btn_rprt)        
        rows.append(self._dts_out_lnk)        
        self._vs_boxL.children=rows
        self._dts_btn_rprt.disabled=False        
        
    def _on_dtext_view(self,b):
        #if not self.sp.connected: _scans_status('not connected'); return
        fil=b.tooltip.split('|')[-1]
        with open(fil,"rb") as f: 
            img = f.read()
        self._vs_im.close()
        self._vs_imcaption.value="Previewing "+fil[:-11]
        self._vs_im=ipw.Image(value=img,width=400,height=400,format='png',layout={'width':'400px','height':'400px'})
        self._vs_btn_phi.tooltip=b.tooltip
        self._vs_btn_excl.tooltip=b.tooltip
        self._vs_boxR.children=[self._vs_imcaption,self._vs_im,self._vs_boxRB]
        self._vs_btn_phi.disabled=False
        self._vs_btn_excl.disabled=False
    
    def _scan_from_tooltip(self,tooltip):
        sst,sen=self._scan_page_range()
        sb,exp,sc,_=tooltip.split('|')
        for i in range(sst,sen):
            s=self._scans_dtext[i]
            if s['ID']==sc and s['subject']==sb and s['experiment']==exp: 
                return s,i
        else:
            return None                
        
    def _on_label_phi(self,b):
        s,_=self._scan_from_tooltip(b.tooltip)
        if not s is None:
            s['phi']=1
            self._show_dtext_scan_page()
            
    def _on_excl_scantype(self,b):
        s,ind=self._scan_from_tooltip(b.tooltip)
        
        if s is None: return
        descr=s['series_description']
        
        #find previous scan with a different type
        for i in range(ind,-1,-1):
            if self._scans_dtext[i]['series_description'] != descr:
                found=i
                break
        else:
            found=0
        
        sc_prev_id=self._scans_dtext[found]['xnat_imagescandata_id']
        #print('sc_prev_id: {}'.format(sc_prev_id))
        
        #keep the found_ind of the previous examined scan.
        found_ind,ind,new_scans_dtext=0,0,[]
        for i in range(len(self._scans_dtext)):
            sc=self._scans_dtext[i]
            if sc['series_description']!=descr:
                new_scans_dtext+=[sc]
                ind+=1
                if sc['xnat_imagescandata_id']==sc_prev_id:
                    found_ind=ind
        if len(new_scans_dtext)<1: return  
        
        self._scans_dtext=new_scans_dtext
        page=found_ind/self._scans_per_page+1 if found>0 else 1
        #print("found_ind: {}, page: {}".format(found_ind,page))
        #self._shs_pgsel.value=int(page)        
        self._show_dtext_scans(int(page))
        self._on_scan_page_changed(self._shs_pgsel)
    
    def _on_report_phi(self,b):
        out=self._dts_out_lnk; out.outputs=()        
        #generate a csv from the list of scans
        #get_ipython().system("mkdir -p './xnat_temp'")
        #fil=tempfile.NamedTemporaryFile(mode='w+b', suffix='.csv', prefix='phi_scans_', dir='./temp')
        fil=self._wrkdir+'/phi_scans.csv'
        scans_phi=[ s for s in self._scans_dtext if 'phi' in s.keys() ]
        if len(scans_phi)<1: return
        self._write_scans_csv(scans_phi,fil)
        f=FileLink(fil)
        with out:
            display(f)

    def _write_scans_csv(self, scans, file):
            with open(file, 'w') as output_file:
                dict_writer = csv.DictWriter(output_file, scans[0].keys())
                dict_writer.writeheader()
                dict_writer.writerows(scans)
                
    def __init__(self, server_params):
        self.sp=server_params
        self.xi=XnatIterator(server_params)
        style={'description_width':'initial'}
        
        #project and subject selector box.
        self._ps_drop_prj=ipw.Dropdown(description='project:',style=style,layout={'width':'200px'})
        self._ps_drop_prj.observe(self._on_project_changed,names='value')        
        self._ps_lbl_status=ipw.Label(description='status: ready',style=style,layout={'width':'200px'})                                                
        self._ps_box=ipw.HBox([self._ps_drop_prj, self._ps_lbl_status])
        
        #build scan list box.
        self._sl_txt_pref=ipw.Text(value='',description='filter by subject name prefix:',
                                   disabled=False, style=style, layout={'width':'200pt'})             
        self._sl_btn_build=ipw.Button(description='Build scan list', disabled=True,layout={'width':'150px'})
        self._sl_btn_build.on_click(self._on_build_scan_list)
        self._sl_lbl_status=ipw.Label(value='status: ready',style=style,layout={'width':'500px'})
        self._sl_box=ipw.HBox([self._sl_txt_pref,self._sl_btn_build,self._sl_lbl_status])
        
        #text detection module.
        self._dt_btn_detect=ipw.Button(description='Detect text', disabled=True,layout={'width':'100px'})
        self._dt_btn_detect.on_click(self._on_detect_text)
        self._dt_lbl_status=ipw.Label(value='status: ready',style=style,layout={'width':'300px'})
        self._dt_boxU=ipw.HBox([self._dt_btn_detect,self._dt_lbl_status])
        self._dt_out=ipw.Output(value='',layout={'width':'600px','height':'100px','overflow_y':'auto'})
        self._dt_box=ipw.VBox([self._dt_boxU,self._dt_out])
        
        #scan page selection module.
        self._shs_pgsel=ipw.IntText(value=1,description='Scans with detected text, page', disabled=True,style=style,layout={'width':'300px'})
        self._shs_pgsel.observe(self._on_scan_page_changed)
        self._shs_lbl_ptot=ipw.Label(value='of 1',style=style,layout={'width':'100px'})
        self._shs_box=ipw.HBox([])
        
        #dtext scan list elements.
        self._dts_btn_rprt=ipw.Button(description='Generate csv report',disabled=True,layout={'width':'150px'})
        self._dts_btn_rprt.on_click(self._on_report_phi)
        self._dts_out_lnk=ipw.Output()
        
        #detected scan view
        self._vs_imcaption=ipw.Label(value='Viewing scan',style=style,layout={'width':'400px'})
        self._vs_im=ipw.Image(width=400, height=400,layout={'width':'1px','height':'400px'})
        self._vs_btn_phi=ipw.Button(description='Label as PHI',disabled=True,layout={'width':'150px'})
        self._vs_btn_phi.on_click(self._on_label_phi)
        
        self._vs_btn_excl=ipw.Button(description='Exclude this scan type',disabled=True,layout={'width':'200px'})        
        self._vs_btn_excl.on_click(self._on_excl_scantype)
        
        self._vs_boxL=ipw.VBox([])
        self._vs_boxRB=ipw.HBox([self._vs_btn_excl,self._vs_btn_phi])
        self._vs_boxR=ipw.VBox([self._vs_imcaption,self._vs_im,self._vs_boxRB])                
                
        self._vs_box=ipw.HBox([])
        
        self._box=ipw.VBox([self._ps_box,self._sl_box,self._dt_box,self._shs_box,self._vs_box])
        
        self._subjects=[]
        self._scans=[]
        self._scan_pages=0
        self._scans_dtext=[]
        self._wrkdir=""
        
        self._modules={'proj_sel':(1,self._ps_box), 'build_sl':(2,self._sl_box),'dtext':(3,self._dt_box),'selpage':(4,self._shs_box),'review':(5,self._vs_box)}
        self._scans_per_page=3
        self._batch_size=10
        
        self._tf_model=tf.keras.models.load_model(os.getcwd()+'/models/09.10.2019.on_5M.hd5')
        
        self._first_time=True


In [22]:
selector=Selector(ServerParams())
selector.show_login_form()

VBox(children=(HBox(children=(Text(value='https://xnat-dev-mga1.nrg.wustl.edu', description='XNAT server:', la…

VBox(children=(HBox(children=(Dropdown(description='project:', layout=Layout(width='200px'), options=(), style…

In [None]:
import os;
os.system('mkdir -p ~/.dtext/tmp')

In [11]:
d=datetime.datetime.now()