# X2K_Web Genetic Algorithm

## X2K Web API

In [None]:
import pandas as pd
import http.client
import json

all_x2k_options = {
    'TF-target gene background database used for enrichment': [
        'ChEA 2015',
        'ENCODE 2015',
        'ChEA & ENCODE Consensus',
        'Transfac and Jaspar',
        'ChEA 2016',
        'ARCHS4 TFs Coexp',
        'CREEDS',
        'Enrichr Submissions TF-Gene Coocurrence',
    ],
    'kinase interactions to include': [#kea 2016
        'kea 2018',
        'ARCHS4',
        'iPTMnet',
        'NetworkIN',
        'Phospho.ELM',
        'Phosphopoint',
        'PhosphoPlus',
        'MINT',
    ],
    'enable_ppi': [
        'ppid',
        'Stelzl',
        'IntAct',
        'MINT',
        'BioGRID',
        'Biocarta',
        'BioPlex',
        'DIP',
        'huMAP',
        'InnateDB',
        'KEGG',
        'SNAVI',
        'iREF',
        'vidal',
        'BIND',
        'figeys',
        'HPRD',
    ],
    'max_number_of_interactions_per_article':  {"10":15, "01":50, "11":200, "00":1000000},
    'max_number_of_interactions_per_protein': {"10":50, "01":100, "11":200, "00":500},
    'min_network_size': {"10":1, "01":10, "11":50, "00":100},
    'min_number_of_articles_supporting_interaction': {"10":0, "01":1, "11":5, "00":10},
    'path_length': {"0":1, "1":2},
    'included organisms in the background database': {"10": "human", "01": "mouse", "11": "both", "00": "RESHUFFLE"},
}

def run_X2K(input_genes, x2k_options={}):
    # Open HTTP connection
    conn = http.client.HTTPConnection("amp.pharm.mssm.edu") #
    #conn = http.client.HTTPConnection("localhost:8080", timeout=20)
    # Get default options
    default_options = {'text-genes': '\n'.join(input_genes), 'included_organisms': 'both', 'included_database': 'ChEA 2015',
                       'path_length': 2, 'minimum network size': 50, 'min_number_of_articles_supporting_interaction': 2,
                       'max_number_of_interactions_per_protein': 200, 'max_number_of_interactions_per_article': 100,
                       'biocarta': True, 'biogrid': True, 'dip': True, 'innatedb': True, 'intact': True, 'kegg': True, 'mint': True,
                       'ppid': True, 'snavi': True, 'number_of_results': 50, 'sort_tfs_by': 'combined score', 'sort_kinases_by': 'combined score',
                       'kinase interactions to include': 'kea 2018'}
    # Update options
    for key, value in x2k_options.items():
        if key in default_options.keys() and key != 'text-genes':
            default_options.update({key: value})
    # Get payload
    boundary = "----WebKitFormBoundary7MA4YWxkTrZu0gW"
    payload = ''.join(['--'+boundary+'\r\nContent-Disposition: form-data; name=\"{key}\"\r\n\r\n{value}\r\n'.format(**locals()) for key, value in default_options.items()])+'--'+boundary+'--'
    # Get Headers
    headers = {
        'content-type': "multipart/form-data; boundary="+boundary,
        'cache-control': "no-cache",
    }
    # Initialize connection
    conn.request("POST", "/X2K/api", payload, headers)
    # Get response
    res = conn.getresponse()
    # Read response
    data = res.read().decode('utf-8')
    # Convert to dictionary
    x2k_results = {key: json.loads(value) if key != 'input' else value for key, value in json.loads(data).items()}
    # Clean results
    x2k_results['ChEA'] = x2k_results['ChEA']['tfs']
    x2k_results['G2N'] = x2k_results['G2N']['network']['nodes']
    x2k_results['KEA'] = x2k_results['KEA']['kinases']
    x2k_results['X2K'] = x2k_results['X2K']['network']
    # Return results
    return x2k_results


In [None]:
def parse_GEO_line(line):
    lineSp = line.split('\t')
    expt_name = lineSp[0]
    genes = [str(x.strip(',1.0')) for x in lineSp[2:-1]]
    return expt_name, genes

def prepare_options_for_x2k(input_genes, x2k_parameters):
    options=x2k_parameters.copy()
    for param in options:
        options[param] = options[param]['selection']
    # Add input_genes
    options['text-genes'] = input_genes
    # Convert ppi into enable flags
    for ppi in options['enable_ppi']:
        options['enable_' + ppi] = 'true'
    del options['enable_ppi']
    # Convert any lists
    return {
        k: '\n'.join(v) if type(v) == list else str(v)
        for k, v in options.items() 
    }

def reshuffle(x2k_parameters):
    import random
    new_options = x2k_parameters.copy()
    for param in new_options:
        selection = new_options[param]['selection']
        while selection == "RESHUFFLE":
            selection = random.choice( list(all_x2k_options[param].values()) )
            new_options[param]['selection'] = selection
    return new_options

def translateDatabases(binaryString_segment, _dbs):
    selection = []
    for i, bit in enumerate(binaryString_segment):
        if bit == "1":
            selection.append(_dbs[i])
    return selection

def parameters_to_binary(x2k_parameters):
    newBinary=[]
    for param in all_x2k_options:
        newBinary.append( x2k_parameters[param]['bits'] )
    return ''.join(newBinary)

def binary_to_parameters(binaryString):
    x2k_parameters={}
    stringCount = 0
    for param in all_x2k_options:
        # Database lists
        if param in ['TF-target gene background database used for enrichment','kinase interactions to include','enable_ppi']:
            dbList = all_x2k_options[param]
            bitSegment = binaryString[stringCount:stringCount + len(dbList)]
            selection = translateDatabases(bitSegment, dbList)
            x2k_parameters[param] = {'selection':selection, 'bits':bitSegment}
            stringCount += len(selection)
        # All other parameters
        else:
            paramDict = all_x2k_options[param]
            bitLength = len(list(paramDict.keys())[0])
            bits = binaryString[stringCount:stringCount + bitLength]
            selection = paramDict[bits]
            x2k_parameters[param] = {'selection':selection, 'bits':bits}
            stringCount += bitLength
    # Reshuffle
    x2k_parameters = reshuffle(x2k_parameters)
    newBinary = parameters_to_binary(x2k_parameters)
    return x2k_parameters, newBinary
    
 
############ Parallel processing of X2K ############
def run_X2K_once(x2k_input):
    x2k_options = x2k_input['options']
    input_genes = x2k_options['text-genes'].split('\n')
    expt_name = x2k_input['expt_name']
    try:
        #print(expt_name)
        x2k_results = run_X2K(input_genes=input_genes, x2k_options=x2k_options)
        # x2k_results['x2k_options'] = x2k_options 
        # x2k_results[expt_name] = x2k_results
    except:
        #print('Couldnt process '+expt_name)
        x2k_results='NA'
    return {'experiment':expt_name,'results':x2k_results}

def prepare_all_x2k_inputs(binaryString, gmtLines):
    all_x2k_inputs=[] 
    for i,line in enumerate(gmtLines):
        expt_name, input_genes = parse_GEO_line(line)
        x2k_parameters, newBinary = binary_to_parameters(binaryString)
        x2k_options = prepare_options_for_x2k(input_genes, x2k_parameters)
        all_x2k_inputs.append( {'options':x2k_options, 'expt_name':expt_name} )
    return all_x2k_inputs, newBinary

import os
def parallel_x2k_results(binaryString, gmtLimit=False):
    all_x2k_results = {}
    with open('Genetic_Algorithm/testgmt/'+os.listdir('Genetic_Algorithm/testgmt')[0]) as gmt_file:
        gmtLines = gmt_file.readlines()
    if gmtLimit!=False:
        gmtLines = gmtLines[0:gmtLimit]
    # Prepare parameters (has to be done iterate
    all_x2k_inputs, newBinary = prepare_all_x2k_inputs(binaryString, gmtLines)
    
    # ********* Parallelize X2K across experiments ********* #
    # ****************************************************** #
    from multiprocessing.dummy import Pool as ThreadPool
    pool = ThreadPool(20)
    raw_x2k_results = pool.map(run_X2K_once, all_x2k_inputs)
    pool.close() 
    pool.join() 
    
    # Post-process pooled results
    allExpts_x2k_results={}
    for dict in raw_x2k_results:
        allExpts_x2k_results[ dict['experiment'] ] = dict['results']  
    """
    binaryString=createPopulation(1)[0]
    """
    return all_x2k_results, newBinary


def get_x2k_results(binaryString, gmtLimit=False):
    import os
    allExpts_x2k_results = {}
    errors=[]

    with open('Genetic_Algorithm/testgmt/'+os.listdir('Genetic_Algorithm/testgmt')[0]) as gmt_file:
        gmt = gmt_file.readlines()
    if gmtLimit!=False:
        gmt = gmt[0:gmtLimit]
    for i,line in enumerate(gmt): # PARALLELIZE
        # Get experiment name and input genes
        expt_name, input_genes = parse_GEO_line(line)
        ## Standardize input genes
        #input_genes = [standardizeGeneSymbol(g) for g in input_genes]
        #print(str(i)+" : "+expt_name)
        # Prepare options
        x2k_parameters, newBinary = binary_to_parameters(binaryString=binaryString)
        x2k_options = prepare_options_for_x2k(input_genes, x2k_parameters)
        # Run x2k API
        try:
            x2k_results = run_X2K(input_genes=input_genes, x2k_options=x2k_options)
            # Modify results 
            x2k_results['x2k_options'] = x2k_options
            #x2k_results['binaryString'] = binaryString 
            #x2k_results['newBinary'] = newBinary
            allExpts_x2k_results[expt_name] = x2k_results
        except:
            #print("^ couldn't process: skipping")
            errors.append(expt_name) 
            continue
    # print("ERRORS: ")
    # print(errors)
    # print()
    return allExpts_x2k_results, newBinary

## Run X2K GA

### GA Support Functions

In [None]:
###################################
# 0. Create initial population
###################################
def stringLength():
    string_length=0
    for key in all_x2k_options.keys():
        string_length += len(all_x2k_options[key])
    return  string_length

def createPopulation(popSize):
    from random import choice
    binaryStringLength = stringLength()
    populationinit = []
    for i in range(popSize):
        populationinit.append(''.join(choice(('0', '1')) for _ in range(binaryStringLength)) )
        print(populationinit[i])
    return populationinit

###################################
# 1. Calculate fitness
###################################
def pvalue_matrix(all_x2k_results, dataType='KEA' ):
    nameKey = {'ChEA':'simpleName','KEA':'name', 'G2N':'name'}
    ## dict_keys(['X2K', 'ChEA', 'KEA', 'G2N', 'input', 'Experiment', 'x2k_options', 'binaryString'])
    # Experiment -> Kinase -> kinase results
    pvalDict={} 
    for expt in all_x2k_results:
        results = all_x2k_results[expt][dataType]
        if dataType == 'G2N':
            for g in results:
                g['name'] = g['name'].split("-")[0]
        predictedKinases = [y[nameKey[dataType]] for y in results]
        predictedPvals = [y['pvalue'] for y in results]
        # if replaceNAs==True:
        #     predictedPvals = [1.0 if math.isnan(x) else x for x in predictedPvals]
        pvalDict[expt] = dict(zip(predictedKinases, predictedPvals))
     return pvalDict


def population_fitness(gen, population, fitness_method='target_shuffled_difference', fitnessDict={}, gmtLimit=False, parallel=True):
    population_results={}
    newFitnessDict = fitnessDict.copy()
    for i,binaryString in enumerate(population):
        unique_id = "ind"+str(i)+"_gen"+str(gen)
        print(unique_id)
        if binaryString in newFitnessDict:
            print("Pulling info from fitnessDict")
            population_results[unique_id] = newFitnessDict[binaryString]
        else:
            if parallel==True:
                all_x2k_results, newBinary = parallel_x2k_results(binaryString=binaryString, gmtLimit=gmtLimit)
            else:
                all_x2k_results, newBinary = get_x2k_results(binaryString=binaryString, gmtLimit=gmtLimit)

            CHEA_pvalDict = pvalue_matrix(all_x2k_results, 'ChEA')
            KEA_pvalDict = pvalue_matrix(all_x2k_results, 'KEA')
            population_results[unique_id] = {'generation':gen, 'binaryString':binaryString, 'newBinary':newBinary,
                                             'fitness':eval(fitness_method)(pd.DataFrame(KEA_pvalDict)), 
                                             'KEA_results':KEA_pvalDict, 'CHEA_results':CHEA_pvalDict} 
            #Add to fitness dictionary
            newFitnessDict[binaryString] = population_results[unique_id]
    return population_results, newFitnessDict

###################################
# 2. Select fittest individuals
###################################
def selectFittest(topNum, GA_results, selectionMethod='Fitness-proportional'):
    import pandas as pd
    fitDF = pd.DataFrame(GA_results).T
    import pandas as pd
    if selectionMethod == 'Fitness-proportional':
        fittestDF = fitDF.sort_values(by=['fitness'],ascending=False).iloc[:topNum,:]
        print("Top fitnesses:  " + str(fittestDF["fitness"].values))
    # Tournament selection (less stringent)
    ## Split the population into equal subgroups, and then select the fittest individual from each group
    elif selectionMethod == 'Tournament':
        fittestDF=pd.DataFrame()
        if fitDF.shape[0] % topNum!=0:
            print("Tournament selection requires that populationSize/topNum and childrenPerGeneration/topNum are both whole numbers.")
        subsetSize = int( fitDF.shape[0] / topNum )
        for t in range(topNum):
            subDF = fitDF.sample(n=subsetSize, replace=False)
            fittestDF = fittestDF.append( subDF.sort_values(by=['fitness'], ascending=False).iloc[0,:].copy())

    elif selectionMethod == 'mixedTournament':
        if fitDF.shape[0]%topNum!=0 or topNum%2!=0:
            print("WARNING:: Tournament selection requires that populationSize/topNum, \n"
                  +"childrenPerGeneration/topNum, and topNum/2 to be whole numbers.")
        topNumHalf = int(topNum/2)
        sortedDF = fitDF.sort_values(by=['fitness'], ascending=False).copy()
        # The first half of the new pop are the fittest parents overall
        fittestDF = sortedDF.iloc[:topNumHalf, :].copy()
        # Then run Tournament selection on the rest of the population to get the other half of the new pop
        everybodyElse = sortedDF.iloc[topNumHalf:, :].copy()
        subsetSize = int(everybodyElse.shape[0] / topNumHalf)
        for t in range(topNumHalf):
            subDF = everybodyElse.sample(n=subsetSize, replace=False)
            fittestDF = fittestDF.append(subDF.sort_values(by=['fitness'], ascending=False).iloc[0, :].copy())
    else:
        print("Use viable 'selectionMethod'")
    return fittestDF

###################################
# 3. Crossover/breed fittest
###################################
###################################
# 4. Introduce random mutations
###################################
# individual1=  population[0]
# individual2=  population[1]
# crossoverPoints=8
def createChild(individual1, individual2, crossoverPoints, crossoverLocations="evenlyDistributed"):
    if crossoverLocations=="evenlyDistributed":
        chunkSize = int(len(individual1) / (crossoverPoints+1))
        ind1Split = [individual1[i:i + chunkSize] for i in range(0, len(individual1), chunkSize)]
        ind2Split = [individual2[i:i + chunkSize] for i in range(0, len(individual2), chunkSize)]
    elif crossoverLocations=='random':
        from random import sample
        cutpoints = sorted(sample(range(1, len(individual1)-1), crossoverPoints)) # randomly generate n non-overlapping numbers
        def splitParent(parent, cutpoints):
            indSplit=[]
            for i,num in enumerate(cutpoints):
                #print("**Cutpoint index= "+str(i))
                if i == 0: # If it's the first cutpoint, take all values up to the first index+1
                    start = 0
                    end = num
                else:
                    start = cutpoints[i-1]
                    end = num
                segment = parent[start:end]
                #print("Cutpoint= " + str(start) + " : " + str(end))
                #print("------- "+segment+" -------")
                indSplit.append(segment)
            # Add the very last segment
            indSplit.append(parent[cutpoints[-1]:])
            return indSplit
        ind1Split = splitParent(individual1, cutpoints)
        ind2Split = splitParent(individual2, cutpoints)
    # Put together the new child
    from random import random
    childFragments=[]
    for fragment in range(len(ind1Split)):
        if int(100*random()) < 50: # Just randomly picks from ParentA or ParentB for each individual parameter
            childFragments.append(ind1Split[fragment])
        else:
            childFragments.append(ind2Split[fragment])
    child = "".join(childFragments)
    return child
# child = createChild( fittestDF['newBinary'].values[0] , fittestDF['newBinary'].values[1], 3)

def mutateChild(child, mutationRate):
    from random import random
    mutant = ''
    for bit, val in enumerate(child):
        rando = random()
        if rando <= mutationRate and val == '1':
            mutant = str(str(mutant) + '0')
        elif rando <= mutationRate and val == '0':
            mutant = str(str(mutant) + '1')
        else:
            mutant = str(str(mutant) + str(val))
    return mutant

def createChildren(numberOfChildren, fittestDF, mutationRate, breedingVariation, crossoverLocations):
    from random import random
    fittest = fittestDF.newBinary.tolist()
    #breedingChances = []
    # Add noise to fitness score?
    # for b in range(len(Fittest)):
    #     breedingChances.append(np.random.uniform(1 + breedingVariation, 1 - breedingVariation) * int(fittestFitness[b]))
    #     topBreeders = [x for _, x in sorted(zip(breedingChances, Fittest), reverse=True)]
    # Breed n times
    # 'Once you're in, you're in'. After selecting the top fittest individuals, it doesn't matter who is fitter within that group: everyone breeds with everyone else randomly
    children = []
    for i in range(numberOfChildren):
        ind1 = int(random()*len(fittest))
        ind2 = int(random()*len(fittest))
        child = createChild(fittest[ind1], fittest[ind2], 3, crossoverLocations)
        # MUTATE the children!
        child = mutateChild(child, mutationRate)
        children.append(child)
    return children
# createChildren(100, fitDF, .01)

### Fitness Functions

In [None]:
############################
# Fitness Support Functions
############################
def values_to_ranks(DF, ascending=False):
    Ranks={}
    # assign ranks based on given value (could be pvalue, -log(pvalue), ranks, etc)
    for col in DF:
        # Since zscore comes from -log(pvalue), flip the rank order so that low numbered ranks are still the best
        orderedCol = DF[col].sort_values(ascending=ascending)
        # Shuffle order of 0s
        nonZeros = orderedCol.loc[orderedCol!=0]
        try:
            zeros = orderedCol.loc[orderedCol==0].sample(frac=1)
        except:
            zeros = pd.Series(dtype=float)
        shuffledCol = pd.concat([nonZeros, zeros])
        # Assign ranks
        newRanks = pd.Series(data=range(0,len(shuffledCol)), name=col, index=shuffledCol.index)
        newRanks.sort_index(inplace=True) # Sort by index
        Ranks[col] = dict(zip(newRanks.index, newRanks.values))
    return pd.DataFrame(Ranks)

import math
def values_to_zscores(DF, dropZeros=True):
    zScore_dict={}
    if dropZeros==True:
        # Drop all rows that ONLY have (0). Never appeared across any experiment
        # Keeping the all 0s messes up the zscore
        df = DF[(DF.T != math.nan).any()]
     else:
        df = DF.copy()
    for col in df:
        zScore_dict[col] = (df[col]-df[col].mean()) / df[col].std(ddof=0)
    return pd.DataFrame(zScore_dict)

def scaled_ranks(DFstack):
    scaledDF = DFstack.copy()
    scaledDF['Rank'] -= scaledDF['Rank'].min() 
    scaledDF['Rank'] /= scaledDF['Rank'].max()
    return scaledDF

def clearTestGMT():
    import os
    dir_name = "Genetic_Algorithm/testgmt/"
    files = os.listdir(dir_name)
    for item in files:
        if item.endswith(".txt") or item.endswith(".gmt"):
            os.remove(os.path.join(dir_name, item))


    
############################
# Fitness Functions
############################
def target_shuffled_difference(pval_matrix, zscore=True, scaledRanks=True):
    # Select ranking method 
    if zscore==True:
        DF = values_to_zscores(pval_matrix)
    else:
        DF = values_to_ranks(pval_matrix)
    # Calculate the difference between Rank_target and Shuffled_Rank
    DFstack = DF.stack().reset_index()
    DFstack.columns = ['Kinase','Experiment','Rank']
    if scaledRanks==True:
            DFstack = scaled_ranks(DFstack)
    # Target Kinases Only
    DFstack_target = DFstack.loc[DFstack['Kinase']==DFstack['Experiment'].str.split('_').str[0]]
    # Shuffled targets
    DFstack_shuffled = DFstack_target.copy()
    DFstack_shuffled.loc[:,['Experiment','Rank']] = DFstack.sample(n=len(DFstack_target)).loc[:,['Experiment','Rank']].values
    DFmerged = pd.merge(DFstack_target, DFstack_shuffled, on='Kinase',suffixes=['_target', '_shuffled'])
    fitness = DFmerged['Rank_shuffled'].sum() - DFmerged['Rank_target'].sum()
    return fitness

### GA Function

In [None]:

def X2K_Web_GA(GMT, gmtLimit, initial_pop_size=5, generations=5, select_fittest=5, selection_method='Fitness-proportional',\
               fitness_method='target_shuffled_difference', children_per_generation=8, mutation_rate=.01, breeding_variation=0, \
               crossover_locations='random', include_fittest_parents=2, parallel=True, save_results='No'):
    # Prepare GMT input
    from shutil import copyfile
    clearTestGMT()
    copyfile(GMT, "Genetic_Algorithm/testgmt/"+ GMT.split("/")[-1])
    # Store GA settings
    GA_settings = {'gmt_file':GMT.split("/")[-1], 'initial_pop_size':initial_pop_size, 'generations':generations, 'select_fittest':select_fittest,
                   'selection_method':selection_method}
    # Results Dicts
    all_GA_results={}
    fitnessDict={}
    # 0. Create initial population 
    population = createPopulation(initial_pop_size)
    # Loop over n generations
    for gen in range(generations):
        print('================ GENERATION '+str(gen)+' ================')
        # 1. Get all fitnesses 
        pop_fitness_results, fitnessDict = population_fitness(gen, population, fitness_method, \
                                                     fitnessDict=fitnessDict, gmtLimit=gmtLimit, parallel=parallel)
        all_GA_results.update(pop_fitness_results)
        # 2. Select fittest
        fitnessDF = selectFittest(topNum=select_fittest, GA_results=pop_fitness_results, selectionMethod=selection_method)
        # 3. Create/mutate children
        population = createChildren(children_per_generation, fitnessDF, mutation_rate, breeding_variation, crossover_locations)
        if include_fittest_parents > 0:
            # When this is mixedTournament, selects from the parents that bred (regardless of whether they were the fittest in the whole population)
            population.extend( fitnessDF['newBinary'].values[:include_fittest_parents].tolist() )
    ga_resultsDict = {'all_GA_results':all_GA_results, 'GA_settings':GA_settings}
    if save_results!='No':
        import pickle 
        pickle.dump( ga_resultsDict, open( "Genetic_Algorithm/GA_Results/"+save_results+".pkl", "wb" ) )    
    return ga_resultsDict




GAset = {'GMT':"../X2K_Genetic_Algorithm/Validation/Perturbation_Data/GEO/Kinase_Perturbations_from_GEO_SUBSET1.80per.txt",
                    'gmtLimit':False, 'initial_pop_size':100, 'generations':10, 'select_fittest':10, 'selection_method':'Fitness-proportional',
                    'fitness_method':'target_shuffled_difference', 'children_per_generation':92, 'mutation_rate':.01, 'breeding_variation':0,
                    'crossover_locations':'random','include_fittest_parents':8, 'parallel':True, 'save_results':'GA_results'}
def GA_Train_Test(GAset):
    # Train GA
    GA_train, GA_settings_train = X2K_Web_GA(GMT=GAset['GMT'], gmtLimit=GAset['gmtLimit'], initial_pop_size=GAset['initial_pop_size'], \
                    generations=GAset['generations'], select_fittest=GAset['select_fittest'], selection_method=GAset['selection_method'], \
                    fitness_method=GAset['fitness_method'], children_per_generation=GAset['children_per_generation'], \
                    mutation_rate=GAset['mutation_rate'], breeding_variation=GAset['breeding_variation'], \
                    crossover_locations=GAset['crossover_locations'], include_fittest_parents=GAset['include_fittest_parents'],\
                    save_results=GAset['save_results']+'_train')
    # Test GA
    GAset['GMT'] = "../X2K_Genetic_Algorithm/Validation/Perturbation_Data/GEO/Kinase_Perturbations_from_GEO_SUBSET2.20per.txt"
    GA_test, GA_settings_test = X2K_Web_GA(GMT= GAset['GMT'], gmtLimit=GAset['gmtLimit'], initial_pop_size=GAset['initial_pop_size'], \
                    generations=GAset['generations'], select_fittest=GAset['select_fittest'], selection_method=GAset['selection_method'], \
                    fitness_method=GAset['fitness_method'], children_per_generation=GAset['children_per_generation'], \
                    mutation_rate=GAset['mutation_rate'], breeding_variation=GAset['breeding_variation'], \
                    crossover_locations=GAset['crossover_locations'], include_fittest_parents=GAset['include_fittest_parents'],\
                    save_results=GAset['save_results']+'_test')
    import pickle
    GA_resultsDict = {'GA_train':GA_train, 'GA_test':GA_test, 'GA_settings_train':GA_settings_train, 'GA_settings_test':GA_settings_test}
    pickle.dump( GA_resultsDict, open( "Genetic_Algorithm/GA_Results/"+GAset['save_results']+".pkl", "wb" ) )
    
    return GA_resultsDict

GA_resultsDict = GA_Train_Test(GAset)

## GA Results Plots

In [None]:


import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle

GA_results = pickle.load( open( "Genetic_Algorithm/GA_Results/GA_results.pkl", "rb" ) )


def prepare_df(GA_results):
    GA_train = GA_results[0]
    dat = pd.DataFrame(GA_train).T
    dat[['generation','fitness']] = dat[['generation','fitness']].apply(pd.to_numeric)
    return dat
    
def plot_fitness(GA_results):
    dat = prepare_df(GA_results)
    # Get peak fitness and add back to parent df
    dat = dat.join(dat.groupby('generation')['fitness'].max(), on='generation', rsuffix='_peak')
    # Plot
    f, ax = plt.subplots(1, 1)
    sns.pointplot(data=dat, x='generation', y='fitness', label='Mean Fitness', color='limegreen', ax=ax)
    sns.pointplot(data=dat, x='generation', y='fitness_peak', label='Peak Fitness', color='forestgreen', markers='^', ax=ax)
    # Add legend 
    ax.legend(handles=ax.lines[2:] , labels=["Mean Fitness","Peak Fitness" ])

   