In [1]:
import shutup
import torch
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from tsgcn.simulation import MsprimeSimulation, run_sims
from tsgcn.util import get_idle_gpu, get_degree_histogram
from tsgcn.data import TreeSequenceData, TreeSequencesDataset, windowed_div_from_ts, compute_ys, get_node_features
from tsgcn.model import BiGCNModel, BiGCNEncoder, PNANet

import tskit

shutup.please()

  time=int(extended_GF.time.head(1) - 1), rate=0
  time=int(extended_GF.time.tail(1) + 1), rate=0


In [2]:
import torchvision
from torchview import draw_graph

In [3]:
import gpustat
gpustat.print_gpustat()

poppy                     Mon Aug  7 21:44:13 2023  525.125.06
[0] NVIDIA A100 80GB PCIe | 32°C,   0 % |  2158 / 81920 MB | murillor(1288M)
[1] NVIDIA A100 80GB PCIe | 50°C,  90 % | 81104 / 81920 MB | chriscs(80234M)
[2] NVIDIA A100 80GB PCIe | 45°C,  94 % | 81092 / 81920 MB | chriscs(80222M)


In [4]:
# this device variable will get used later during training
device = torch.device(f"cuda:{get_idle_gpu()}" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
## CONSTANTS
seed = 11379
num_reps = 1000
sp_name = "HomSap"
chrom = None
length = 10_000
model_name ="OutOfAfrica_3G09"
sample_size=10
HomSap_chr13_OOA = MsprimeSimulation(seed, num_reps, sp_name, model_name, "data/raw", chrom, length, sample_size, n_workers=512)

In [6]:
run_sims(HomSap_chr13_OOA)

In [7]:
# Importing our dataset
dataset = TreeSequencesDataset("data/", HomSap_chr13_OOA.sims_path,seeds=HomSap_chr13_OOA.seed_array, y_name="node-features")

In [8]:
compute_ys(dataset, get_node_features, "node-features", n_workers=256)

In [9]:
from torch.utils.data import random_split

torch.manual_seed(123)
train_set, valid_set, test_set = random_split(dataset,[70*dataset.len()//100, 15*dataset.len()//100, 15*dataset.len()//100])

trainloader = DataLoader(train_set, batch_size=1, shuffle=True)
validloader = DataLoader(valid_set, batch_size=1, shuffle=True)
testloader = DataLoader(test_set, batch_size=1, shuffle=True)

In [10]:
deg = get_degree_histogram(trainloader)

In [11]:
def get_y(batch):
    #return batch.x[:,0].unsqueeze(1)
    return batch.y[0][:,2]

In [12]:
dataset.num_edge_features

2

In [15]:
torch.manual_seed(1793335)
#model = BiGCNModel(dataset.num_features, channels=16, num_layers=4, device=device, num_out_features=1, dropout=0.0)
model = PNANet(input_dim=dataset.num_features, hidden_dim=4, edge_dim=dataset.num_edge_features, num_layers= 1, out_dim=1, deg = deg)
model = model.to(device)


In [16]:
optimizer = torch.optim.Adam(model.parameters(),lr=1e-5)
num_epochs=60
criterion = torch.nn.MSELoss()
last_train_loss = 0

for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    num_batches = 0
    for batch in trainloader:
        num_batches+=1
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch)
        #print(output.shape, batch.x[:,0].unsqueeze(1).shape)
        loss = criterion(output,get_y(batch))
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
    
    train_avg_loss = epoch_loss / num_batches
    val_loss = 0
    model.eval()
    num_batches = 0
    for batch in validloader:
        num_batches+=1
        batch.to(device)
        output = model(batch)
        loss = criterion(output,get_y(batch))
        val_loss += loss.item()
    val_avg_loss = val_loss/num_batches
    print(f"Epochs: {epoch} | epoch avg. loss: {train_avg_loss:.8f} | validation avg. loss: {val_avg_loss:.8f}")
    if abs(train_avg_loss-last_train_loss) < 1e-3 and epoch > 15: 
        print("early stopping")
        break
    last_train_loss = train_avg_loss

Epochs: 0 | epoch avg. loss: 1.04640994 | validation avg. loss: 1.04026600
Epochs: 1 | epoch avg. loss: 1.03486062 | validation avg. loss: 1.02977730
Epochs: 2 | epoch avg. loss: 1.02534501 | validation avg. loss: 1.02121227
Epochs: 3 | epoch avg. loss: 1.01766113 | validation avg. loss: 1.01439059
Epochs: 4 | epoch avg. loss: 1.01164180 | validation avg. loss: 1.00915479
Epochs: 5 | epoch avg. loss: 1.00713274 | validation avg. loss: 1.00535002
Epochs: 6 | epoch avg. loss: 1.00397062 | validation avg. loss: 1.00280093
Epochs: 7 | epoch avg. loss: 1.00196090 | validation avg. loss: 1.00129112
Epochs: 8 | epoch avg. loss: 1.00086068 | validation avg. loss: 1.00055167
Epochs: 9 | epoch avg. loss: 1.00038145 | validation avg. loss: 1.00028216
Epochs: 10 | epoch avg. loss: 1.00023366 | validation avg. loss: 1.00021881
Epochs: 11 | epoch avg. loss: 1.00020440 | validation avg. loss: 1.00020765
Epochs: 12 | epoch avg. loss: 1.00019573 | validation avg. loss: 1.00019813
Epochs: 13 | epoch avg

In [None]:
batch.edge_index

In [None]:
import torch_geometric.utils.to_dense_adj

In [None]:
ts = tskit.load(dataset.raw_file_names[0])

In [None]:
dataset[0].edge_attr

In [None]:
tree = next(ts.trees())

In [None]:
import networkx as nx
from torch_geometric.utils import from_networkx, to_networkx
import matplotlib.pyplot as plt


In [None]:
G2 = to_networkx(dataset[0])

In [None]:
G = nx.Graph(tree.as_dict_of_dicts())

In [None]:
nx.draw(G2)
plt.show()

In [None]:
from_networkx(G).edge_index

In [None]:
torch_geometric.utils.to_dense_adj(batch.edge_index)

In [None]:
get_y(batch).shape

In [None]:
output.shape

In [None]:
model.eval()
predictions = []
real = []

for batch in testloader:
    
    output = model(batch.to(device))
    predictions.append(output.detach().cpu().numpy())
    real.append(get_y(batch).detach().cpu().numpy())

predictions = np.concatenate(predictions)
real = np.concatenate(real)

In [None]:
#torch.save(model.encoder.state_dict(), "trained_encoder.pt")

In [None]:
model.encoder

In [None]:
predictions.shape

In [None]:
real.shape

In [None]:
batch = next(iter(testloader))

In [None]:
batch.to(device)
output = model(batch)
criterion(output, get_y(batch))

Visualizing `Predicted~Observed` diversity within a single tree sequence

In [None]:
plt.scatter(get_y(batch).detach().cpu().numpy(), output.detach().cpu().numpy())
#plt.axline((0.5,0.5), slope=1)

Now across all windows and tree sequences

In [None]:
import scipy.stats

In [None]:
# calculate r^2 of real and predictions
scipy.stats.pearsonr(real.flatten(), predictions.flatten())[0]**2

In [None]:


plt.scatter(real, predictions, alpha=0.3)
#plt.axline((0.5,0.5), slope=1)

plt.ylabel('Predicted diversity')
plt.xlabel('Observed diversity')