## 1. Execute one of the following cell to pick AGs results to be plotted

### Specify 
- datasets that should be loaded, and the label you want to assign to them (eg GAN, GAN_versionXX ...)
- output directory name (will be created if needed)
- SNP position file(s)  
- number of individuals to keep from each dataset (will be randomly subampled if less than the total number)

In [None]:
outDir = "FIG/TEST/FIG_10K/"
# number of individuals to keep (for each dataset)
# can be set to a small number to run some tests
# nsub = 5000 for the full dataset
nsub = 100 
infileDir = "../"
# The key name can be changed appart from the Real one, that is mandatory
infiles = {'Real': infileDir + "1000G_real_genomes/10K_SNP_1000G_real.hapt.zip",
           'GAN': infileDir + "GAN_AGs/10K_SNP_GAN_AG_10800Epochs.hapt.zip",
           'RBM': infileDir + "RBM_AGs/10K_SNP_RBM_AG_1050epochs.hapt.zip",
          }
realposfname = infileDir + "1000G_real_genomes/10k_SNP.legend"
# same SNP positions for all datasets, so it is just repeated for all keys:
position_fname = {key:realposfname for key in infiles.keys()} 
# if needed, update the position file for some key (that exists in infiles):
# position_fname['stdpopsim'] = infileDir + "stdpop_genomes/CEU_chr15_matching_10k.legend"

## 2. Imports and general color dictionary

In [None]:
dirscript = 'short'
from short import plot_utils as plu

In [None]:
import seaborn as sns
import pandas as pd
import numpy as np
import importlib
from short import plot_utils as plu
import os

In [None]:
inDir = ""
mainOutDir = outDir
# General colors
# if you add a new key in infiles, you should add a corresponding color in the following dictionnary
# leaving examples for extra labels below on purpose 
allcolpal = dict({'Real':"#95a5a6", 
                  'Binomial':"#2ecc71", 'Markov_w5':'#6a3d9a', 'Markov_w10':"#9b59b6", 
                  'GAN':"#3498db", 'GAN_ep20k':"#3498db", # blue main GANs (you can choose the same color for multiple labels if you wish)
                  'GAN_ep5k':'#0b559f', # dark blue SM GAN 805
                  'RBM':"#e74c3c", 'RBM_bis': "#e74c3c", # red MAIN RBMs
                  'RBM_SIG1': '#fa9b58', 'RBM_SIG2': '#fece7f', # orange SM RBMs chunk
                  'Test1':'#575757', 'Test2':'#393939',
                  'HAPGEN': '#458B74',
                  'stdpopsim': '#fa9b58',
                  'RBM_init_random':'salmon', #sns.color_palette('Reds_r',14)[0], 
                  'RBM_init_test': 'sienna', #sns.color_palette('Reds_r',14)[1],
                  'RBM_init_train': 'darkorange' #sns.color_palette('Reds_r',14)[2], 
              })
## Print the full color palette if needed:
# print(allcolpal.keys())
# sns.palplot(sns.color_palette(allcolpal.values()))

# Colors for RBMs at different epochs
RBM_labels = ['RBM_ep{}'.format(ep) for ep in np.concatenate( [np.linspace(200,650,10).astype(int), [690]])]
allcolpal.update(dict(zip(RBM_labels, sns.color_palette('Reds_r',len(RBM_labels)+3))))

#print(allcolpal.keys())
#sns.palplot(sns.color_palette(allcolpal.values()))

# Update current color palette to the dataset type in infiles 
colpal =  {key:allcolpal[key] for key in infiles.keys()}
sns.set_palette(colpal.values())

print("Datasets under study:\n",infiles)
sns.palplot(sns.color_palette())
print(f"Output Directory for figures: {outDir}\n",
      f"Real dataset positions: {realposfname}\n",
      f"Sample size:{nsub}")


## 3. run notebook to plot all figures or a subset of sumstats (for faster results)

In [None]:
f"Figures will be saved in {mainOutDir} or its subdirectories"

In [None]:
## Print one more time the name of datasets that will be loaded
## the path should exist otherwise you need to check that your setup is correct
[f"Input file {f} exists: {os.path.exists(f)}" for f in infiles.values()]

In [None]:
# Setup options (transformations, sumstats to compute etc) and output directory (automatically derived from maintOutDir)

importlib.reload(plu) # useful only if plot_utils is changed since you imported it for dev reason
boolComputeAATS = True # if False notebook 5 will reload previously computed AATS instead of computing it
figwi = 12 # control size of some figures 

# set allchecks to False for a first rapid scan
# set to True for computing/plotting all sumstats and scores (long, better on a cluster)
allchecks = True #False 

# pick the transformations you want to apply to the datasets
# For no transformation choose 
# transformations=None
transformations={'to_minor_encoding':False, 'min_af': 0, 'max_af': 1}

if not transformations is None:
    tname=';'.join([f'{k}-{v}' for k,v in transformations.items()])
    outDir = os.path.join(mainOutDir,tname+'/')
print(f"Figures will be saved in {outDir}")
if os.path.exists(outDir):
    print('This directory exists, the following files might be overwritten:')
    print(os.listdir(outDir))

### Compute summary statistics
**You can pick which notebooks to execute** (and comment the other lines)    
Only the **first one is mandatory**  (plotfig_utils_1_INIT.ipynb)  
It loads datasets, applies basic transformations if asked, and initializes a few variables (such as a dictionnary of haplotypes, allele counts, fixed site vectors, etc)

In [None]:
%run -p {dirscript}/plotfig_utils_1_INIT.ipynb  # mandatory, all lines below are optional
%run -p {dirscript}/plotfig_utils_2_AF.ipynb 
%run -p {dirscript}/plotfig_utils_3_PCA.ipynb
%run -p {dirscript}/plotfig_utils_4_LD.ipynb
if allchecks:
    %run -p {dirscript}/plotfig_utils_5_DIST_AATS.ipynb # computationnally long
if allchecks:
    %run -p {dirscript}/plotfig_utils_6_3pointcorr.ipynb  # computationnally long