In [2]:
import pyximport
pyximport.install()

import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sp
import dgl
from sklearn.cluster import AgglomerativeClustering

import DRBin
import DRBin.utils
from DRBin.models import DGI, LogReg
from DRBin import process
from DRBin.calculate_graph import *
from DRBin.eval import *
from DRBin.vMF_VAE import *

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import normalize
import scipy.sparse as sp

%matplotlib inline

Using backend: pytorch


In [6]:
with open('/home/maog/data/metahit/contigs.fna', 'rb') as filehandle:
    tnfs = DRBin.parsecontigs.read_contigs(filehandle)
rpkms = np.load('/home/maog/data/metahit/abundance.npz')
rpkms=rpkms['arr_0']
vae = DRBin.vMF_VAE.vMF_VAE(nsamples=rpkms.shape[1])
dataloader = DRBin.vMF_VAE.make_dataloader(rpkms, tnfs)
#vae.trainmodel(dataloader)
#latent = vae.encode(dataloader)
#np.savetxt('/home/maog/data/urog/latent.txt', latent)
latent = np.loadtxt('/home/maog/data/metahit/vMF-VAE_latent.txt')

In [8]:
latent = np.loadtxt('/home/maog/data/metahit/vMF-VAE_latent.txt')
u, v = calculate_graph(latent, marker_contigs, contig_id_idx)
g = dgl.graph((u, v))
knn_graph = g.adj(scipy_fmt='csr')
sp.save_npz('/home/maog/data/metahit/knngraph.npz', knn_graph)

In [11]:
u, v = calculate_negativate_graph(latent, marker_contigs, contig_id_idx)
g = dgl.graph((u, v))
knn_graph = g.adj(scipy_fmt='csr')
sp.save_npz('/home/maog/data/metahit/knn_neg_graph.npz', knn_graph)

In [None]:
latent = np.loadtxt('/home/maog/data/metahit/vMF-VAE_latent.txt')
# training params
batch_size = 1
nb_epochs = 300
patience = 20
lr = 0.001
l2_coef = 0.0
drop_prob = 0.3
hid_units = 32
a = 1e-64
sparse = True
nonlinearity = 'prelu' # special name to separate parameters
features = latent
adj = sp.load_npz('/home/maog/data/metahit/knngraph.npz')
adj_hat = sp.load_npz('/home/maog/data/metahit/knn_neg_graph.npz')

features = sp.csr_matrix(features)
features, _ = process.preprocess_features(features)
nb_nodes = features.shape[0]
ft_size = features.shape[1]
adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))
adj_hat = process.normalize_adj(adj_hat + sp.eye(adj_hat.shape[0]))

if sparse:
    sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
    sp_adj_hat = process.sparse_mx_to_torch_sparse_tensor(adj_hat)
else:
    adj = (adj + sp.eye(adj.shape[0])).todense()
    adj_hat = (adj_hat + sp.eye(adj_hat.shape[0])).todense()

features = torch.FloatTensor(features[np.newaxis])
if not sparse:
    adj = torch.FloatTensor(adj[np.newaxis])
model = DGI(ft_size, hid_units, nonlinearity)
optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)

if torch.cuda.is_available():
   
    torch.cuda.set_device(1)
    model.cuda()
    print('Using CUDA')
    features = features.cuda()
    if sparse:
        sp_adj = sp_adj.cuda()
        sp_adj_hat = sp_adj_hat.cuda()
    else:
        adj = adj.cuda()
        adj_hat = adj_hat.cuda()
b_xent = nn.BCEWithLogitsLoss()
xent = nn.CrossEntropyLoss()
cnt_wait = 0
best = 1e9
best_t = 0

for epoch in range(nb_epochs):
    model.train()
    optimiser.zero_grad()

    lbl_1 = torch.ones(batch_size, nb_nodes)
    lbl_2 = torch.zeros(batch_size, nb_nodes)
    lbl = torch.cat((lbl_1, lbl_2), 1)

    if torch.cuda.is_available():
        lbl = lbl.cuda()
    
    logits = model(features, sp_adj if sparse else adj, sp_adj_hat if sparse else adj_hat, sparse, None, None, None) 

    loss = xent(logits, lbl)

    print('loss:', loss)

    if loss < best:
        best = loss
        best_t = epoch
        cnt_wait = 0
#        torch.save(model.state_dict(), '/home/maog/data/metahit/best_dgi.pkl')
    else:
        cnt_wait += 1

    loss.backward()
    optimiser.step()

#model.load_state_dict(torch.load('/home/maog/data/metahit/best_dgi.pkl'))
embeds, _ = model.embed(features, sp_adj if sparse else adj, sparse, None)
embeds = embeds.squeeze(-3).cpu()
embeds = normalize(embeds)
embeds = embeds.numpy()
features = features.squeeze(-3)
features = features.cpu().numpy()
#get the final vector for clustering
X = ((1 - a) * features + a * embeds)
#np.savetxt('/home/maog/data/metahit/DRBin_latent.txt', X)

In [7]:
from DRBin.my_cluster import *
from Bio import SeqIO

In [8]:
contig_length = {}
contig_id_idx = {}
contig_idx_id = {}
contigs = '/home/maog/data/metahit/contigs.fna'
for record in SeqIO.parse(contigs, "fasta"):
    contig_length[record.id] = len(record.seq)
    contig_idx_id[len(contig_id_idx)] = record.id
    contig_id_idx[record.id] = len(contig_id_idx)

In [21]:
clusters = cluster_points(X)

In [22]:
for record in SeqIO.parse(contigs, "fasta"):
    contig_length[record.id] = len(record.seq)
    contig_idx_id[len(contig_id_idx)] = record.id
    contig_id_idx[record.id] = len(contig_id_idx)
filtered_bins, cluster_contig_id = filterclusters(clusters, contig_length, contig_idx_id)
import collections
cluster = dict()
cluster = collections.defaultdict(set)
for k, v in filtered_bins.items():
    for i in v:
        if k != -1:
            cluster["bins"+ str(k)].add(contig_idx_id[i])
len(cluster.keys())

316

In [78]:
# This writes a .tsv file with the clusters and corresponding sequences
with open('/home/maog/data/metahit/result/DRBin_cluster.tsv', 'w') as file:
    DRBin.utils.write_clusters(file, cluster)

# # Only keep contigs in any filtered bin in memory
# keptcontigs = set.union(*cluster.values())

# with open('/home/maog/data/urog/contigs.fna', 'rb') as file:
#     fastadict = DRBin.utils.loadfasta(file, keep=keptcontigs)
    
# bindir = '/home/maog/data/urog/result/bins'
# DRBin.utils.write_bins(bindir, cluster, fastadict, maxbins=1500)

In [5]:
# First load in the Reference
reference_path = '/home/maog/data/metahit/reference.tsv'

!head $reference_path # show first 10 lines of reference file

with open(reference_path) as reference_file:
    reference = DRBin.benchmark.Reference.from_file(reference_file)

gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_1-5871	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	1	5871
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_5841-8340	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	5841	8340
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_8310-10809	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	8310	10809
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_10779-29944	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	10779	29944
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_29914-33073	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	29914	33073
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_33043-41174	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	33043	41174
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococcus_D21_uid55871]_41144-44994	Acidaminococcus_D21_uid55871	NZ_ACGB01000001.1	41144	44994
gi|224815735|ref|NZ_ACGB01000001.1|_[Acidaminococc

In [6]:
taxonomy_path = '/home/maog/data/metahit/taxonomy.tsv'

!head $taxonomy_path # show first 10 lines of reference file

with open(taxonomy_path) as taxonomy_file:
    reference.load_tax_file(taxonomy_file)

Acidaminococcus_D21_uid55871	Acidaminococcus_D21_uid55871	Acidaminococcus
Acidaminococcus_fermentans_DSM_20731_uid43471	Acidaminococcus fermentans	Acidaminococcus
Acidaminococcus_intestini_RyC_MR95_uid74445	Acidaminococcus intestini	Acidaminococcus
Actinomyces_ICM47_uid170984	Actinomyces_ICM47_uid170984	Actinomyces
Adlercreutzia_equolifaciens_DSM_19450_uid223286	Adlercreutzia equolifaciens	Adlercreutzia
Aeromicrobium_JC14_uid199535	Aeromicrobium_JC14_uid199535	Aeromicrobium
Akkermansia_muciniphila_ATCC_BAA_835_uid58985	Akkermansia muciniphila	Akkermansia
Alcanivorax_hongdengensis_A_11_3_uid176602	Alcanivorax hongdengensis	Alcanivorax
Alistipes_AP11_uid199714	Alistipes_AP11_uid199714	Alistipes
Alistipes_HGB5_uid67587	Alistipes_HGB5_uid67587	Alistipes


In [7]:
with open('/home/maog/data/metahit/result/DRBin_cluster.tsv') as clusters_file:
    DRBin_clusters = DRBin.utils.read_clusters(clusters_file)
    DRBin_bins = DRBin.benchmark.Binning(DRBin_clusters, reference, minsize=100000)

In [8]:
print('DGRBin bins:')
for rank in DRBin_bins.summary():
    print('\t'.join(map(str, rank)))

DGRBin bins:
112	108	100	98	89	66	26	7	0
107	105	97	95	88	65	25	7	0
52	52	49	48	48	41	21	6	0


In [10]:
with open('/home/maog/data/metahit/vamb_cluster.tsv') as clusters_file:
    DRBin_clusters = DRBin.utils.read_clusters(clusters_file)
    DRBin_bins = DRBin.benchmark.Binning(DRBin_clusters, reference, minsize=100000)
print('vamb bins:')
for rank in DRBin_bins.summary():
    print('\t'.join(map(str, rank)))

vamb bins:
108	107	105	102	91	66	21	6	0
104	104	102	99	90	65	20	6	0
53	53	52	51	49	43	17	5	0
