# Common constants, classes and functions

In [3]:
import numpy as np
import pandas as pd
import re
import logging
import matplotlib.pyplot as plt
import itertools
from IPython.core.debugger import Tracer
from enum import Enum, IntEnum


# Common constants
EPS = 0.000001
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
DUMMY_AA = 'X'
AA_PAIRS = [(AMINO_ACIDS[i], AMINO_ACIDS[j]) for i in range(len(AMINO_ACIDS)) for j in range(i, len(AMINO_ACIDS))]


# Utility functions
def basename(path, ext=True):
    bn = os.path.basename(path)
    if not ext:
        bn = os.path.splitext(bn)[0]
    return bn

def desc_tensor(x):
    return '%s' % str((x.type(), x.shape))

def desc_ndarray(x):
    return '%s' % str((x.dtype, x.shape))

def totensor(x):
    return torch.tensorm(x)

# Common classes
class StrEnum(str, Enum):
    def __str__(self):
        return self.value

# class BindLevel(IntEnum):
#     POSITIVE_HIGH = 4
#     POSITIVE = 3
#     POSITIVE_INTERMEDIATE = 2
#     POSITIVE_LOW = 1
#     NEGATIVE = 0

#     @classmethod
#     def is_binder(cls, level):
#         return level > BindLevel.NEGATIVE

# BIND_LEVELS = list(BindLevel)

# class BindMetric(object):

#     def level(self, val):
#         raise NotImplementedError('Not implemented yet')

#     def is_binder(self, val):
#         return BindLevel.is_binder(self.level(val))

# class ContinuousBindMetric(BindMetric):

#     def __init__(self, cutoffs=None, comp_op=None):
#         self.cutoffs = cutoffs
#         self.comp_op = comp_op

#     def level(self, val):
#         for i, cutoff in enumerate(self.cutoffs):
#             if self.comp_op(val, cutoff):
#                 return BIND_LEVELS[i]
#         return BindLevel.NEGATIVE

# class BinaryBindMetric(BindMetric):
#     def __init__(self, bind_flags=(1, 0)):
#         self.bind_flags = bind_flags

#     def level(self, val):
#         return BindLevel.POSITIVE if val == self.bind_flags[0] else BindLevel.NEGATIVE

# DEFALUT_IC50_BM = ContinuousBindMetric(cutoffs=[100, 500, 1000, 5000], comp_op=np.less)
# DEFALUT_HALFALIVE_BM = ContinuousBindMetric(cutoffs=[240, 120, 50, 10], comp_op=np.greater)
# DEFALUT_BINARY_BM = BinaryBindMetric()

# def ic502prob(x):
#     return (1 - np.log(x) / np.log(50000)) if x <= 50000 else 0

# def prob2ic50(x):
#     return (np.exp((1 - x) * np.log(50000))) if x >= 0 else 50000


class PlotUtils(object):
    @staticmethod
    def plot_confusion_matrix(cm, classes,
                              normalize=False,
                              title='Confusion matrix',
                              cmap=plt.cm.Blues):
        """
        This function prints and plots the confusion matrix.
        Normalization can be applied by setting `normalize=True`.
        """
        plt.imshow(cm, interpolation='nearest', cmap=cmap)
        plt.title(title)
        plt.colorbar()
        tick_marks = np.arange(len(classes))
        plt.xticks(tick_marks, classes, rotation=45)
        plt.yticks(tick_marks, classes)

        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            print("Normalized confusion matrix")
        else:
            print('Confusion matrix, without normalization')

        print(cm)

        thresh = cm.max() / 2.
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, cm[i, j],
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black")

        plt.tight_layout()
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
    
class FileUtils(object):
    @staticmethod
    def pkl_save(fn, target):
        with open(fn, 'wb') as fh:
            pickle.dump(target, fh)
    
    @staticmethod
    def pkl_load(fn):
        with open(fn, 'rb') as fh:
            return pickle.load(fh)
    @staticmethod
    def rm_files(path, fn_ptrn):
#         print 'Pattern:', fn_ptrn
        for fn in os.listdir(path):
#             print 'Current file:', fn
            if re.match(fn_ptrn, fn) is not None:
#                 print 'Remove file:', fn
                os.remove('%s/%s' % (os.path.normpath(path), fn))

    @staticmethod
    def list_files(path, fn_ptrn):
        fns = []
        for fn in os.listdir(path):
            if re.match(fn_ptrn, fn) is not None:
                fns.append(fn)
        return fns
       
            
class StatUtils(object): 
    @staticmethod
    def minmax(x):
        return (min(x), max(x))
    
    @staticmethod
    def find_corr(tab, cutoff=0.8):
        corr = tab.corr()
        colnames = tab.columns
        target = []
        for i in range(len(colnames)):
            for j in range(i+1, len(colnames)):
                cur = np.abs(corr.values[i, j])
    #             print('Current:%s(%s), %s(%s): %s' % (colnames[i], i, colnames[j], j, cur))
                if cur >= cutoff:
    #                 print('%s\t%s\t%s' % (colnames[i], colnames[j], cur))
                    target.append(colnames[j])
    #                 print('Appended col:%s' % (colnames[j]))
        return np.unique(target)

    @staticmethod
    def almost_equals(f1, f2):
        return abs(f1 - f2) < EPS    
    
from IPython.display import display, display_html
    
class PrintUtils(object):
    @staticmethod
    def fullprint(*args, **kwargs):
        from pprint import pprint
        import numpy
        import pandas as pd
        opt = numpy.get_printoptions()
        max_rows = pd.options.display.max_rows
        numpy.set_printoptions(threshold='nan')
        pd.options.display.max_rows = None
        pprint(*args, **kwargs)
        numpy.set_printoptions(**opt)
        pd.options.display.max_rows = max_rows
        
    @staticmethod
    def display_side_by_side(*args):
        html_str=''
        for df in args:
            html_str+=df.to_html()
        display_html(html_str.replace('table','table style="display:inline"'),raw=True)

        
class LogUtils(object):
    @staticmethod
    def has_filehandler(loggername, filename):
        logger = logging.getLogger(loggername)
        for handler in logger.handlers:
            if type(handler) is logging.FileHandler:
                if handler.baseFilename == filename:
                    return True
        return False
    
    @staticmethod
    def get_default_logger(loggername, filename):
        logger = logging.getLogger(loggername)
        logger.propagate = False
        for handler in logger.handlers:
            if type(handler) is logging.FileHandler:
                if handler.baseFilename == filename:
                    return logger
        
        fh = logging.FileHandler(filename, mode='a')
        fh.setLevel(logging.DEBUG)
        fh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s: %(message)s'))
        logger.addHandler(fh)
        logger.setLevel(logging.DEBUG)
        return logger    

class StrUtils(object):
    @staticmethod
    def rm_nonwords(s):
        import re
        return re.sub('\\W', '', s)
    
    @staticmethod
    def empty(s):
        return (s is None) or (len(s) == 0)
    
class SeqUtils(object):
    @staticmethod
    def rand_aaseqs(N=10, seq_len=9):
        return [SeqUtils.rand_aaseq(seq_len) for i in range(N)]

    @staticmethod
    def rand_aaseq(seq_len=9):
        aas = np.asarray(list(AMINO_ACIDS))
        indices = np.random.choice(aas.shape[0], seq_len)
        return ''.join(aas[indices])
        
    @staticmethod
    def is_all_amino_acids(seq, allow_dummy=False):
        aas = AMINO_ACIDS
        if allow_dummy:
            aas = aas + DUMMY_AA
        return all([(aa in aas) for aa in seq])

    @staticmethod
    def write_fa(fn, seqs, headers=None):        
        with open(fn, 'w') as fh:
            fh.write(SeqUtils.format_fa(seqs, headers))
    
    @staticmethod
    def format_fa(seqs, headers=None):
        return '\n'.join(map(lambda h, seq: '>%s\n%s' % (h, seq), range(1, len(seqs)+1) if headers is None else headers, seqs))
    
from IPython.core.debugger import Tracer

class FastaSeqParseListener(object):
    def on_seq_read(self, header=None, seq=None):
        pass
            
class FastaSeqParser(object):
    def __init__(self):
        self._listeners = []
        
    def add_parse_listener(self, listener=None):
        self._listeners.append(listener)
        
    def remove_parse_listener(self, listener=None):
        self._listeners.remove(listener)
        
    def parse(self, in_stream, decode=None):
#         Tracer()()
        header = None
        seq = ''
        for line in in_stream:
            line = line.strip()
            if decode is not None:
                line = decode(line)
            if line.startswith('>'):
                if len(seq) > 0:
                    self._fire_seq_read(header=header, seq=seq)

                header = line[1:]
                seq = ''
            else:
                seq += line

        self._fire_seq_read(header=header, seq=seq)
            
    def _fire_seq_read(self, header=None, seq=None):
        for listener in self._listeners:
            listener.on_seq_read(header=header, seq=seq)
            
            
class TypeConvertUtils(object):
    @staticmethod
    def to_boolean(x):
#         Tracer()()
        if type(x) is int or type(x) is float:
            return x >= 1
        if type(x) is str:
            upper_x = x.upper()
            return upper_x in ['TRUE', 'T']
        return None
    
class ArrayUtils(object):
    @staticmethod
    def intersect2d(a1, a2):
        if len(a1.shape) != 2 or len(a2.shape) != 2:
            return None
        return np.array([x for x in set(tuple(x) for x in a1) & set(tuple(x) for x in a2)])
    
    @staticmethod
    def diff2d(a1, a2):
        if len(a1.shape) != 2 or len(a2.shape) != 2:
            return None
        return np.array([x for x in set(tuple(x) for x in a1) - set(tuple(x) for x in a2)])

    @staticmethod
    def union2d(a1, a2):
        if len(a1.shape) != 2 or len(a2.shape) != 2:
            return None
        return np.array([x for x in set(tuple(x) for x in a1) | set(tuple(x) for x in a2)])

class NumUtils(object):
    @staticmethod
    def padsize(t, s):
        if t > s:
            x = int((t - s)/2)
            return (x, t-s-x)
        else:
            return (0, 0)


from urllib import request
import ssl

class RemoteUtils(object):
    
    _ssl_context = ssl._create_unverified_context()
    
    @classmethod
    def download_to(cls, url, decode='utf-8', fnout=None):
        with request.urlopen(url, context=cls._ssl_context) as response, open(fnout, 'w') as fout:
            fout.write(response.read().decode(decode))
            

#################################
from unittest import *


class StrUtilsTest(TestCase):
    def test_empty(self):
        self.assertTrue(StrUtils.empty(None))
        self.assertTrue(StrUtils.empty(''))
        self.assertFalse(StrUtils.empty(' '))
        
class FileUtilsTest(TestCase):
    def test_rm_files(self):
        fns = FileUtils.list_files('tmp/', 'extest*')
        self.assertTrue(len(fns) > 0)
        FileUtils.rm_files('tmp/', 'extest*')
        fns = FileUtils.list_files('tmp/', 'extest*')
        self.assertTrue(len(fns) == 0)

class SeqUtilsTest(TestCase):
    def test_format_fa(self):
        seqs = ['AAA', 'BBB', 'CCC']
        headers = ['HA', 'HB', 'HC']
        expected_with_headers = '>HA\nAAA\n>HB\nBBB\n>HC\nCCC'
        expected_without_headers = '>1\nAAA\n>2\nBBB\n>3\nCCC'
        self.assertEquals(expected_with_headers, SeqUtils.format_fa(seqs=seqs, headers=headers))
        self.assertEquals(expected_without_headers, SeqUtils.format_fa(seqs=seqs))
        
    def test_is_all_amino_acids(self):
        self.assertTrue(SeqUtils.is_all_amino_acids('ACDEFRY'))
        self.assertFalse(SeqUtils.is_all_amino_acids('BQWPSMKYY'))
        self.assertFalse(SeqUtils.is_all_amino_acids('AQWPSMKYY-'))
        self.assertFalse(SeqUtils.is_all_amino_acids('ACDEFRY' + DUMMY_AA))
        self.assertTrue(SeqUtils.is_all_amino_acids('ACDEFRY' + DUMMY_AA, allow_dummy=True))

    def test_rand_seqs(self):
        print(SeqUtils.rand_aaseq(seq_len=15))

from io import StringIO

class FastaSeqParserTest(TestCase):
    class MyParserListener(FastaSeqParseListener):
        def __init__(self):
            self.headers = []
            self.seqs = []
        
        def on_seq_read(self, header=None, seq=None):
            print('Header:%s, Seq:%s' % (header, seq))
            self.headers.append(header)
            self.seqs.append(seq)
    
#     def setUp(self):
#         self.parser = FastaSeqParser()      
    def test_parse(self):
        parser = FastaSeqParser()
        listener = FastaSeqParserTest.MyParserListener()
        
        parser.add_parse_listener(listener)
        seqs = ['AAA', 'BBB', 'CCC']
        headers = ['HA', 'HB', 'HC']
        fasta = SeqUtils.format_fa(seqs=seqs, headers=headers)
        
        parser.parse(StringIO(fasta))
        
        self.assertTrue(np.array_equal(headers, listener.headers))
        self.assertTrue(np.array_equal(seqs, listener.seqs))
        
class TypeConvertUtilsTest(TestCase):
    def test_to_boolean(self):
        self.assertTrue(TypeConvertUtils.to_boolean(1))
        self.assertTrue(TypeConvertUtils.to_boolean(2))
        self.assertTrue(TypeConvertUtils.to_boolean('True'))
        self.assertTrue(TypeConvertUtils.to_boolean('T'))
        self.assertFalse(TypeConvertUtils.to_boolean(0))
        self.assertFalse(TypeConvertUtils.to_boolean('False'))
        self.assertFalse(TypeConvertUtils.to_boolean('F'))

class ArrayUtilsTest(TestCase):
    def test_intersect2d(self):
        self.assertTrue(np.all(np.array([[1, 2], [4, 5]]) == ArrayUtils.intersect2d(np.array([[1, 2], [2, 3], [4, 5]]), np.array([[1, 2], [4, 5]]))))

    def test_diff2d(self):
        self.assertTrue(np.all(np.array([[2, 3]]) == ArrayUtils.diff2d(np.array([[1, 2], [2, 3], [4, 5]]), np.array([[1, 2], [4, 5]]))))

    def test_union2d(self):
        print(ArrayUtils.union2d(np.array([[1, 2], [2, 3], [4, 5]]), np.array([[1, 2], [4, 5]])))

class NumUtilsTest(TestCase):
    def test_padsize(self):
        self.assertEquals((3, 4), NumUtils.padsize(15, 8))
        self.assertEquals((2, 3), NumUtils.padsize(15, 10))
        self.assertEquals((11, 11), NumUtils.padsize(37, 15))
        self.assertEquals((7, 8), NumUtils.padsize(24, 9))
        self.assertEquals((0, 0), NumUtils.padsize(10, 15))
        self.assertEquals((0, 0), NumUtils.padsize(9, 9))
        self.assertEquals((0, 0), NumUtils.padsize(9, 10))

class BindMeasTest(TestCase):
    
# DEFALUT_IC50_BM = ContinuousBindMetric(cutoffs=[100, 500, 1000, 5000], comp_op=np.less)
# DEFALUT_HALFALIVE_BM = ContinuousBindMetric(cutoffs=[240, 120, 50, 10], comp_op=np.greater)
    
    def test_bind_metric(self):
        ic50_bm = DEFALUT_IC50_BM
        halfalive_bm = DEFALUT_HALFALIVE_BM
        binary_bm = DEFALUT_BINARY_BM

        self.assertEqual(BindLevel.POSITIVE_HIGH, ic50_bm.level(99))
        self.assertEqual(BindLevel.POSITIVE, ic50_bm.level(100))
        self.assertEqual(BindLevel.POSITIVE_INTERMEDIATE, ic50_bm.level(500))
        self.assertEqual(BindLevel.POSITIVE_INTERMEDIATE, ic50_bm.level(999))
        self.assertEqual(BindLevel.POSITIVE_LOW, ic50_bm.level(1000))
        self.assertEqual(BindLevel.POSITIVE_LOW, ic50_bm.level(4999))
        self.assertEqual(BindLevel.NEGATIVE, ic50_bm.level(5000))

        self.assertEqual(BindLevel.POSITIVE_HIGH, halfalive_bm.level(241))
        self.assertEqual(BindLevel.POSITIVE, halfalive_bm.level(240))
        self.assertEqual(BindLevel.POSITIVE, halfalive_bm.level(121))
        self.assertEqual(BindLevel.POSITIVE_INTERMEDIATE, halfalive_bm.level(120))
        self.assertEqual(BindLevel.POSITIVE_INTERMEDIATE, halfalive_bm.level(51))
        self.assertEqual(BindLevel.POSITIVE_LOW, halfalive_bm.level(50))
        self.assertEqual(BindLevel.POSITIVE_LOW, halfalive_bm.level(11))
        self.assertEqual(BindLevel.NEGATIVE, halfalive_bm.level(10))
        
        self.assertEqual(BindLevel.POSITIVE, binary_bm.level(1))
        self.assertEqual(BindLevel.NEGATIVE, binary_bm.level(0))

        self.assertTrue(ic50_bm.is_binder(49))
        self.assertTrue(ic50_bm.is_binder(499))
        self.assertFalse(ic50_bm.is_binder(500))
        self.assertFalse(ic50_bm.is_binder(501))
        self.assertTrue(halfalive_bm.is_binder(241))
        self.assertTrue(halfalive_bm.is_binder(240))
        self.assertFalse(halfalive_bm.is_binder(120))
        self.assertFalse(halfalive_bm.is_binder(10))
        self.assertTrue(binary_bm.is_binder(1))
        self.assertFalse(binary_bm.is_binder(0))
        
    def test_ic502prob(self):
        x = 500
        p = ic502prob(x)
        self.assertAlmostEqual(x, prob2ic50(p), delta=EPS)

class RemoteUtilsTest(TestCase):
    def test_download_to(self):
        fn_test = 'tmp/test.fa'
        if os.path.exists(fn_test): 
            os.unlink(fn_test)
        
        url = 'https://www.ebi.ac.uk/ipd/mhc/group/BoLA/download/BoLA/NC2?type=protein' 

        RemoteUtils.download_to(url, fnout=fn_test)
        
        self.assertTrue(os.path.exists(fn_test))
        self.assertTrue(os.path.getsize(fn_test) > 0)
#         os.unlink(fn_test)
  
##################################
# suite = TestSuite()
# # suite.addTests(TestLoader().loadTestsFromTestCase(StrUtilsTest))
# # suite.addTests(TestLoader().loadTestsFromTestCase(BindMeasTest))
# # suite.addTests(TestLoader().loadTestsFromTestCase(RemoteUtilsTest))
# # suite.addTests(TestLoader().loadTestsFromTestCase(SeqUtilsTest))
# # suite.addTests(TestLoader().loadTestsFromTestCase(FastaSeqParserTest))
# # suite.addTests(TestLoader().loadTestsFromTestCase(NumUtilsTest))
# TextTestRunner(verbosity=3).run(suite)

ok

----------------------------------------------------------------------
Ran 1 test in 0.002s

OK


<unittest.runner.TextTestResult run=1 errors=0 failures=0>