# Utilities for sorting and classifiying STED images of spot spairs

## Imports and defs

In [2]:
import os
import re
import sys
import shutil
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt
from scipy.ndimage import gaussian_filter, imread, median_filter, gaussian_laplace, sobel
from skimage.feature import peak_local_max
from scipy.spatial import kdtree
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.svm import SVC
from sklearn.cross_validation import cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib
import pickle

%matplotlib inline
plt.rcParams["figure.figsize"] = [6, 6]

def mkdir_if_necessary(path):
    if not os.path.exists(path):
        os.makedirs(path)
        
        
def exists_with_postfix(path, postfix=".jpg"):
    return os.path.exists(path + postfix)

def sort_overviews(d):
    files = next(os.walk(d))[2]

    ov_d = os.path.join(d, 'overviews')
    mkdir_if_necessary(ov_d)

    p = re.compile('.*?field.*?sted.*?')

    for f in files:
        if not re.match(p, f):
            shutil.move(os.path.join(d, f), os.path.join(ov_d, f))
            
def recommend_quality(im, thresh_brightest = 10, max_dist=25):
    i1 = im[:,:,0]
    g1 = gaussian_filter(i1, 1)
    i2 = im[:,:,1]
    g2 = gaussian_filter(i2, 1)
    p1 = peak_local_max(g1, min_distance=2)
    p2 = peak_local_max(g2, min_distance=2)
    
    p1i = sorted([(i1[p1[i,0], p1[i,1]], i) for i in range(len(p1))], key=lambda x: x[0], reverse=True)
    p2i = sorted([(i2[p2[i,0], p2[i,1]], i) for i in range(len(p2))], key=lambda x: x[0], reverse=True)
    
    if p1i[0][0] < thresh_brightest / 2 * np.mean(i1):
        print('BAD: channel 1 dark')
        return 'b'
    
    if p2i[0][0] < thresh_brightest / 2 * np.mean(i2):
        print('BAD: channel 2 dark')
        return 'b'
    
    if p1i[0][0] < thresh_brightest * np.mean(i1):
        print('MEDIOCRE: channel 1 dark')
        return 'm'
    
    if p2i[0][0] < thresh_brightest * np.mean(i2):
        print('MEDIOCRE: channel 2 dark')
        return 'm'
        
    
    halflife1 = sum([p1i[i][0] > 0.67 * p1i[0][0] for i in range(len(p1i))]) 
    halflife2 = sum([p2i[i][0] > 0.67 * p2i[0][0] for i in range(len(p2i))]) 
    
    print('Found ' + str(halflife1) + ' candidate peaks in channel 1')
    print('Found ' + str(halflife2) + ' candidate peaks in channel 2')
    
    if (halflife1 > 5):
        print('BAD: found too many peaks in channel 1')
        return 'b'
    
    if (halflife2 > 5):
        print('BAD: found too many peaks in channel 2')
        return 'b'
    
    p1good = [p1[p1i[i][1]] for i in range(halflife1)]
    p2good = [p2[p2i[i][1]] for i in range(halflife2)]
    tree = kdtree.KDTree(p1good)
    
    mindist = np.min(tree.query(p2good)[0])
    print('approximate minimal distance: ' + str(mindist))
    
    if (mindist > max_dist):
        print('MEDIOCRE: minimal distance too high')
        return 'm'
    
    print('GOOD')
    return 'g'

def getfeatures(img):
    
    i1 = im[:,:,0]
    g1 = gaussian_filter(i1, 1)
    i2 = im[:,:,1]
    g2 = gaussian_filter(i2, 1)
    p1 = peak_local_max(g1, min_distance=2)
    p2 = peak_local_max(g2, min_distance=2)
    
    p1i = sorted([(i1[p1[i,0], p1[i,1]], i) for i in range(len(p1))], key=lambda x: x[0], reverse=True)
    p2i = sorted([(i2[p2[i,0], p2[i,1]], i) for i in range(len(p2))], key=lambda x: x[0], reverse=True)
    
    halflife1 = sum([p1i[i][0] > 0.67 * p1i[0][0] for i in range(len(p1i))]) 
    halflife2 = sum([p2i[i][0] > 0.67 * p2i[0][0] for i in range(len(p2i))])
    
    #print(halflife1)
    #print(halflife2)
    
    p1good = [p1[p1i[i][1]] for i in range(halflife1)]
    p2good = [p2[p2i[i][1]] for i in range(halflife2)]
    tree = kdtree.KDTree(p1good)
    
    q = tree.query(p2good)
    
    m2 = np.argmin(q[0])
    m1 = q[1][m2]
    d = q[0][m2]
    
    features = [np.mean(i1), np.mean(i2), np.var(i1), np.var(i2), halflife1, halflife2, d]
    
    for sigma in [0.7, 1 , 1.5 , 2.25 , 3.5, 5]:
        features.append(gaussian_filter(i1,sigma)[tuple(p1good[m1])])
        features.append(gaussian_filter(i2,sigma)[tuple(p2good[m2])])
        features.append(gaussian_laplace(i1,sigma)[tuple(p1good[m1])])
        features.append(gaussian_laplace(i2,sigma)[tuple(p2good[m2])])
    
    features.append(sobel(i1)[tuple(p1good[m1])])
    features.append(sobel(i2)[tuple(p2good[m2])])
    features.append(i1[tuple(p1good[m1])])
    features.append(i2[tuple(p2good[m2])])
    
    return [float(f) for f in features]

def predict_ml(img, sc, cls):
    feat = np.array(getfeatures(img)).reshape(1,-1)
    return ['good', 'bad', 'mediocre'][cls.predict(sc.transform(feat))]

## Classifier Training

In [3]:
# init features
features = []
classes = []

In [4]:
# calculate features and classes from
# list of dictionaries to use as training data

# dicts with 30 percent STED images
ds = [
    '/Users/david/Desktop/8th_shipment_20170130/mixed_HS1_HS4_B/K562_B/',
    '/Users/david/Desktop/8th_shipment_20170130/mixed_HS1_HS4_B/GM_B/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS1_HS4_A/GM_A/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS1_HS4_A/K562_A/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS1_HS4_B/GM_B/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS1_HS4_B/K562_B/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS2_HBG2_A/GM_A/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS2_HBG2_A/K562_A/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS2_HBG2_B/GM_B/',
    '/Users/david/Desktop/9th_shipment_20170216/mixed_HS2_HBG2_B/K562_B/',
    '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170221/HS1_HS4_A/K562_Apos1/',
    '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170221/HS1_HS4_A/K562_Apos2/',
    '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170222/HS1_HS4_B/K562_Bpos1/',
    '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170222/HS1_HS4_B/K562_Bpos2/'
    ]

for d in ds:
    for di, _ , fl in os.walk(d):
        for f in fl:
            if f.endswith('.jpg') and di.split(os.sep)[-1] in ['good', 'bad', 'mediocre']:
                im = imread(os.path.join(di, f))
                features.append(getfeatures(im))
                classes.append(['good', 'bad', 'mediocre'].index(di.split(os.sep)[-1]))

In [5]:
# generate scaler and Random Forest classifier

sc = StandardScaler()
sc.fit(features)

cls = RandomForestClassifier(n_estimators=100)

print('Mean classifier accuracy (10-fold c.v.): ' + 
      str(np.mean(cross_val_score(cls, sc.transform(features), [0 if x == 0 else 1 for x in classes], cv=10))))
cls.fit(sc.transform(features), [0 if x == 0 else 1 for x in classes])

Mean classifier accuracy (10-fold c.v.): 0.882122779508


RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)

In [6]:
# save classifier and scaler

with open('/Users/david/Desktop/scaler_30sted_2.pks', 'wb') as fd:
    pickle.dump(sc, fd)
    
with open('/Users/david/Desktop/goodbadclassifier_30sted_2.pks', 'wb') as fd:
    pickle.dump(cls, fd)

## Loading a preexisting classifier

In [3]:
with open('/Users/david/Desktop/scaler_30sted_2.pks', 'rb') as fd:
    sc = pickle.load(fd)

with open('/Users/david/Desktop/goodbadclassifier_30sted_2.pks', 'rb') as fd:
    cls = pickle.load(fd)

# Actual sorting

In [64]:
### 1: set the directory to process
#dir_to_process = os.path.join(os.getcwd(), 'AutomatedAcquisitions')
dir_to_process = '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170306/HS2_HBG2/K562/'

In [34]:
### 2: move all the overview files into a separate folder
dirs = [d for d in next(os.walk(dir_to_process))[1] if not d.startswith('.')]
print(dirs)

for d in dirs:
    sort_overviews(os.path.join(dir_to_process, d))

['HS2_HBG2', 'HS2Delta_HBG2']


In [62]:
### 3: set subfolder to process
d = os.path.join(dir_to_process, 'K562_HS2Delta')

## Sorting, ML assisted if classifier present

In [None]:
### 4: SORTING into good/bad/mediocre


gd_d = os.path.join(d, 'good')
bd_d = os.path.join(d, 'bad')
md_d = os.path.join(d, 'mediocre')

mkdir_if_necessary(gd_d)
mkdir_if_necessary(bd_d)
mkdir_if_necessary(md_d)

files = [f for f in next(os.walk(d))[2] if f.endswith('.msr')]

for fi in files:
    f = os.path.join(d,fi)
    im = imread(f + ".jpg")
    
    rec = None
    if sc != None and cls != None:
        rec = predict_ml(im, sc, cls)
    
    plt.imshow(im)
    plt.show()
    
    if rec != None:
        print(rec.upper())
    
    print('-----')
    sys.stdout.flush()
    decision = input("ISGOOD? [(g)ood/(b)ad/(m)ediocre] :") or rec
    dec = decision.upper()[0]
        
    if dec == "G":
        shutil.move(os.path.join(d, fi), os.path.join(gd_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(gd_d, fi + ".jpg"))
    elif dec == "B":
        shutil.move(os.path.join(d, fi), os.path.join(bd_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(bd_d, fi + ".jpg"))
    elif dec == "M":
        shutil.move(os.path.join(d, fi), os.path.join(md_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(md_d, fi + ".jpg"))
    else:
        pass

## Automated sorting, needs classifier

In [76]:
### 4a: SORTING into good/bad/(mediocre) AUTOMATED
d = '/Users/david/Desktop/VisitFebMar2017/visit_feb_20170311/K562/HS2_HBG2/'

gd_d = os.path.join(d, 'good')
bd_d = os.path.join(d, 'bad')
md_d = os.path.join(d, 'mediocre')

mkdir_if_necessary(gd_d)
mkdir_if_necessary(bd_d)
mkdir_if_necessary(md_d)

files = [f for f in next(os.walk(d))[2] if f.endswith('.msr')]

for fi in files:
    f = os.path.join(d,fi)
    im = imread(f + ".jpg")
    
    rec = predict_ml(im, sc, cls)
    dec = rec.upper()[0]
    
    if dec == "G":
        shutil.move(os.path.join(d, fi), os.path.join(gd_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(gd_d, fi + ".jpg"))
    elif dec == "B":
        shutil.move(os.path.join(d, fi), os.path.join(bd_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(bd_d, fi + ".jpg"))
    elif dec == "M":
        shutil.move(os.path.join(d, fi), os.path.join(md_d, fi))
        shutil.move(os.path.join(d, fi + ".jpg"), os.path.join(md_d, fi + ".jpg"))

