In [1]:
import shutup
import torch
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tsgcn.simulation import MsprimeSimulation, run_sims
from tsgcn.util import get_idle_gpu
from tsgcn.data import TreeSequenceData, TreeSequencesDataset, windowed_div_from_ts, compute_ys
from tsgcn.model import BiGCNModel, BiGCNEncoder

import tskit

shutup.please()

  time=int(extended_GF.time.head(1) - 1), rate=0
  time=int(extended_GF.time.tail(1) + 1), rate=0


In [2]:
# this device variable will get used later during training
device = torch.device(f"cuda:{get_idle_gpu()}" if torch.cuda.is_available() else "cpu")
print(device)

cuda:1


In [3]:
## CONSTANTS
seed = 11379
num_reps = 1_000
sp_name = "HomSap"
chrom = None
length = 1_000_000
model_name ="OutOfAfrica_3G09"
sample_size=10
num_windows = 100
model_num_windows = 20
HomSap_chr13_OOA = MsprimeSimulation(seed, num_reps, sp_name, model_name, "data/raw/", chrom, length, sample_size, n_workers=512)
model_breaks = np.linspace(0, HomSap_chr13_OOA.contig.length, model_num_windows+1)
out_breaks = np.linspace(0, HomSap_chr13_OOA.contig.length, num_windows+1)

In [4]:
run_sims(HomSap_chr13_OOA)

In [5]:
# Just exploring the output of convert_tseq
from tsgcn.util import convert_tseq
ts=tskit.load(HomSap_chr13_OOA.ts_paths[0])
eix, ei, nf, sq = convert_tseq(ts)
a=TreeSequenceData(edge_index=eix, edge_interval=ei, x = nf, sequence_length=sq)

In [6]:
# Importing our dataset
dataset = TreeSequencesDataset("data/", HomSap_chr13_OOA.sims_path,seeds=HomSap_chr13_OOA.seed_array, y_name="windowed-diversity")

In [7]:
compute_ys(dataset, windowed_div_from_ts, "windowed-diversity", num_windows=num_windows, n_workers=256)

In [8]:
def node_num_child(ts):
    stats = np.zeros((ts.num_nodes, ), dtype=np.float32)
    for tree in ts.trees():
        for u in tree.nodes():
            stats[u] += tree.num_children(u)
    return torch.FloatTensor(stats)

In [9]:
dataset.len()

1000

In [10]:
dataset[0]

TreeSequenceData(x=[1419, 1], edge_index=[2, 7494], edge_interval=[2, 7494], sequence_length=1000000.0, y=[100])

In [11]:
test = BiGCNEncoder(model_breaks, device, dataset[0].num_features, dataset[0].num_features)

In [12]:
test(dataset[0])

tensor([[0.2346],
        [0.2346],
        [0.2346],
        ...,
        [0.2346],
        [0.2346],
        [0.2346]], grad_fn=<AddmmBackward0>)

In [13]:
from torch.utils.data import random_split

torch.manual_seed(123)
train_set, valid_set, test_set = random_split(dataset,[70*dataset.len()//100, 15*dataset.len()//100, 15*dataset.len()//100])

trainloader = DataLoader(train_set, batch_size=1, shuffle=True)
validloader = DataLoader(valid_set, batch_size=1, shuffle=True)
testloader = DataLoader(test_set, batch_size=1, shuffle=True)

In [14]:
def get_y(batch):
    #return batch.x[:,0].unsqueeze(1)
    return batch.y.unsqueeze(1)

In [15]:
torch.manual_seed(1793335)
model = BiGCNModel(device, num_encoder_in_features=dataset.num_features, num_encoder_out_features=2,
                    breaks=out_breaks, pooling="windowed_sum", out_breaks=out_breaks)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=2e-4)
num_epochs=100
criterion = torch.nn.MSELoss()

for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    num_batches = 0
    for batch in trainloader:
        num_batches+=1
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch)
        #print(output.shape, batch.x[:,0].unsqueeze(1).shape)
        loss = criterion(output,get_y(batch))
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
    
    train_avg_loss = epoch_loss / num_batches
    val_loss = 0
    model.eval()
    num_batches = 0
    for batch in validloader:
        num_batches+=1
        batch.to(device)
        output = model(batch)
        loss = criterion(output,get_y(batch))
        val_loss += loss.item()
    val_avg_loss = val_loss/num_batches
    
    
    print(f"Epochs: {epoch} | epoch avg. loss: {train_avg_loss:.8f} | validation avg. loss: {val_avg_loss:.8f}")

Epochs: 0 | epoch avg. loss: 75709059.83178571 | validation avg. loss: 5023676.29666667
Epochs: 1 | epoch avg. loss: 2783248.23660714 | validation avg. loss: 1532665.96791667
Epochs: 2 | epoch avg. loss: 1028326.06111607 | validation avg. loss: 677268.54354167
Epochs: 3 | epoch avg. loss: 484941.44696429 | validation avg. loss: 338603.79885417
Epochs: 4 | epoch avg. loss: 246977.28233259 | validation avg. loss: 174548.76723958
Epochs: 5 | epoch avg. loss: 126539.53973772 | validation avg. loss: 87999.74856771
Epochs: 6 | epoch avg. loss: 62238.49425502 | validation avg. loss: 41644.50421875
Epochs: 7 | epoch avg. loss: 28274.00491839 | validation avg. loss: 17848.81119792
Epochs: 8 | epoch avg. loss: 11660.73476249 | validation avg. loss: 7032.29307780
Epochs: 9 | epoch avg. loss: 4692.07811994 | validation avg. loss: 3035.63477987
Epochs: 10 | epoch avg. loss: 2416.26430324 | validation avg. loss: 1958.30739543
Epochs: 11 | epoch avg. loss: 1887.13746124 | validation avg. loss: 1752.5

In [None]:
model.eval()
predictions = []
real = []

for batch in testloader:
    
    output = model(batch.to(device))
    predictions.append(output.detach().cpu().numpy())
    real.append(get_y(batch).detach().cpu().numpy())

predictions = np.concatenate(predictions)
real = np.concatenate(real)

In [None]:
model.encoder

In [None]:
predictions.shape

In [None]:
real.shape

In [None]:
batch = next(iter(testloader))

In [None]:
batch.to(device)
output = model(batch)
criterion(output, get_y(batch))

Visualizing `Predicted~Observed` diversity within a single tree sequence

In [None]:
plt.scatter(get_y(batch).detach().cpu().numpy(), output.detach().cpu().numpy())
plt.axline((0.7,0.7), slope=1)

Now across all windows and tree sequences

In [None]:
import scipy.stats

In [None]:
# calculate r^2 of real and predictions
scipy.stats.pearsonr(real.flatten(), predictions.flatten())[0]**2

In [None]:


plt.scatter(real, predictions, alpha=0.3)
plt.axline((0.7,0.7), slope=1)

plt.ylabel('Predicted diversity')
plt.xlabel('Observed diversity')