In [10]:
import sys

sys.path.append("/home/ubuntu/graphseq-inference/")
from graphseq_inference.data_utils import *
from graphseq_inference.models import *
from graphseq_inference.train_utils import *

In [3]:
%load_ext autoreload
%autoreload 2

In [11]:
import matplotlib.pyplot as plt

In [8]:
#rm -rf ./20k_dataset/.ipynb_checkpoints

In [None]:
def RMSELoss(yhat,y):
    return torch.sqrt(torch.mean((yhat-y)**2))

criterion = RMSELoss

In [None]:
def convert_tree_sequence_to_data_object(tree_sequence: tskit.trees.TreeSequence,
                                                     parameters: np.ndarray,
                                                     num_trees:int = 500,
                                                     num_embedding:int = 60, 
                           ):
    
    population_size = parameters["pop_size_0":"pop_size_59"].tolist() 
    y = torch.Tensor(population_size)
    
    max_num_nodes = 2 * tree_sequence.num_samples - 1 
    data_objects = []
    
    #ts = tree_sequence
    #ts = msprime.mutate(ts, 1e-8)
    #ne = calculate_beta_coal_ne_estimate(ts.num_mutations, ts.sample_size, ts.sequence_length, parameters.model, 1e-8)
    
    for i, tree in enumerate(tree_sequence.trees()):
        if i < num_trees:
            
            data = from_networkx(nx.Graph(tree.as_dict_of_dicts()))
            rename_data_attribute(data, "branch_length", "edge_weight") 
            num_nodes = data.num_nodes
            data.x = torch.eye(max_num_nodes,num_embedding)
            data.x[num_nodes:] = torch.zeros(num_embedding)
            data.y = torch.Tensor(torch.log(y))
            data.num_nodes = max_num_nodes
            
            
            
            #data.edge_weight = data.edge_weight / ne
            
            data_objects.append(data)
            
        else: 
            break

        
    return data_objects

In [13]:
from scipy.interpolate import interp1d
upper_out_of_bound = lower_out_of_bound = True
while upper_out_of_bound or lower_out_of_bound:
    steps = 18
    x = np.log(get_population_time(time_rate=0.1, num_time_windows=steps, tmax=10_000_000).tolist())
    y = np.log(sample_population_size(10_000, 10_000_000, steps))
    xnew = np.linspace(x[0], x[-1], num=10000, endpoint=True)
    f_cubic = interp1d(x, y, kind='cubic')
    ynew = f_cubic(xnew)
    upper_out_of_bound = np.sum(np.exp(ynew) > 10_000_000) > 0
    lower_out_of_bound = np.sum(np.exp(ynew) < 10_000) > 0
    
x_sample = xnew[np.linspace(10, 9999, 60).astype(int)]
y_sample = ynew[np.linspace(10, 9999, 60).astype(int)]

population_time = np.exp(x_sample)

In [221]:
parameters = pd.read_csv("20k_seed_0x1337_demographies.csv")

In [None]:
directory = "./20k_dataset/"
files = os.listdir(directory)
files = [directory +  file for file in files]
len(files)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DiffPoolNet(19, 60, 192, 60).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
initialize_weights(model)

In [None]:
# no validation loop because training dataset is ridiculously large
# and validation occurs later by choosing specific scenarios

In [None]:
length = 60
loss_all = []

for epoch in range(0, 2):
    np.random.shuffle(files)
    for i in tqdm(range(0, len(files))):

        file = files[i]
        ts, mask = torch.load(file)
        
        nth_scenario = int(file.split("_")[2])
        

        parameter_set = parameters.iloc[nth_scenario]
        data_objects = convert_tree_sequence_to_data_object(ts, parameter_set)

        if len(data_objects) > 1:
        
            mask[population_time <= 10] = False
            mask = torch.tile(torch.Tensor(mask), (len(data_objects), 1))

            optimizer.zero_grad()
            dl = DataLoader(data_objects, batch_size=len(data_objects))
            for batch in dl:
                batch = batch.to(device)
                y_hat = model(batch)
                y_true = data_objects[0].y.tile(len(data_objects)).reshape(len(data_objects), length).to(device)

                mask = mask.to(device)
                mask = mask.bool()
                y_true[~mask] = 0
                y_hat[~mask] = 0
                loss = criterion(y_hat, y_true) 
                loss.backward()
                loss_all.append(loss.item())
                optimizer.step()

            #if i != 0 and i % 100000 == 0:
            if i != 0 and i % 10000 == 0:

                loss_all = np.mean(loss_all)
                print(f"loss {loss_all}")
                if i != 0 and i % 100000 == 0: torch.save(model.state_dict(), "./large_models_demo_nescaling/mmc_diffpool_model_demography_inference_intermediate" + str(epoch) + "_" + str(i) + ".pth")
                os.system(f'echo "Epoch: {epoch:03d}, Train Loss: {np.mean(loss_all):.4f}" >> ./large_models_demo_nescaling/mmc_diffpool_model_demography_inference_intermediate.txt')
                loss_all = []
                
    torch.save(model.state_dict(), "./trained_models/demography_models//mmc_diffpool_model_demography_inference" + str(epoch) + "_" + str(i) + ".pth")

  0%|                                | 1675/2000000 [14:00<258:21:01,  2.15it/s]