# GNN  

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data, Dataset
from torch_geometric.data import NeighborSampler
# from torch_cluster import random_walk
# import torch_sparse

import os
import numpy as np
from tqdm import tqdm
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import SeqIO

### experiment path

In [6]:
exp = '../Results/zymo10_1000kto1100k/output/'

# truth = np.array(open(exp + "ground_truth.txt").read().strip().split("\n"))
data = np.load('../Results/zymo10_1000kto1100k/oblr/data.npz')
updated_clusters = np.load(exp + 'new_classes.npz')


### fetch data from files

In [7]:
edges = data['edges']
comp = data['scaled']
# test = data['read_cluster']
read_cluster = updated_clusters['classes']
read_indices = updated_clusters['classified']

# 1) present_of_markers = 

print(read_cluster.shape)
print(read_cluster)


(100000,)
[3 2 2 ... 0 2 2]


# GraphSAGE Steps

In [8]:
class SAGE(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_layers, device):
        super(SAGE, self).__init__()

        self.num_layers = num_layers
        hidden_channels = (in_channels + out_channels)//2

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        self.fc1 = torch.nn.Linear(hidden_channels, hidden_channels)
        self.fc2 = torch.nn.Linear(hidden_channels, out_channels)
        
        self.device = device
        
        self.to(device)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x_target = x[:size[1]]
            x = self.convs[i]((x, x_target), edge_index)
            x = F.relu(x)                
            x = F.dropout(x, p=0.2, training=self.training)
            
        x = self.fc1(x)
        embedding = x
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        
        x = self.fc2(x)
    
        return x.log_softmax(dim=-1), embedding

    def inference(self, x_all, subgraph_loader):   
        idx = []
        
        for i in range(self.num_layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                if i==0:
                    idx += list(n_id[:batch_size].numpy())
                edge_index, _, size = adj.to(self.device)
                x = x_all[n_id].to(self.device)
                x_target = x[:size[1]]
                x = self.convs[i]((x, x_target), edge_index)
                x = F.relu(x)
                xs.append(x)

            x_all = torch.cat(xs, dim=0)
        
        x = self.fc1(x_all)
        x = F.relu(x)        
        x = self.fc2(x)
        
        x = x.cpu()

        return np.array(idx), x

In [9]:
def train(model, x, y, optimizer, train_loader, device):
    model.train()
    total_loss = 0
    
    for batch_size, n_id, adjs in train_loader:
        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.
        adjs = [adj.to(device) for adj in adjs]
        
        optimizer.zero_grad()
        out, embd = model(x[n_id], adjs)
        
        loss = F.nll_loss(out, y[n_id[:batch_size]])
                
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss)

    loss = total_loss / len(train_loader)

    return loss


@torch.no_grad()
def test(model, x, subgraph_loader):
    model.eval()

    out = model.inference(x, subgraph_loader)
    return out

In [10]:
def get_graph_data(features, edges):
    edge_index = torch.tensor(edges, dtype=torch.long)
    print(edge_index)
    data = Data(x=torch.tensor(features).float(), edge_index=edge_index.t().contiguous())
    print(data)
    
    return data

data = get_graph_data(comp, edges)


tensor([[    2,  3395],
        [    2, 50755],
        [    2, 57235],
        ...,
        [93996, 47238],
        [97240, 45621],
        [97312, 45820]])
Data(x=[100000, 136], edge_index=[2, 144163])


In [11]:
def get_train_data(features, read_cluster):

    # Extract indices where arr[index] == -1
    train_idx = np.where(read_cluster != -1)[0]
    train_idx = torch.LongTensor(train_idx)
    print(train_idx.shape)
    
    y = torch.LongTensor(read_cluster)
    print(y)
    
    no_classes = len(set(read_cluster)) - 1 # removed one for the cluster mentioned as -1
    print(no_classes)
    return train_idx, y, no_classes

train_idx, y, no_classes = get_train_data(comp, read_cluster)

torch.Size([97412])
tensor([3, 2, 2,  ..., 0, 2, 2])
5


### prepping models and data

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SAGE(data.x.shape[1], no_classes, 2, device)

In [13]:
print("Using the device", device)

Using the device cpu


In [14]:
print("Model summary\n-------------\n", model, "\n-------------")

Model summary
-------------
 SAGE(
  (convs): ModuleList(
    (0): SAGEConv(136, 70, aggr=mean)
    (1): SAGEConv(70, 70, aggr=mean)
  )
  (fc1): Linear(in_features=70, out_features=70, bias=True)
  (fc2): Linear(in_features=70, out_features=5, bias=True)
) 
-------------


In [15]:
from torch_sparse import SparseTensor
train_loader = NeighborSampler(data.edge_index, 
                               node_idx=train_idx,
                               sizes=[50, 50], 
                               batch_size=64,
                               pin_memory=True,
                               shuffle=True, 
                               drop_last=True, 
                               num_workers=8)

subgraph_loader = NeighborSampler(data.edge_index, 
                                  node_idx=None, 
                                  sizes=[100],
                                  pin_memory=True,
                                  batch_size=10240, 
                                  shuffle=False,
                                  num_workers=8)



In [16]:
x = data.x.to(device) #comp
y = y.to(device) #cluster numbers
print(x.shape)
print(y.shape)

torch.Size([100000, 136])
torch.Size([100000])


In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=10e-6)

In [18]:
%matplotlib notebook
import matplotlib.pyplot as plt

epochs = 100

# for plotting
fig = plt.figure()
ax = fig.add_subplot(111)
ax.set_xlim(-1, 201)
ax.set_ylim(0, 1)
plt.ion()
fig.show()
fig.canvas.draw()
losses = []
prev_loss = 100

for epoch in range(1, epochs+1):
    loss = train(model, x, y, optimizer, train_loader, device)
    dloss = prev_loss-loss
    prev_loss = loss
    # plot params
    losses.append(loss)    
    ax.clear()
    ax.plot(losses)
    ax.set_xlim(-1, 101)
    ax.set_ylim(0, 1)
    fig.canvas.draw()
    
    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}', end="\r", flush=True)
    
    if loss < 0.05:
        print()
        print('Early stopping, loss less than 0.05')
        break
    
print()

<IPython.core.display.Javascript object>

Epoch 07, Loss: 0.0499
Early stopping, loss less than 0.05



In [19]:
# we must keep track of classified ids, since lonely nodes are not classified
idx, preds = test(model, x, subgraph_loader) 

In [20]:
classes = torch.argmax(preds, axis=1)

### final result to be carried forward

This consits of 

* **classified:** Index of the read (0 to N-1, where N is number of reads)
* **classes:** Class/bin of the read


In [21]:
np.savez(exp +'updated_classes.npz', classes=classes.numpy(), classified=idx)

### Evaluation if truth available

In [None]:
classification = np.load(exp + 'updated_classes.npz')

classes = classification['classes']
idx = classification['classified']
no_classes = len(set(classes))
spec_set = set(truth) - {'Unknown'}


matrix = np.zeros((len(spec_set), no_classes))
spec_idx = {s:n for n, s in enumerate(spec_set)}
idx_spec = {n:s for n, s in enumerate(spec_set)}

for n, (c, t) in tqdm(enumerate(zip(classes, truth[idx]))):
    if t == 'Unknown':
        continue
    matrix[spec_idx[t], c] += 1
    
tot = matrix.sum()
row_sum = matrix.max(0).sum()
col_sum = matrix.max(1).sum()

print(matrix.shape)
p, r = 100*row_sum/tot, 100*col_sum/tot
f1 = 2 * p * r / (p + r)

print(f'Precision  =  {p:3.2f}')
print(f'Recall     =  {r:3.2f}')
print(f'F1-score   =  {f1:3.2f}')

### separate reads for binning purposes

In [None]:
if os.path.isdir(exp + 'binned_reads'):
    shutil.rmtree(exp + 'binned_reads')
os.mkdir(exp + 'binned_reads')

bin_file = {}

for c in set(classes):
    if not os.path.isdir(exp + f'binned_reads/bin-{c}'):
        os.mkdir(exp + f'binned_reads/bin-{c}')
    bin_file[c] = open(exp + f'binned_reads/bin-{c}/reads.fasta', 'w+') 

In [None]:
idx_class = {i:c for i,c in zip(idx, classes)}

for n, record in tqdm(enumerate(SeqIO.parse(exp + 'reads.fasta', "fasta"))):
    if n in idx_class:
        bin_file[idx_class[n]].write(f'>{str(record.id)}\n{str(record.seq)}\n')

for c in set(classes):
    bin_file[c].close()