In [None]:
from collections import defaultdict, Counter
from string import ascii_uppercase
from os import path
import os

from metadata import load_csv, zfilter

import pandas as pd
from numpy import random
import numpy as np


%matplotlib inline

In [None]:
supergroups = {
    'indica':               'Indica',
    'japonica':             'Japonica',
    'temperate japonica':   'Japonica',
    'tropical japonica':    'Japonica',
}

In [None]:
allruns = load_csv("all_25228runs_raw.csv")

In [None]:
allpd = pd.DataFrame(allruns)
allpd['group'] = allpd.group.str.lower()
allpd['supergroup'] = [supergroups.get(g, 'other') for g in allpd.group]

In [None]:
indjap = allpd.query('supergroup in ["Indica", "Japonica"]')

print("Groups:")
for g, c in Counter(list(indjap.group)).most_common():
    print("  -", g, c)
print("Supergroups:")
for g, c in Counter(list(indjap.supergroup)).most_common():
    print("  -", g, c)
   

In [None]:
print()
# Remove all samples that weren't sequenced as 6 runs
indjap = indjap.groupby('sra_sample').filter(lambda x: len(x) == 6)
print("Select 6s:", len(indjap))

# Remove all runs outside 1 sd from mean number of reads
indjap = zfilter(indjap, 'num_reads', 1)
print("Z-score filtering:", len(indjap))
# Remove all samps with fewer than 6 remaining runs
indjap = indjap.groupby('sra_sample').filter(lambda x: len(x) == 6)
print("6s again:", len(indjap))

In [None]:
indjap.groupby('supergroup').supergroup.count()

In [None]:
groups = defaultdict(list)
for group, runs in indjap.groupby('supergroup'):
    for samp in runs.sra_sample.unique():
        groups[group].append(samp)

In [None]:
def gensets(groups, n, each=8):
    for _ in range(n):
        sample = list()
        for group, items in groups.items():
            for i in random.choice(len(items), each):
                sample.append(items[i])
        yield sample

In [None]:
try:
    os.rmdir('sets')
except:
    pass
os.mkdir('sets')

In [None]:
ngroup = 100
allruns = set()

for i, samples in enumerate(gensets(groups, 100, each=8)):
    label = "{0:0{width}d}".format(i + 1, width=int(np.ceil(np.log10(ngroup + 1))))
    outfile = "sets/3krice_set_{}.txt".format(label)
    with open(outfile, 'w') as fh:
        for sample in samples:
            for run in indjap.query('sra_sample == "{}"'.format(sample)).sra_id:
                print(run, file=fh)
                allruns.add(run)
with open("sets/3krice_set_ALL.txt", 'w') as fh:
    for run in sorted(allruns):
        print(run, file=fh)

In [None]:
len(allruns)