In [9]:
import stdpopsim
import numpy as np
import os

In [10]:
## CONSTANTS
seed = 11379
num_reps = 50
sp_name = "HomSap"
chrom = "chr13"
model_name ="OutOfAfrica_3G09"
sample_size=10
engine = stdpopsim.get_engine("msprime")
rng = np.random.default_rng(seed)
seed_array = rng.integers(1,2**31,num_reps)
sims_path = f"data/sims/{sp_name}/{chrom}/{model_name}/"

In [11]:
os.makedirs(sims_path, exist_ok=True)

In [12]:
species = stdpopsim.get_species(sp_name)
model = species.get_demographic_model(model_name)
contig = species.get_contig(chrom, mutation_rate=model.mutation_rate)
samples = {pop.name: sample_size for pop in model.populations}

In [32]:
ts_paths = []
num_trees_list = []
for seed in seed_array:
    spath = f"{sims_path}sim_{seed}.ts"
    ts_paths.append(spath)
    if os.path.exists(spath):
        ts = tskit.load(spath)
        num_trees_list.append(ts.num_trees)
    else:
        ts = engine.simulate(model, contig, samples, seed=seed)
        ts.dump(spath)

In [13]:
ts=tskit.load("data/sims/HomSap/chr13/OutOfAfrica_3G09/sim_1057502661.ts")

In [14]:
tree = next(ts.trees())

In [18]:
ts.node(8).is_sample()

1

In [None]:
list(tree.nodes())

In [None]:
tree.get_length(0)

In [25]:
import torch
from torch_geometric.data import Dataset, Data
import tskit

In [26]:
def convert_tree(tree, ts):
    dtree = tree.as_dict_of_dicts()
    edge_index = []
    for p, dic in dtree.items():
        for c, dic2 in dic.items():
            edge_index.append([p,c])
    assert len(edge_index) == tree.num_edges
    node_features = []
    for i, node in enumerate(tree.nodes()):
        node_features.append([ts.node(node).time, ts.node(node).is_sample()])
    assert len(node_features) == len(tree.preorder())
    return torch.IntTensor(edge_index).T, torch.FloatTensor(node_features)

In [40]:

class TreeSequenceDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, seeds=None, num_trees_list=None):
        self.seeds = seeds
        self.num_trees_list = num_trees_list
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return [f'data/sims/HomSap/chr13/OutOfAfrica_3G09/sim_{s}.ts' for s in self.seeds]

    @property
    def processed_file_names(self):
        file_list = []
        i = -1
        for _, num_trees in zip(self.seeds, self.num_trees_list):
            i += 1
            for _ in range(num_trees):
                file_list.append(f'tree_{i}.pt')
        return file_list

    def download(self):
        pass

    def process(self):
        idx = -1
        for raw_file_name, seed in zip(self.raw_file_names, self.seeds):
            ts = tskit.load(raw_file_name)
            tree_breaks = [x for x in ts.breakpoints()]
            div = ts.diversity(windows=tree_breaks, mode="branch")
            div = (div - np.mean(div)) / np.std(div)
            for i, tree in enumerate(ts.trees()):
                idx += 1
                print(f"Processing tree {idx} from ts {seed}")
                edge_index, node_features = convert_tree(tree, ts)
                #print(edge_index, edge_features)
                data = Data(x=node_features, edge_index=edge_index,y=div[i])
                torch.save(data, os.path.join(self.processed_dir, f'tree_{idx}.pt'))
    def len(self):
        return len(self.processed_file_names)
    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'sim_{idx}.pt'))
        return data

In [41]:
seed_array[:3]

array([1347859693,  372481089,  697986435])

In [42]:
dataset = TreeSequenceDataset("data/",seeds=seed_array[:3], num_trees_list=num_trees_list[:3])

Processing...


Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
Processing tree 0 from ts 1347859693
P

KeyboardInterrupt: 

In [None]:
loader =  DataLoader(dataset, batch_size=32, shuffle=True)


In [None]:
dataset.num_edge_features