In [1]:
import itertools
import pandas as pd
import networkx as nx
import numpy as np
import os
import random

import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers

from spektral.layers import GCNConv
from spektral.layers.convolutional import gcn_conv
from spektral.transforms import LayerPreprocess
from spektral.transforms import GCNFilter
from spektral.data import Dataset
from spektral.data import Graph
from spektral.data.loaders import SingleLoader

from sklearn.preprocessing import MinMaxScaler

from glob import glob

from Bio.Seq import Seq


In [2]:
tf.config.run_functions_eagerly(
    True
)


In [3]:
class plasgraph(tf.keras.Model):
    def __init__(
        self,
        n_labels,
        output_activation="sigmoid",
        channels=16,
        activation="relu",
        dropout_rate=0.1,
        l2_reg=2.5e-4,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.n_labels = n_labels
        self.channels = channels
        self.activation = activation
        self.output_activation = output_activation
        self.dropout_rate = dropout_rate
        self.l2_reg = l2_reg
        reg = tf.keras.regularizers.l2(l2_reg)

        self._dropout_layer_0 = layers.Dropout(dropout_rate)
        self._dropout_layer_1 = layers.Dropout(dropout_rate)
        self._dropout_layer_2 = layers.Dropout(dropout_rate)
        self._dropout_layer_3 = layers.Dropout(dropout_rate)
        self._dropout_layer_4 = layers.Dropout(dropout_rate)
        self._dropout_layer_5 = layers.Dropout(dropout_rate)
        self._dropout_layer_6 = layers.Dropout(dropout_rate)
        self._dropout_layer_7 = layers.Dropout(dropout_rate)
        self._dropout_layer_8 = layers.Dropout(dropout_rate)

        self.fully_connected_1 = layers.Dense(32, activation="relu")
        self.fully_connected_2 = layers.Dense(32, activation="relu")
        self.fully_connected_3 = layers.Dense(32, activation="relu")
        self.fully_connected_4 = layers.Dense(2, activation="sigmoid")

        self._gcn_layer1 = gcn_conv.GCNConv(
            channels=32, activation=activation, kernel_regularizer=reg, use_bias=True
        )

    def get_config(self):
        return dict(
            n_labels=self.n_labels,
            channels=self.channels,
            activation=self.activation,
            output_activation=self.output_activation,
            dropout_rate=self.dropout_rate,
            l2_reg=self.l2_reg,
        )

    def call(self, inputs):
        x, a = inputs

        node_identity = self.fully_connected_1(x)

        x = self.fully_connected_2(x)

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_1(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_2(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_3(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_4(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_5(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_6(merged)

        x = self._gcn_layer1([x, a])

        merged = layers.concatenate([node_identity, x])

        x = self._dropout_layer_7(merged)

        x = self.fully_connected_3(x)

        x = self._dropout_layer_8(x)

        x = self.fully_connected_4(x)

        return x


In [4]:
# convert networkx graph to a spektral graph
class Networkx_to_Spektral(Dataset):
    def __init__(self, nx_graph, **kwargs):
        self.nx_graph = nx_graph

        super().__init__(**kwargs)

    def read(self):

        x = np.array(
            [self.nx_graph.nodes[node_name]["x"] for node_name in self.nx_graph.nodes]
        )

        y = np.array(
            [self.nx_graph.nodes[node_name]["y"] for node_name in self.nx_graph.nodes]
        )

        a = nx.adjacency_matrix(self.nx_graph)
        a.setdiag(0)
        a.eliminate_zeros()

        # return a list of Graph objects
        return [Graph(x=x.astype(float), a=a.astype(float), y=y.astype(float))]


In [5]:
# get pentamer distributions

kmer_length = 5

k_mers = ["".join(x) for x in itertools.product("ACGT", repeat=kmer_length)]


fwd_kmers = []
rev_kmers = []

for k_mer in k_mers:
    if not ((k_mer in fwd_kmers) or (k_mer in rev_kmers)):
        fwd_kmers.append(k_mer)
        rev_kmers.append(str(Seq(k_mer).reverse_complement()))


def get_kmer_distribution(
    sequence, k_mers=k_mers, fwd_kmers=fwd_kmers, kmer_length=5, scale=False
):
    if len(sequence) < 5:
        return [0] * int(4**kmer_length / 2)
    dict_kmer_count = {}

    for k_mer in k_mers:
        dict_kmer_count[k_mer] = 0

    for i in range(len(sequence) + 1 - kmer_length):
        kmer = sequence[i : i + kmer_length]
        try:
            dict_kmer_count[kmer] += 1
        except KeyError:
            pass

    k_mer_counts = [
        dict_kmer_count[k_mer] + dict_kmer_count[str(Seq(k_mer).reverse_complement())]
        for k_mer in fwd_kmers
    ]

    if scale:
        scaler = MinMaxScaler()
        k_mer_counts = scaler.fit_transform(np.array(k_mer_counts).reshape(-1, 1))
        k_mer_counts = list(k_mer_counts.flatten())

    return k_mer_counts


# extract GC content

def get_gc_content(seq):
    number_gc = 0
    number_acgt = 0
    for base in seq.lower():
        if base in "gc":
            number_gc += 1
        if base in "acgt":
            number_acgt += 1
    try:
        gc_content = round(number_gc / number_acgt, 4)
    except ZeroDivisionError:
        gc_content = 0.5
    return gc_content


In [6]:
#set seeds for reproducibility
seed_number = 123

os.environ['PYTHONHASHSEED']=str(seed_number)
random.seed(seed_number)
np.random.seed(seed_number)
tf.random.set_seed(seed_number)

In [7]:
# get 20 k. pneumoniae samples

tuple_list_k_pneumoniae = []
alignment_files_k_pneumoniae = glob("data/labels/k_pneumoniae/*alignment_labelled_ambiguity_cutoff_1.csv")[0:20]
for alignment_file in alignment_files_k_pneumoniae:
    isolate = alignment_file.split("\\")[-1].split("_")[0]
    graph_file = "data/graphs/k_pneumoniae/" + isolate + ".gfa"
    tuple_list_k_pneumoniae.append((alignment_file, graph_file))

# get 20 E. coli samples

tuple_list_e_coli = []
alignment_files_e_coli = glob("data/labels/e_coli/*alignment_labelled_ambiguity_cutoff_1.csv")[0:20]
for alignment_file in alignment_files_e_coli:
    isolate = "_".join(alignment_file.split("\\")[-1].split("_")[0:3])
    graph_file = "data/graphs/e_coli/" + isolate + "_assembly.gfa"
    tuple_list_e_coli.append((alignment_file, graph_file))

# get 20 E. faecium samples

tuple_list_e_faecium = []
alignment_files_e_faecium = glob("data/labels/e_faecium/*alignment_labelled_ambiguity_cutoff_1.csv")[0:20]
for alignment_file in alignment_files_e_faecium:
    isolate = alignment_file.split("\\")[-1].split("_")[0]
    graph_file = "data/graphs/e_faecium/" + isolate + "_assembly.gfa"
    tuple_list_e_faecium.append((alignment_file, graph_file))


# concatenate tuple lists and shuffle

tuple_list_all_species = tuple_list_k_pneumoniae + tuple_list_e_coli + tuple_list_e_faecium
random.shuffle(tuple_list_all_species)

Extract data to generate the graphs

In [8]:
current_num_contigs = 0
dict_contig_length = {}
dict_contig_length_normalized = {}
tuple_node1_node2 = []
dict_contig_gc = {}
dict_contig_kmer = {}
dict_contig_coverage = {}
dict_contig_label = {}

dict_contig_kmer_euclidean_distance = {}
dict_contig_num_edges = {}

In [9]:
for alignment_file, graph_file in tuple_list_all_species:

    file_ = open(graph_file, "r")
    lines = file_.readlines()
    file_.close()

    df_alignment = pd.read_csv(alignment_file, index_col=0)

    # get gc of whole seq

    whole_seq = ""

    for line in lines:
        if line.split()[0] == "S":
            whole_seq += line.strip().split()[2]

    gc_of_whole_seq = get_gc_content(whole_seq)

    # get contig lengths and max length

    max_contig_length = 0

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_length[current_num_contigs + int(line.split()[1])] = len(
                line.split()[2]
            )
            if len(line.split()[2]) > max_contig_length:
                max_contig_length = len(line.split()[2])

    # get normalized contig lengths and max length

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_length_normalized[
                current_num_contigs + int(line.split()[1])
            ] = (len(line.split()[2]) / max_contig_length)

    # get graph edges

    for line in lines:
        if line.split()[0] == "L":
            tuple_node1_node2.append(
                (
                    current_num_contigs + int(line.split()[1]),
                    current_num_contigs + int(line.split()[3]),
                )
            )

    # get gc content

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_gc[current_num_contigs + int(line.split()[1])] = (
                get_gc_content(line.split()[2]) - gc_of_whole_seq
            )

    # get pentamer distributions

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_kmer[
                current_num_contigs + int(line.split()[1])
            ] = get_kmer_distribution(line.split()[2], k_mers=k_mers, scale=True)

    # get euclidian distance of pentamer distribution for each node

    # generate dict with all contigs of current isolate and their pentamer distribution
    dict_contig_kmer_current_isolate = {}

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_kmer_current_isolate[
                int(line.split()[1])
            ] = get_kmer_distribution(line.split()[2], k_mers=k_mers)

    # calculate total pentamer distribution and scale between 0 and 1
    all_kmer_counts = [
        sum(x) for x in zip(*list(dict_contig_kmer_current_isolate.values()))
    ]
    scaler = MinMaxScaler()
    all_kmer_counts = scaler.fit_transform(np.array(all_kmer_counts).reshape(-1, 1))
    all_kmer_counts = list(all_kmer_counts.flatten())

    # get euclidean distance for each contig and add to dict
    for contig in dict_contig_kmer_current_isolate:
        kmer_distribution = np.array(dict_contig_kmer_current_isolate[contig])
        scaler = MinMaxScaler()
        scaled_kmer_distribution = scaler.fit_transform(
            np.array(kmer_distribution).reshape(-1, 1)
        )
        scaled_kmer_distribution = list(scaled_kmer_distribution.flatten())
        dict_contig_kmer_euclidean_distance[
            current_num_contigs + contig
        ] = np.linalg.norm(
            np.array(all_kmer_counts) - np.array(scaled_kmer_distribution)
        )

    # get coverage

    for line in lines:
        if line.split()[0] == "S":
            dict_contig_coverage[current_num_contigs + int(line.split()[1])] = round(
                float(line.strip().split(":")[-1]), 2
            )

    # get labels

    for line in lines:
        if line.split()[0] == "S":
            contig_ = int(line.split()[1])
            try:
                label = df_alignment.loc[contig_, "label"]
                if label == "chromosome":
                    dict_contig_label[current_num_contigs + contig_] = [0, 1]
                elif label == "plasmid":
                    dict_contig_label[current_num_contigs + contig_] = [1, 0]
                elif label == "ambiguous":
                    dict_contig_label[current_num_contigs + contig_] = [1, 1]
                else:
                    dict_contig_label[current_num_contigs + contig_] = [0, 0]
            except KeyError:
                dict_contig_label[current_num_contigs + contig_] = [0, 0]

    current_num_contigs = len(dict_contig_coverage.keys())


In [10]:
# generate networkx graph

G = nx.Graph()

for tpl in tuple_node1_node2:
    G.add_edge(tpl[0], tpl[1])


# get number of edges per contig_

for contig_ in G.nodes:
    dict_contig_num_edges[contig_] = len(list(G.neighbors(contig_)))

# make feature dict

dict_contig_list_coverage_gc_kmer = {}

for contig_ in G.nodes:
    dict_contig_list_coverage_gc_kmer[contig_] = [
        dict_contig_coverage[contig_],
        dict_contig_gc[contig_],
        dict_contig_kmer_euclidean_distance[contig_],
        dict_contig_num_edges[contig_],
        dict_contig_length_normalized[contig_]
    ]  # + dict_contig_kmer[contig_]

# add features to graph nodes
nx.set_node_attributes(G, dict_contig_list_coverage_gc_kmer, "x")

# add labels to graph nodes
nx.set_node_attributes(G, dict_contig_label, "y")

# remove all nodes < 100 bp and connect new neighbors
for node in list(G.nodes):
    if dict_contig_length[node] < 100:
        if len(list(G.neighbors(node))):
            # connecting neighbors of node
            neighbors = list(G.neighbors(node))
            all_new_edges = list(itertools.combinations(neighbors, 2))
            for edge in all_new_edges:
                G.add_edge(edge[0], edge[1])
        # removing node
        G.remove_node(node)

# generate spektral graph
all_graphs = Networkx_to_Spektral(G)

all_graphs.apply(GCNFilter())

print(all_graphs[0])


  self._set_arrayXarray(i, j, x)


Graph(n_nodes=11191, n_node_features=5, n_edge_features=None, n_labels=2)


Add sample weight

In [11]:
# sample weights and masking
number_total_nodes = len(G.nodes)

num_unlabelled = len([node for node in G.nodes if G.nodes[node]["y"] == [0, 0]])
num_chromosome = len([node for node in G.nodes if G.nodes[node]["y"] == [0, 1]])
num_plasmid = len([node for node in G.nodes if G.nodes[node]["y"] == [1, 0]])
num_ambiguous = len([node for node in G.nodes if G.nodes[node]["y"] == [1, 1]])


print(
    "Chromosome contigs:",
    num_chromosome,
    "Plasmid contigs:",
    num_plasmid,
    "Ambiguous contigs:",
    num_ambiguous,
    "Unlabelled contigs:",
    num_unlabelled,
)


# for each class, calculate weight. Set unlabelled contigs weight to 0
chromosome_weight = (num_unlabelled + num_plasmid + num_ambiguous) / number_total_nodes
plasmid_weight = (num_unlabelled + num_chromosome + num_ambiguous) / number_total_nodes
ambiguous_weight = (num_unlabelled + num_chromosome + num_plasmid) / number_total_nodes

masks = []

for node_index, node_ in enumerate(G.nodes):
    label = G.nodes[node_]["y"]

    if label == [0, 0]:
        masks.append(0)
    elif label == [0, 1]:
        masks.append(chromosome_weight)
    elif label == [1, 0]:
        masks.append(plasmid_weight)
    elif label == [1, 1]:
        masks.append(ambiguous_weight)


Chromosome contigs: 7160 Plasmid contigs: 2976 Ambiguous contigs: 963 Unlabelled contigs: 92


In [12]:
# 80% train 20% validate

masks_train = masks[0:int(len(masks)*0.8)] + [0]*(int(len(masks)*0.2) + 1)
masks_validate = [0]*int(len(masks)*0.8) + masks[int(len(masks)*0.8):]

masks_train = np.array(masks_train).astype(float)
masks_validate = np.array(masks_validate).astype(float)

In [13]:
print(len(masks_train))
print(len(masks_validate))

11191
11191


In [14]:
learning_rate = 0.005


model = plasgraph(n_labels=2, output_activation="sigmoid")
model.compile(optimizer=Adam(learning_rate), loss=BinaryCrossentropy(reduction="sum"))

loader_tr = SingleLoader(all_graphs, sample_weights=masks_train)
loader_va = SingleLoader(all_graphs, sample_weights=masks_validate)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=1000,
    callbacks=[EarlyStopping(patience=100, restore_best_weights=True)],
)


Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000
Epoch 68/1000
Epoch 69/1000
Epoch 70/1000
Epoch 71/1000
Epoch 72/1000
E

<keras.callbacks.History at 0x1dcc2f44588>

In [15]:
#model.save("plASgraph_model")
