In [5]:
%load_ext autoreload
%autoreload 2

import sys

sys.path.append("/home/ubuntu/graphseq-inference/")
from graphseq_inference.data_utils import *
from graphseq_inference.models import *
from graphseq_inference.train_utils import *



In [6]:
directory = "./20k_dataset/"

files = os.listdir(directory)
files = [directory +  file for file in files]

len(files)

FileNotFoundError: [Errno 2] No such file or directory: './20k_dataset/'

In [14]:
def reduce_tree_sequence(ts, num_samples):
    return ts.simplify(np.random.choice(range(ts.num_samples),
                                 num_samples, replace=False).tolist())

In [15]:
def convert_tree_sequence_to_data_object_alpha(tree_sequence: tskit.trees.TreeSequence,
                                                     parameters: np.ndarray,
                                                     num_trees:int = 500,
                                                     num_embedding:int = 60, 
                           ):
    
    alpha = parameter_set.model
    y = torch.Tensor([alpha])
    
    
    
    max_num_nodes = 2 * tree_sequence.num_samples - 1 
    data_objects = []
    
    for i, tree in enumerate(tree_sequence.trees()):
        if i < num_trees:
            
            data = from_networkx(nx.Graph(tree.as_dict_of_dicts()))
            rename_data_attribute(data, "branch_length", "edge_weight") 
            num_nodes = data.num_nodes
            data.x = torch.eye(max_num_nodes,num_embedding)
            data.x[num_nodes:] = torch.zeros(num_embedding)
            data.y = torch.Tensor(y)
            data.num_nodes = max_num_nodes
            data_objects.append(data)
            
        else: 
            break

        
    return data_objects

In [19]:
parameters = pd.read_csv("20k_seed_0x1337_demographies.csv")

In [6]:
def get_alpha(file): return float(file.split("_")[-1].replace(".trees", ""))

In [39]:
from os.path import basename

class AlphaInferenceModel(nn.Module):
    
    def __init__(self, DemographyNet, time_window=60):
        super().__init__()
        self.l1 = nn.Linear(time_window, time_window//2)
        self.l2 = nn.Linear(time_window//2, time_window//4)
        self.l3 = nn.Linear(time_window//4, 1)
        self.DemographyNet = DemographyNet
    
    def forward(self, batch):
        x = self.DemographyNet(batch)
        return self.l3(F.relu(self.l2(F.relu(self.l1(x)))))


    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
demography_net = DiffPoolNet(19, 60, 192, 60, track_running_stats=False)

model = AlphaInferenceModel(demography_net)
model = model.to(device)
criterion = RMSELoss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)




In [8]:
def preorder_dist(tree):
    result = []
    for root in tree.roots:
        stack = [(root, (-1, root) , (tree.get_time(tree.root), 0))]
        while len(stack) > 0:
            u, pc, time = stack.pop()
            result.append((u, pc, time))
            for v in tree.children(u):
                stack.append((v, (u,v) , (tree.get_time(u),tree.get_time(v))))
    return result


def multiple_mergerized_to_data_object(result):

    G = nx.Graph()
    for _, edge, times in result[:-1]:
        a, b = edge
        ta, tb = times
        branch = ta - tb
        G.add_weighted_edges_from([(a, b, branch)])
        
    data = from_networkx(G)
    rename_data_attribute(data, "weight", "edge_weight") 

    return data


def add_x(data, num_embedding = 60):
    max_num_nodes = 2 * 10 - 1 
    num_nodes = data.num_nodes
    data.x = torch.eye(max_num_nodes,num_embedding)
    data.x[num_nodes:] = torch.zeros(num_embedding)
    data.num_nodes = max_num_nodes
    return data

def ts_to_data_objects(ts):
    
    data_objects = []
    for tree in ts.trees():
        result = preorder_dist(tree)
        result.reverse()
        result = restructure_result(result, threshold=10)    
        data = multiple_mergerized_to_data_object(result)
        data_objects.append(data)  
        
    return data_objects



def restructure_result(result, threshold = 2):

    is_modified = True
    while True:

        if not is_modified: break
        is_modified = False

        for i, (idx_node, (parent, node), (time_parent, time_node)) in enumerate(result):
            branch_length = time_parent - time_node            

            threshold = 20
            
            if time_parent < 2000:
                threshold = 200
            
            if time_parent < 200:
                threshold = 20
            
            if time_parent < 20:
                threshold = 2
                
            if time_parent < 10:
                threshold = 1
                
            if time_parent < 5:
                threshold = 0.5
                
            
            if branch_length < threshold and branch_length != 0 and (parent != node):
    
                
                new_time = (time_parent+time_node)/2
                result[i] = (-1, (parent, parent), (time_parent, time_parent))
                for j, (_, (p, n), (tp, tn)) in enumerate(result):
                    if node == n:
                        result[j] = (-1, (p, parent), (tp, time_parent))#
                    if node == p:
                        result[j] = (-1, (parent, n), (time_parent, tn))

                is_modified = True
                break

    new_result = []
    for a, (b,c), (d, e) in result:
        if b != c:
            new_result.append((a, (b,c), (d,e)))
    return new_result

In [None]:
length = 1
loss_all = []

for epoch in range(0, 10):
    np.random.shuffle(files)
    for i, file in enumerate(tqdm(files)):


        ts, mask = torch.load(file)
        #ts = tskit.load(file)
        #ts = tskit.load(file)
        #data_objects = convert_tree_sequence_to_data_object_alpha(ts, get_alpha(file))
        nth_scenario = int(file.split("_")[2])
        parameter_set = parameters.iloc[nth_scenario]
        data_objects = convert_tree_sequence_to_data_object_alpha(ts, parameter_set)

        if len(data_objects) > 1:
        
            #mask[population_time <= 10] = False
            #mask = torch.tile(torch.Tensor(mask), (len(data_objects), 1))

            optimizer.zero_grad()
            dl = DataLoader(data_objects, batch_size=len(data_objects))
            for batch in dl:
                batch = batch.to(device)
                y_hat = model(batch)
                y_true = data_objects[0].y.tile(len(data_objects)).reshape(len(data_objects), length).to(device)

                loss = criterion(y_hat, y_true) 
                loss.backward()
                loss_all.append(loss.item())
                optimizer.step()

            if i != 0 and i % 10000 == 0:
                loss_all = np.mean(loss_all)
                print(f"loss {loss_all}")
                torch.save(model.state_dict(), "./alpha_inf/mmc_diffpool_model_alpha_inference_intermediate" + str(epoch) + "_" + str(i) + ".pth")
                os.system(f'echo "Epoch: {epoch:03d}, Train Loss: {np.mean(loss_all):.4f}" >> ./alpha_inf/mmc_diffpool_model_alpha_inference_intermediate.txt')
                loss_all = []
                
torch.save(model.state_dict(), "./alpha_inf/mmc_diffpool_model_alpha_inference_intermediate" + str(epoch) + "_" + str(i) + ".pth")

  0%|                                 | 1200/1000000 [08:06<92:48:05,  2.99it/s]