In [None]:
import numpy as np
import scipy.stats as st
import seaborn as sns

def get_all_combinations(parent): # Finds all possible combinations of alleles a parent can pass on to their offspring, assuming independen assortment.
	if len(parent) == 1:
		return [parent[0][0], parent[0][1]]
	else:
		genlist = []
		for x in get_all_combinations(parent[1:]):
			genlist.append(parent[0][0] + x)
			genlist.append(parent[0][1] + x)
		return genlist

def make_row(genotype, allele):
	row = []
	for a in genotype:
		row.append(a + allele)
	return row

def make_table(parent1, parent2):
	table = []
	for a in parent1:
		table.append(make_row(parent2, a))
	return table

def print_table(table, c1, c2): # formats and prints Punnett square
	latextable = []
	divlength = (len(c1[0])*2+4)*2**(len(c1[0]))
	print('')
	print('', end=' ')
	for a in c2:
		print(' '*(len(c1[0])+3) + a + '', end=' ')
		latextable.append('& ' + a + ' ')
	print('\n' + ' '*(len(c1[0])+1) + '-'*(divlength))
	latextable.append('\\\ \n\\hline\n')
	
	for i, row in enumerate(table):
		print(c1[table.index(row)], end=' ')
		latextable.append(c1[table.index(row)] + ' & ')
		print('|', end=' ')
		for j, cell in enumerate(row):
			print(cell + ' | ', end=' ')
			if j != len(row)-1:
				latextable.append(cell + ' & ')
			else:
				latextable.append(cell + ' ')
		print('\n' + ' '*(len(c1[0])+1) + '-'*(divlength))
		if i != len(table)-1:
			latextable.append('\\\ \n')	
	return latextable		
	
def print_genotype_frequencies(table): # calculates frequencies for each genotype present in table
	freqtable = []
	freqtable.append('\n')
	calculated = []
	genotypes = [a for b in table for a in b]
	for k, x in enumerate(genotypes):
		count = 0
		for y in genotypes:
			if sorted(x) == sorted(y):
				count += 1
		if sorted(x) not in calculated:
			print("The frequency of the " + x + " genotype is " + str(float(count)/float((len(genotypes)))*100) + "%.")
			freqtable.append(x + ' & ' + str(float(count)/float((len(genotypes)))*100) + '\\% \\\ \\hline \n')	
		calculated.append(sorted(x))
	return freqtable	


def coin_toss(p):
    '''
    input: Probability to get a head 
    output: 1 for heads, 0 for tails 
    '''
    coin = np.random.rand(1)
    return np.sum(coin<p)
def virtual_plant (N, traits, p=0.5, crosslinking = False, crosslinking_p = 0.5):
    '''
    input: N: The total number of plants to be simulated 
           traits: How many traits to be simulated 
           p: population propotion with dominace allele
    output: List of dicts: 
            Each dict is in format: {0:[[0,1],[0]], 1:[[0,0],[0]],2:[[1,1],[0]]}
            keys: Trait number 
            dict[key][0]: Two allelles, 0 means recessive, 1 means dominance 
            dict[key][1]: crosslinking marker, if it is 1, means there is cross linking, 0 means no cross linking
    '''
    plants = []
    for i in range(N):
        plant = {}
        
        for j in range(traits):
            if crosslinking:
                plant[j] = [[coin_toss(p),coin_toss(p)],[coin_toss(crosslinking_p)]]
            else: 
                plant[j] = [[coin_toss(p),coin_toss(p)],[0]]
        plants.append(plant)
    return plants 
# Note of indexing: 
# plants[i] = plant 
# plant[j] /plants [i][j]= trait [j]
# plant[j][0][0]/plants[i][j][0][0] = trait[i] allele 1, 1 is domin, 0 is recessive 
# plant[j][0][1]/plants[i][j][0][1]  = train[i] allele 2, 1 is domin, 0 is recessive
# plant[j][1][0]/plants[i][j][1][0]  = crosslinking indicator, 1 is crosslink, 0 is no cross link. 

def cross_breeder(plant1, plant2, N=1, p =0.5):
    '''
    input: Plant 1: dict form, withe keys = trait index, [[allele1, allelle2], [crosslinking indicator]]
           plant 2: same form as plant 1 
           N: Cross breeding number, how much next generation generate
    output: list of plants
    '''
    new_plants = []
#     if plant1.keys().issubset(plant2.keys()) or plant2.keys().issubset(plant1.keys()):
#         traits = np.min(len(plant1.keys(),len(plant2.keys())))
#         ### These two lines makes sure: 1 list must have more traits than the other, and use the shorter traits. 
#         ### I am just showing I could do that, but there seems no need to do it. 
    for i in range(N):
        new_plant = {}
        crosslinker1 = coin_toss(p)
        crosslinker2 = coin_toss(p)
        allele1 = 2
        allele2 = 2
        for j in range(len(plant1.keys())):
            if plant1[j][1][0]:
                allele1 = plant1[j][0][crosslinker1]
            else: 
                allele1 = plant1[j][0][coin_toss(p)]
            if plant2[j][1][0]:
                allele2 = plant2[j][0][crosslinker2]
            else: 
                allele2 = plant2[j][0][coin_toss(p)]
            new_plant[j] = [[allele1,allele2],[plant1[j][1][0]*plant2[j][1][0]]]
        new_plants.append(new_plant)
        return new_plants
    
alphabet0 = 'abcdefghijklmnopqrstuvwxyz'
alphabet1 = alphabet0.upper()
def genotype_converter(plant): 
    '''
    input: a plant, with at most 26 traits
    output: a string contains the alphabetical genotype, with A/a for the first allelle, and B/b and so on
    
    '''
    string = []
    for j in range(len(plant.keys())):
        trait = sorted(plant[j][0])
        if (trait[1]+trait[0])==2:
            string.append(alphabet1[j] + alphabet1[j])
        elif (trait[1] + trait[0])==1 :  
            string.append(alphabet1[j] + alphabet0[j])
        else: 
            string.append(alphabet0[j] + alphabet0[j])
    return string
def phenotype_converter(plant, method = 'auto'): 
    '''
    input: a plant, with at most 26 traits
    type: auto, or reversed 
    output: a string contains the alphabetical genotype, with A/a for the first allelle, and B/b and so on
    
    '''
    phenotype = {}
    if method == 'auto' :
        string1 = 'Dominant'
        string2 = 'Recessive'
        string = ''
        for j in range(len(plant.keys())):
            string = alphabet1[j]
            trait = sorted(plant[j][0])
            if trait[1] or trait[0]:
                phenotype[string] = string1
            else: 
                phenotype[string] = string2
    else:
          
        string1 = 'Dominant'
        string2 = 'Recessive'
        string = ''
        for j in range(len(plant.keys())):
            string = alphabet1[j]
            trait = sorted(plant[j][0])
            if trait[1] or trait[0]:
                phenotype[string] = string2
            else: 
                phenotype[string] = string1
    return phenotype

def reverse_genotype(string,crosslinking = False): 
    '''
    input: string with 2 characters, like Aa/Bb or AA aa
    output: a plant form: {x:[1,0],[0]}
    '''
    key = alphabet1.index(string[0].upper())
    if string[0].isupper() and string[1].isupper():
        genotype = [1,1]
    elif string[0].islower() and string[1].islower():
        genotype = [0,0]
    else: 
        genotype = [0,1]
    crosslinking_stat = np.sum(crosslinking)
    plant = {key:[genotype,[crosslinking_stat]]}
    return plant

def punnett_square_generator(p1,p2):
    ''' input: two Virtue plants 
        output: punnett_square print out 
    '''
    p1string = genotype_converter(p1)
    p2string = genotype_converter(p2)
    c1 = get_all_combinations(p1string)
    c2 = get_all_combinations(p2string)
    a = make_table(c1, c2)
    latextable = print_table(a, c1, c2)
    freqtable = print_genotype_frequencies(a)
    print('')
    return latextable,freqtable

# Plant class: Input a good genotype. Like AaBb. CcDd also works. Dont do it like #$%S>APD
# crosslinking: [any genes are cross linked, like [1,2]]. Automatically, 1 should be B and 2 should be C. 
# you could do it sequntially, but I dont like it. 
# Plant.genotype: genotype of plant
# Plant.phenotype: phenotype of plant 
# Plant.virtue: generate a virtue form of plant. This is the form from what we had later. 
# plant.cross: if takes no input, you get self cross. If you take another plant, you get cross breeding. 
# plant.punnett_square: get a punnett_square. 


class Plant:
    def __init__(self, genotype, crosslinking = []):
        self.genotype = genotype    # instance variable unique to each instance
        virtue = {}
        for i in range(int(len(genotype)/2)):
            key = alphabet1.index(genotype[2*i].upper())
            if genotype[2*i].isupper() and genotype[2*i+1].isupper():
                vgenotype = [1,1]
            elif genotype[2*i].islower() and genotype[2*i+1].islower():
                vgenotype = [0,0]
            else: 
                vgenotype = [0,1]
            virtue[key] = [vgenotype,[0]]
        for item in crosslinking:
            virtue[item][1][0] = 1 
        self.virtue = virtue
        phenotype = phenotype_converter(virtue)
        self.phenotype = phenotype
        
        def cross(self, plant2 = self):
            cross = cross_breeder(self.virtue,plant2.virtue)
            print(genotype_converter(cross))
        def punnett_square(self, plant2 = self):
            # here we need a pandas. 1st row: all genes father can pass, and 1st column is all genes mother can pass. 
            # we assume self is father 
            # 1st row: 
            punnett_square = punnett_square_generator(self.virtue, plant2.virtue)