# Import Packages and environmental setup

In [76]:
import numpy as np
import torch
torch.manual_seed(0)
import torch.nn as nn
from torch_geometric.data import Data
import itertools
import json
import sys
import time
from tqdm import tqdm
import os
from os import path
sys.path.insert(0, '../')
import gc
import torch_geometric.transforms as T
from torch_geometric.nn.conv import TransformerConv
from torch_geometric.nn import VGAE
from torch_geometric.loader import DataLoader
from torch_geometric.utils import batched_negative_sampling
gc.collect()

1801

# Functions

In [77]:
# function to convert string to numbers
def convert_string_to_numbers(str, dict):
    ''' str: string to convert
        dict dictionary with the relative ordering of each char'''
            # create a map iterator using a lambda function
    # lambda x -> return dict[x]
    # This return the value for each key in dict based on str
    numbers = map(lambda x: dict[x], str)
    # return an array of int64 numbers
    return np.fromiter(numbers, dtype=np.int64)

In [78]:
# function to create a graph for each 
def construct_single_graph(idx, label):
    ''' idx: the current graph index w.r.t the label
        label: the current label'''
    # transform the character of amino acid in to numbers for all 5 sequences in this graph
    transformed_x = []
    for i in range(5):
        # get the index of the sequence from the original dataset
        seq_idx = 5*idx + i
        transformed_x.append(convert_string_to_numbers(seq_string[seq_idx][:-1], dict_amino))
        
    # set feature vectors of internal nodes to -1 with same length
    vec_len = len(transformed_x[0])
    internal_node_vec = np.full(vec_len, -1, dtype=np.int64)
    # append the three internal node
    for i in range(3):
        transformed_x.append(internal_node_vec)
    # create the node feature vector
    x = torch.tensor(transformed_x, dtype=torch.float)
    
    # now we create the edge set w.r.t the label
    # This part is quite dumb as I'm hard coding the 15 edge set
    if label == 0:
        edge_index = torch.tensor([[0,5],[5,0],[1,5],[5,1],
                                   [5,6],[6,5],[4,6],[6,4],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [3,7],[7,3]], dtype=torch.long)
    elif label == 1:
        edge_index = torch.tensor([[0,5],[5,0],[1,5],[5,1],
                                   [5,6],[6,5],[3,6],[6,3],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [4,7],[7,4]], dtype=torch.long)
    elif label == 2:
        edge_index = torch.tensor([[0,5],[5,0],[1,5],[5,1],
                                   [5,6],[6,5],[2,6],[6,2],
                                   [6,7],[7,6],[3,7],[7,3],
                                   [4,7],[7,4]], dtype=torch.long)
    elif label == 3:
        edge_index = torch.tensor([[0,5],[5,0],[2,5],[5,2],
                                   [5,6],[6,5],[4,6],[6,4],
                                   [6,7],[7,6],[3,7],[7,3],
                                   [1,7],[7,1]], dtype=torch.long)
    elif label == 4:
        edge_index = torch.tensor([[0,5],[5,0],[2,5],[5,2],
                                   [5,6],[6,5],[3,6],[6,3],
                                   [6,7],[7,6],[4,7],[7,4],
                                   [1,7],[7,1]], dtype=torch.long)
    elif label == 5:
        edge_index = torch.tensor([[0,5],[5,0],[2,5],[5,2],
                                   [5,6],[6,5],[1,6],[6,1],
                                   [6,7],[7,6],[4,7],[7,4],
                                   [3,7],[7,3]], dtype=torch.long)
    elif label == 6:
        edge_index = torch.tensor([[0,5],[5,0],[3,5],[5,3],
                                   [5,6],[6,5],[4,6],[6,4],
                                   [6,7],[7,6],[1,7],[7,1],
                                   [2,7],[7,2]], dtype=torch.long)
    elif label == 7:
        edge_index = torch.tensor([[0,5],[5,0],[3,5],[5,3],
                                   [5,6],[6,5],[2,6],[6,2],
                                   [6,7],[7,6],[1,7],[7,1],
                                   [4,7],[7,4]], dtype=torch.long)
    elif label == 8:
        edge_index = torch.tensor([[0,5],[5,0],[3,5],[5,3],
                                   [5,6],[6,5],[1,6],[6,1],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [4,7],[7,4]], dtype=torch.long)
    elif label == 9:
        edge_index = torch.tensor([[0,5],[5,0],[4,5],[5,4],
                                   [5,6],[6,5],[3,6],[6,3],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [1,7],[7,1]], dtype=torch.long)
    elif label == 10:
        edge_index = torch.tensor([[0,5],[5,0],[4,5],[5,4],
                                   [5,6],[6,5],[2,6],[6,2],
                                   [6,7],[7,6],[3,7],[7,3],
                                   [1,7],[7,1]], dtype=torch.long)
    elif label == 11:
        edge_index = torch.tensor([[0,5],[5,0],[4,5],[5,4],
                                   [5,6],[6,5],[1,6],[6,1],
                                   [6,7],[7,6],[3,7],[7,3],
                                   [2,7],[7,2]], dtype=torch.long)
    elif label == 12:
        edge_index = torch.tensor([[1,5],[5,1],[2,5],[5,2],
                                   [5,6],[6,5],[0,6],[6,0],
                                   [6,7],[7,6],[3,7],[7,3],
                                   [4,7],[7,4]], dtype=torch.long)
    elif label == 13:
        edge_index = torch.tensor([[1,5],[5,1],[3,5],[5,3],
                                   [5,6],[6,5],[0,6],[6,0],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [4,7],[7,4]], dtype=torch.long)
    else:
        edge_index = torch.tensor([[1,5],[5,1],[4,5],[5,4],
                                   [5,6],[6,5],[0,6],[6,0],
                                   [6,7],[7,6],[2,7],[7,2],
                                   [3,7],[7,3]], dtype=torch.long)
    
    # Now we create the graph object as Data
    data = Data(x=x, edge_index=edge_index.t().contiguous())
    return data

# File inputs

In [79]:
# get name of the script
# nameScript = sys.argv[0].split('/')[-1]
nameScript = "gae_model.py"
# get json file name of the script
nameJson = "gae.json"
# nameJson = sys.argv[1]
print("------------------------------------------------------------------------")
print("Training the Garph Auto Encoder for 5-taxa dataset")
print("------------------------------------------------------------------------")
print("Executing " + nameScript + " following " + nameJson, flush = True)

# opening Json file 
jsonFile = open(nameJson) 
dataJson = json.load(jsonFile)

# loading the input data from the json file
ngpu = dataJson["ngpu"]                  # number of GPUS
lr = dataJson["lr"]                      # learning rate
embedSize = dataJson["embedSize"]        # Embedding size
nEpochs = dataJson["nEpochs"]            # Number of Epochs
batchSize = dataJson["batchSize"]        # batchSize


data_root = dataJson["dataRoot"]         # data folder
model_root = dataJson["modelRoot"]       # folder to save the data

label_files = dataJson["labelFile"]      # file with labels
sequence_files = dataJson["matFile"]     # file with sequences

if "summaryFile" in dataJson:
    summary_file = dataJson["summaryFile"]
else :
    summary_file = "summary_file.txt"


print("------------------------------------------------------------------------")
print("Loading Sequence Data in " + sequence_files, flush = True)
print("Loading Label Data in " + label_files, flush = True)

# we read the labels as list of strings
with open(data_root+label_files, 'r') as f:
    label_char = f.readlines()

# we read the sequence as a list of strings
with open(data_root+sequence_files, 'r') as f:
    seq_string = f.readlines()

n_samples = len(label_char)
seq_length = len(seq_string[0])-1
print("Number of samples:{}; Sequence length of each sample:{}"
        .format(n_samples, seq_length))
print("------------------------------------------------------------------------")

------------------------------------------------------------------------
Training the Garph Auto Encoder for 5-taxa dataset
------------------------------------------------------------------------
Executing gae_model.py following gae.json
------------------------------------------------------------------------
Loading Sequence Data in sequences12062021.in
Loading Label Data in labels12062021.in
Number of samples:10000; Sequence length of each sample:1550
------------------------------------------------------------------------


# Data pre-processing

In [80]:
# We need to extract the dictionary with the relative positions
# for each aminoacid

# first we need to extract all the different chars
strL = ""
for c in seq_string[0][:-1]:
    if not c in strL:
        strL += c

# we sort them
strL = sorted(strL)

# we give them a relative order
dict_amino = {}
for ii, c in enumerate(strL):
    dict_amino[c] = ii

# looping over the labels and create array. Here each element of the
# label_char has the form "1\n", so we only take the first one
labels = np.fromiter(map(lambda x: int(x[0])-1,
                         label_char), dtype= np.int64)

In [81]:
# Create all graphs from raw dataset
# empty dataset for all graphs
dataset = []
for i in range(n_samples):
    data = construct_single_graph(i, labels[i])
    if (not data.validate(raise_on_error=True)):
        print("Error! Node number and edge set does not match!")
        break
    if (not data.is_undirected()):
        print("Error! Incorrect edge set!")
        break
    dataset.append(data)

# Model

In [82]:
class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(VariationalGCNEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, 3 * out_channels, heads=4, concat=False, beta=True)
        self.conv_mu = TransformerConv(3 * out_channels, out_channels, heads=4, concat=False, beta=True)
        self.conv_logstd = TransformerConv(3 * out_channels, out_channels, heads=4, concat=False, beta=True)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [83]:
in_channels = seq_length
out_channels = embedSize

model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
model = model.to(device)

# Training

In [84]:
train_dataset = dataset[:9000]
test_dataset = dataset[9000:]
train_loader = DataLoader(train_dataset, batch_size=batchSize, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batchSize, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [85]:
# Train
def run_one_epoch(data_loader):
    for _, batch in enumerate(tqdm(data_loader)):
        batch.to(device) 
        optimizer.zero_grad() 
        batch_neg_edge = batched_negative_sampling(batch.edge_index, batch.batch)
        z = model.encode(batch.x, batch.edge_index)
        loss = model.recon_loss(z, batch.edge_index, batch_neg_edge)
        loss = loss + (1 / batch.num_nodes) * model.kl_loss()
        loss.backward()
        optimizer.step()
    return float(loss)

#def test(data_loader):
#    for _, batch in enumerate(tqdm(data_loader)):
#        with torch.no_grad():
#            z = model.encode(x, train_pos_edge_index)
#    return model.test(z, pos_edge_index, neg_edge_index)

In [86]:
for epoch in range(1, nEpochs+1):
    model.train()
    loss = run_one_epoch(train_loader)
    print(loss)

100%|█████████████████████████████████████████| 282/282 [00:47<00:00,  6.00it/s]


2.739959239959717


100%|█████████████████████████████████████████| 282/282 [01:07<00:00,  4.19it/s]


2.5902953147888184


100%|█████████████████████████████████████████| 282/282 [01:23<00:00,  3.38it/s]


2.453690528869629


100%|█████████████████████████████████████████| 282/282 [01:17<00:00,  3.62it/s]


2.5720934867858887


100%|█████████████████████████████████████████| 282/282 [01:33<00:00,  3.01it/s]


2.686994791030884


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.58it/s]


2.5583713054656982


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.60it/s]


2.5617260932922363


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.54it/s]


2.5496678352355957


100%|█████████████████████████████████████████| 282/282 [01:13<00:00,  3.84it/s]


2.5518059730529785


100%|█████████████████████████████████████████| 282/282 [01:09<00:00,  4.08it/s]


2.5417938232421875


100%|█████████████████████████████████████████| 282/282 [01:17<00:00,  3.62it/s]


2.553239107131958


100%|█████████████████████████████████████████| 282/282 [00:59<00:00,  4.72it/s]


2.4066555500030518


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.54it/s]


2.4194602966308594


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.53it/s]


2.5908684730529785


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.61it/s]


2.463777780532837


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.54it/s]


2.4508845806121826


100%|█████████████████████████████████████████| 282/282 [01:05<00:00,  4.31it/s]


2.5266098976135254


100%|█████████████████████████████████████████| 282/282 [01:14<00:00,  3.79it/s]


2.352996826171875


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.51it/s]


2.5209977626800537


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.48it/s]


2.470663547515869


100%|█████████████████████████████████████████| 282/282 [01:00<00:00,  4.63it/s]


2.4318885803222656


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.48it/s]


2.4952921867370605


100%|█████████████████████████████████████████| 282/282 [01:22<00:00,  3.40it/s]


2.4877803325653076


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.51it/s]


2.454202651977539


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.52it/s]


2.4832112789154053


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.49it/s]


2.458240270614624


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.54it/s]


2.3802433013916016


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.56it/s]


2.4667422771453857


100%|█████████████████████████████████████████| 282/282 [01:22<00:00,  3.41it/s]


2.4021830558776855


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.50it/s]


2.395709753036499


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.48it/s]


2.3775384426116943


100%|█████████████████████████████████████████| 282/282 [01:04<00:00,  4.37it/s]


2.304204225540161


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.52it/s]


2.398435592651367


100%|█████████████████████████████████████████| 282/282 [01:22<00:00,  3.42it/s]


2.4413442611694336


100%|█████████████████████████████████████████| 282/282 [01:00<00:00,  4.64it/s]


2.4598278999328613


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.41it/s]


2.506908893585205


100%|█████████████████████████████████████████| 282/282 [01:10<00:00,  4.01it/s]


2.4965782165527344


100%|█████████████████████████████████████████| 282/282 [01:08<00:00,  4.09it/s]


2.481613874435425


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.52it/s]


2.406102180480957


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.51it/s]


2.443437099456787


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.48it/s]


2.474085807800293


100%|█████████████████████████████████████████| 282/282 [01:17<00:00,  3.64it/s]


2.456878185272217


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.50it/s]


2.466817855834961


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.41it/s]


2.499765396118164


100%|█████████████████████████████████████████| 282/282 [01:16<00:00,  3.69it/s]


2.6124067306518555


100%|█████████████████████████████████████████| 282/282 [01:05<00:00,  4.32it/s]


2.421323299407959


100%|█████████████████████████████████████████| 282/282 [01:27<00:00,  3.21it/s]


2.393373966217041


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.51it/s]


2.4429054260253906


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.48it/s]


2.4697179794311523


100%|█████████████████████████████████████████| 282/282 [01:05<00:00,  4.31it/s]


2.4641785621643066


100%|█████████████████████████████████████████| 282/282 [01:15<00:00,  3.72it/s]


2.5014052391052246


100%|█████████████████████████████████████████| 282/282 [01:17<00:00,  3.64it/s]


2.523127555847168


100%|█████████████████████████████████████████| 282/282 [01:00<00:00,  4.66it/s]


2.4435782432556152


100%|█████████████████████████████████████████| 282/282 [01:14<00:00,  3.80it/s]


2.447129487991333


100%|█████████████████████████████████████████| 282/282 [01:08<00:00,  4.10it/s]


2.489325761795044


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.45it/s]


2.5426998138427734


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.53it/s]


2.4346137046813965


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.48it/s]


2.36669659614563


100%|█████████████████████████████████████████| 282/282 [01:00<00:00,  4.69it/s]


2.5046184062957764


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.49it/s]


2.4565176963806152


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.59it/s]


2.403986930847168


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.52it/s]


2.366711378097534


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.51it/s]


2.400484085083008


100%|█████████████████████████████████████████| 282/282 [01:14<00:00,  3.81it/s]


2.4267983436584473


100%|█████████████████████████████████████████| 282/282 [01:07<00:00,  4.16it/s]


2.527409076690674


100%|█████████████████████████████████████████| 282/282 [01:22<00:00,  3.43it/s]


2.3454694747924805


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.61it/s]


2.4541029930114746


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.51it/s]


2.418696165084839


100%|█████████████████████████████████████████| 282/282 [01:00<00:00,  4.63it/s]


2.4071578979492188


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.56it/s]


2.51910400390625


100%|█████████████████████████████████████████| 282/282 [01:10<00:00,  3.98it/s]


2.4732367992401123


100%|█████████████████████████████████████████| 282/282 [01:09<00:00,  4.05it/s]


2.4560706615448


100%|█████████████████████████████████████████| 282/282 [01:17<00:00,  3.63it/s]


2.516623020172119


100%|█████████████████████████████████████████| 282/282 [01:22<00:00,  3.43it/s]


2.4932451248168945


100%|█████████████████████████████████████████| 282/282 [00:59<00:00,  4.72it/s]


2.482408285140991


100%|█████████████████████████████████████████| 282/282 [01:18<00:00,  3.60it/s]


2.4616341590881348


100%|█████████████████████████████████████████| 282/282 [01:21<00:00,  3.46it/s]


2.4812607765197754


100%|█████████████████████████████████████████| 282/282 [00:59<00:00,  4.74it/s]


2.455271005630493


100%|█████████████████████████████████████████| 282/282 [01:16<00:00,  3.67it/s]


2.4741933345794678


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.44it/s]


2.4869704246520996


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.53it/s]


2.5455875396728516


100%|█████████████████████████████████████████| 282/282 [01:08<00:00,  4.12it/s]


2.465162754058838


100%|█████████████████████████████████████████| 282/282 [01:13<00:00,  3.84it/s]


2.4853363037109375


100%|█████████████████████████████████████████| 282/282 [01:19<00:00,  3.55it/s]


2.511230230331421


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.46it/s]


2.466195583343506


100%|█████████████████████████████████████████| 282/282 [01:23<00:00,  3.40it/s]


2.424468517303467


100%|█████████████████████████████████████████| 282/282 [01:01<00:00,  4.60it/s]


2.465144157409668


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.45it/s]


2.4295811653137207


100%|█████████████████████████████████████████| 282/282 [01:02<00:00,  4.48it/s]


2.4105782508850098


100%|█████████████████████████████████████████| 282/282 [01:04<00:00,  4.37it/s]


2.4472241401672363


100%|█████████████████████████████████████████| 282/282 [01:16<00:00,  3.67it/s]


2.4565796852111816


100%|█████████████████████████████████████████| 282/282 [01:04<00:00,  4.37it/s]


2.441279411315918


100%|█████████████████████████████████████████| 282/282 [01:09<00:00,  4.05it/s]


2.555305004119873


100%|█████████████████████████████████████████| 282/282 [01:11<00:00,  3.92it/s]


2.3475558757781982


100%|█████████████████████████████████████████| 282/282 [01:03<00:00,  4.45it/s]


2.3537673950195312


100%|█████████████████████████████████████████| 282/282 [01:10<00:00,  4.03it/s]


2.452988624572754


100%|█████████████████████████████████████████| 282/282 [01:24<00:00,  3.32it/s]


2.4660987854003906


100%|█████████████████████████████████████████| 282/282 [01:14<00:00,  3.80it/s]


2.5219616889953613


100%|█████████████████████████████████████████| 282/282 [01:13<00:00,  3.83it/s]


2.3720574378967285


100%|█████████████████████████████████████████| 282/282 [01:20<00:00,  3.52it/s]

2.489607810974121



