In [None]:
#Basic AstroCV example
#Detect galaxies on a sample image
#
#We provide a pre-compiled version of darknet to run remotely from 
#jupiterhub server, however for local use, you should download and
#compile:
#https://github.com/astroCV/darknet
#https://github.com/astroCV/pyyolo

In [None]:
import pyyolo
import numpy as np
import sys
from PIL import Image, ImageEnhance
import matplotlib.patches as patches
import matplotlib.pyplot as plt 
import time

In [None]:
darknet_path = './data/darknet' #darknet path
datacfg = '../sdss.data' #relative to darknet path
cfgfile = '../sdss.cfg' #relative to darknet path
weightfile = '/mnt/data/astrocv/galaxy_sdss_hic.weights' #path to the weights file # lupton rgb +2 brightness +2 contrast with ImageEnhance
filename = 'data/hic/1140_301_1_206.jpg' #image sample 1 hi contrast
#filename = 'data/hic/1045_301_2_129.jpg' #image sample 2 hi contrast
thresh = 0.2  #detection probability threshold
hier_thresh = 0.5 

In [None]:
t1=time.time()
pyyolo.init(darknet_path, datacfg, cfgfile, weightfile) #init and load network
print('Initialization time = %5.3f seconds'%(time.time()-t1))
t1=time.time() #actually loading the image take most of the time
outputs = pyyolo.test(filename, thresh, hier_thresh, 0) #load image and process
print('Load from file + Image processing time = %5.3f seconds'%(time.time()-t1))
for output in outputs:
        print(output)    
pyyolo.cleanup()

#plot image and detections
img = Image.open(filename)
contrast = ImageEnhance.Contrast(img)
img2 = contrast.enhance(1.5) # just for visualzation 
fig,ax = plt.subplots(figsize=(12,9))
plt.axis('off')
plt.tight_layout(pad=0)
plt.imshow(img2)
ax.set_aspect('equal')
for output in outputs:
        r=output['right']
        l=output['left']
        t=output['top']
        b=output['bottom']
        rect = patches.Rectangle((l-4,t-3),r-l+8,b-t+4,linewidth=1,edgecolor='b',facecolor='none')      
        ax.add_patch(rect)
        ax.annotate(output['class'],(l-7,t-18),color='w',fontsize=14)
plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0) 
#plt.savefig('writable/sample.jpg',dpi=180) 
plt.show()

In [None]:
%%javascript
IPython.notebook.kernel.restart();