# Comparative Encoder

## 5:34 PM 6.9.22

In [87]:
import matplotlib.pyplot as plt
plt.style.use('dark_background')
import tensorflow as tf
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

In [88]:
for i in plt.get_fignums():
    plt.close(i)

## Load SILVA Dataset

In [89]:
from Bio import SeqIO
import numpy as np
from tqdm.notebook import tqdm
s = np.array([record for record in tqdm(SeqIO.parse('silva.fasta', "fasta"))], dtype=object)

0it [00:00, ?it/s]

In [90]:
import multiprocessing as mp
from tqdm.notebook import tqdm
def fn(i):
    return np.array(list(str(i.seq)[:300]))
with mp.Pool() as p:
    string_seqs = np.array(list(tqdm(p.imap(fn, s, chunksize=100), total=s.shape[0])))

  0%|          | 0/227331 [00:00<?, ?it/s]

In [91]:
BASES = ['A', 'U', 'G', 'C']
def fn(i):
    enc_seq = np.empty((300, 5), dtype=np.intc)
    for bp in range(string_seqs.shape[1]):
        idx = BASES.index(i[bp]) if i[bp] in BASES else 4
        enc_seq[bp] = [1 if j == idx else 0 for j in range(5)]
    return enc_seq
with mp.Pool() as p:
    seqs = np.asarray(list(tqdm(p.imap(fn, string_seqs, chunksize=100), total=string_seqs.shape[0])))

  0%|          | 0/227331 [00:00<?, ?it/s]

In [92]:
desc = np.array([i.description.split(' ')[1] for i in s])
num_items = np.vectorize(lambda i: len(i.split(';')))(desc)
parsable = num_items == 7
raw_tax = desc[parsable]
tax = []
for i in raw_tax:
    tax.append(i.split(';'))
tax = np.array(tax)
seqs = seqs[parsable]

## Preprocessing

In [93]:
codes = BASES + ['N']
def to_str(s):
    return ''.join(codes[i] for i in s)

def decode(sample):
    str_seq_samp = []
    for i in tqdm(sample):
        str_seq_samp.append(to_str(i))
    return np.asarray(str_seq_samp)
str_seqs = decode(seqs.argmax(axis=-1))

  0%|          | 0/180516 [00:00<?, ?it/s]

In [94]:
from sklearn.model_selection import train_test_split
seqs_train, seqs_val, str_seqs_train, str_seqs_val = train_test_split(seqs, str_seqs, test_size=.01)

In [95]:
pairs = np.array(np.meshgrid(np.arange(seqs_val.shape[0]), np.arange(seqs_val.shape[0]))).T.reshape(-1, 2)
val_x1 = seqs_val[pairs[:, 0]]
val_x2 = seqs_val[pairs[:, 1]]

In [96]:
from Bio import pairwise2
def dissimilarity(pair):
    return (1 / (pairwise2.align.localxx(str_seqs_val[pair[0]], str_seqs_val[pair[1]], score_only=True) / 300)) - 1

import multiprocessing
with multiprocessing.Pool() as p:
    val_labels = np.array(list(tqdm(p.imap(dissimilarity, pairs, chunksize=1000), total=pairs.shape[0])))

  0%|          | 0/3261636 [00:00<?, ?it/s]

## Model Definition

In [97]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [98]:
class DistanceLayer(layers.Layer):
    """
    This layer is responsible for computing the distance between the anchor
    embedding and the positive embedding, and the anchor embedding and the
    negative embedding.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, a, b):
        return tf.reduce_sum(tf.square(a - b), -1)

In [111]:
import tensorflow as tf
from tensorflow.keras import layers

mirrored_strategy = tf.distribute.MirroredStrategy()

with mirrored_strategy.scope():
    inputs = layers.Input((300, 5))
    den = layers.Dense(50)(inputs)
    res = layers.Reshape((100, 50 * 3))(den)
    
    conv = layers.Conv1D(20, 3)(res)
    maxpool = layers.MaxPooling1D()(conv)
    res2 = layers.Flatten()(maxpool)
    norm = layers.BatchNormalization()(res2)
    den = layers.Dense(100 * 50 * 3)(norm)
    res = layers.Reshape((100, 50 * 3))(den)

    trans = TransformerBlock(50 * 3, 4, 100)(res)
    norm = layers.BatchNormalization()(trans)

    trans = TransformerBlock(50 * 3, 4, 100)(norm)
    norm = layers.BatchNormalization()(trans)

    conv = layers.Conv1D(20, 3)(norm)
    maxpool = layers.MaxPooling1D()(conv)
    res2 = layers.Flatten()(maxpool)
    norm = layers.BatchNormalization()(res2)

    out = layers.Dense(2)(norm)

    embeddings = tf.keras.Model(inputs=inputs, outputs=out)
embeddings.summary()

Model: "model_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 300, 5)]          0         
                                                                 
 dense_28 (Dense)            (None, 300, 50)           300       
                                                                 
 reshape_10 (Reshape)        (None, 100, 150)          0         
                                                                 
 conv1d_7 (Conv1D)           (None, 98, 20)            9020      
                                                                 
 max_pooling1d_7 (MaxPooling  (None, 49, 20)           0         
 1D)                                                             
                                                                 
 flatten_8 (Flatten)         (None, 980)               0         
                                                           

In [112]:
from keras import backend as K
def correlation_coefficient_loss(y_true, y_pred):
    x = y_true
    y = y_pred
    mx = K.mean(x)
    my = K.mean(y)
    xm, ym = x-mx, y-my
    r_num = K.sum(tf.multiply(xm,ym))
    r_den = K.sqrt(tf.multiply(K.sum(K.square(xm)), K.sum(K.square(ym))))
    r = r_num / r_den

    r = K.maximum(K.minimum(r, 1.0), -1.0)
    return 1 - K.square(r)

def combined_loss(y_true, y_pred):
    return correlation_coefficient_loss(y_true, y_pred) + tf.keras.losses.MeanSquaredError(
        tf.keras.losses.Reduction.NONE)(y_true, y_pred)

with mirrored_strategy.scope():
    inputa = layers.Input((300, 5), name='input_a')
    inputb = layers.Input((300, 5), name='input_b')
    distances = DistanceLayer()(
        embeddings(inputa),
        embeddings(inputb),
    )
    siamese_network = tf.keras.Model(inputs=[inputa, inputb], outputs=distances)
    siamese_network.compile(optimizer='adam',
                    loss=combined_loss,
                    metrics=[tf.keras.metrics.MeanAbsoluteError()])
siamese_network.summary()

Model: "model_10"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_a (InputLayer)           [(None, 300, 5)]     0           []                               
                                                                                                  
 input_b (InputLayer)           [(None, 300, 5)]     0           []                               
                                                                                                  
 model_9 (Functional)           (None, 2)            15529942    ['input_a[0][0]',                
                                                                  'input_b[0][0]']                
                                                                                                  
 distance_layer_4 (DistanceLaye  (None,)             0           ['model_9[0][0]',         

### Load Model

In [113]:
custom_objects = {'combined_loss': combined_loss}
with tf.keras.utils.custom_object_scope(custom_objects):
    siamese_network = tf.keras.models.load_model('Models/comparative_encoder/full_model')
    embeddings = tf.keras.models.load_model('Models/comparative_encoder/encoder')

## Training

In [109]:
from Bio import pairwise2
def dissimilarity(pair):
    return (1 / (pairwise2.align.localxx(pair[0], pair[1], score_only=True) / 300)) - 1

from sklearn.model_selection import train_test_split
rng = np.random.default_rng()
def randomized_epoch(data, str_data):
    x1, x2, x1_str, x2_str = train_test_split(data, str_data, test_size=.5, random_state=0)
    
    import multiprocessing
    with multiprocessing.Pool() as p:
        y = np.array(list(tqdm(p.imap(dissimilarity, zip(x1_str, x2_str), chunksize=1000), total=x1_str.shape[0])))
    
    train_data = tf.data.Dataset.from_tensor_slices(({'input_a': x1, 'input_b': x2}, y))
    train_data = train_data.batch(1000)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_data = train_data.with_options(options)
    
    siamese_network.fit(train_data, epochs=1)

def validate():
    val_data = tf.data.Dataset.from_tensor_slices((
        {'input_a': val_x1[:100000], 'input_b': val_x2[:100000]},))
    val_data = val_data.batch(1000)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    val_data = val_data.with_options(options)
    pred = siamese_network.predict(val_data)
    return np.corrcoef(pred, val_labels[:100000])[0, 1] ** 2, ((pred - val_labels[:100000]) ** 2).mean()

def train(epochs):
    for i in range(epochs):
        print(f'Epoch {i + 1}:')
        randomized_epoch(seqs_train, str_seqs_train)
        val_r, val_mse = validate()
        print(f'val_mse: {val_mse}; val_r2: {val_r}')
        siamese_network.save('Models/comparative_encoder/full_model')
        embeddings.save('Models/comparative_encoder/encoder')

In [110]:
train(100)

Epoch 1:


  0%|          | 0/89355 [00:00<?, ?it/s]

val_mse: 0.1130027553179244; val_r2: 0.003927908216945577
Epoch 2:


  0%|          | 0/89355 [00:00<?, ?it/s]

val_mse: 0.11232928356276786; val_r2: 0.004199364455520755
Epoch 3:


  0%|          | 0/89355 [00:00<?, ?it/s]

val_mse: 0.11195188866779204; val_r2: 0.007479852569580837
Epoch 4:


  0%|          | 0/89355 [00:00<?, ?it/s]

val_mse: 0.11200208549130619; val_r2: 0.0064317774553822185
Epoch 5:


  0%|          | 0/89355 [00:00<?, ?it/s]

val_mse: 0.11189072760260853; val_r2: 0.010331363003440477
Epoch 6:


  0%|          | 0/89355 [00:00<?, ?it/s]



KeyboardInterrupt: 

In [23]:
embeddings.trainable = False
siamese_network.trainable = False

## Evaluation on SILVA

In [24]:
seq_reps = embeddings.predict(seqs)

2022-06-11 15:59:10.729210: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_3933676"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024FlatMapDataset:10804"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_



### Correlation Plot

In [28]:
import random
from scipy.spatial.distance import euclidean
codes = BASES + ['N']
def to_str(s):
    return ''.join(codes[i] for i in s)
def evaluate():
    a = random.randint(0, seq_reps.shape[0] - 1)
    b = random.randint(0, seq_reps.shape[0] - 1)
    pred = euclidean(seq_reps[a], seq_reps[b])
    first = to_str(seqs[a].argmax(-1))
    second = to_str(seqs[b].argmax(-1))
    score = dissimilarity((first, second))
    return [score, pred]
results = np.asarray([evaluate() for i in range(10000)])
f = plt.figure(figsize=(8, 6))
plt.scatter(results[:, 0], results[:, 1], alpha=.2)
plt.xlabel('Pairwise Similarity Score (Smith-Waterman)')
plt.ylabel('Euclidean Distance Between Encodings')
plt.title('Correlation Plot of Model Encodings')
plt.savefig('Results/it1/silva/eval/comparative_enc_eval.png')

In [29]:
np.corrcoef(results[:, 0], results[:, 1])

array([[1.        , 0.81438758],
       [0.81438758, 1.        ]])

In [30]:
np.mean((results[:, 0] - results[:, 1]) ** 2)

0.07276923042824604

### SILVA Clustering

In [31]:
import matplotlib.pyplot as plt
x, y = np.array(seq_reps).T
f = plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=.1, marker='o')
plt.title("Encoded Representations of SILVA Database")
plt.savefig('Results/it1/silva/silva_all.png')

#### Domain

In [32]:
archaea = seq_reps[tax[:, 0] == 'Archaea']
bacteria = seq_reps[tax[:, 0] == 'Bacteria']
eukaryota = seq_reps[tax[:, 0] == 'Eukaryota']

In [33]:
f = plt.figure(figsize=(8, 6))
plt.scatter(archaea[:, 0], archaea[:, 1], alpha=.4)
plt.scatter(bacteria[:, 0], bacteria[:, 1], alpha=.015)
plt.scatter(eukaryota[:, 0], eukaryota[:, 1], alpha=.4)
leg = plt.legend(['Archaea', 
                  'Bacteria',
                  'Eukaryota'],
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.title('Encoded Representations of the Domains of Life')
plt.savefig('Results/it1/silva/silva_domains.png')

#### Phylum (Bacteria)

In [34]:
a, b = np.unique(tax[tax[:, 0] == 'Bacteria'][:, 1], return_counts=True)
genuses = a[b > 1000]

In [35]:
plottable = np.isin(tax[:, 1], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[tax[plottable][:, 1] == genuses[i], i] = 1
plottable_seqs = seq_reps[plottable]

In [36]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.3, marker='o')
ax.set_title("Encoded Representations of Phyla of Bacteria")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/silva/silva_phylum.png')

#### Class (Proteobacteria)

In [37]:
a, b = np.unique(tax[tax[:, 1] == 'Proteobacteria'][:, 2], return_counts=True)
genuses = a[b > 500]

In [38]:
plottable = np.isin(tax[:, 2], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[tax[plottable][:, 2] == genuses[i], i] = 1
plottable_seqs = seq_reps[plottable]

In [39]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.3, marker='o')
ax.set_title("Encoded Representations of Classes of Proteobacteria")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/silva/silva_class.png')

#### Order (Alphaproteobacteria)

In [40]:
a, b = np.unique(tax[tax[:, 2] == 'Alphaproteobacteria'][:, 3], return_counts=True)
genuses = a[b > 500]

In [41]:
plottable = np.isin(tax[:, 3], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[tax[plottable][:, 3] == genuses[i], i] = 1
plottable_seqs = seq_reps[plottable]

In [42]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.3, marker='o')
ax.set_title("Encoded Representations of Orders of Alphaproteobacteria")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/silva/silva_order.png')

#### Genus (Rhizobiales Rhizobiaceae)

In [43]:
a, b = np.unique(tax[tax[:, 4] == 'Rhizobiaceae'][:, 5], return_counts=True)
genuses = a[b > 500]

In [44]:
plottable = np.isin(tax[:, 5], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[tax[plottable][:, 5] == genuses[i], i] = 1
plottable_seqs = seq_reps[plottable]

In [45]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.5, marker='o')
ax.set_title("Encoded Representations of Orders of Alphaproteobacteria")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/silva/silva_genus.png')

# ANC Data

## Preprocessing

### Load

In [46]:
label_map = {
    "S002": "black",
    "S001": "grey",
    "S003": "unpigmented"
}
def read_ion_reporter(path):
    import os
    import pandas as pd
    import numpy as np
    from Bio import SeqIO
    from tqdm.notebook import tqdm
    import re

    paths = np.array([])  # Get paths to all fasta files
    for root, dirs, files in os.walk(path):
        paths = np.append(paths, [os.path.join(root, i) for i in files if i.endswith(".fasta")])

    def read(i):  # Function to parse a fasta file and get all records
        s = np.array([record for index, record in enumerate(SeqIO.parse(i, "fasta"))] + [0], dtype=object)
        s = s[:-1]  # We append 0 and remove it so that single element arrays get parsed as such
        l = np.full_like(s, label_map[i.split("/")[-1][:4]])
        return s, l

    seqs, labels, desc = np.array([]), np.array([]), np.array([])
    for i in paths:  # Parse all sequences, generate labels
        s, l = read(i)
        seqs = np.append(seqs, s)
        labels = np.append(labels, l)
    # flattened_seqs = np.concatenate(seqs)  # Flatten because we don't care about which sample the data came from
    string_seqs = np.vectorize(lambda i: str(i.seq))(seqs)  # Convert to strings
    descriptions = np.vectorize(lambda i: i.description)(seqs)

    return string_seqs, labels, descriptions

arlington, arlington_labels, arlington_desc = read_ion_reporter("Data/Arlington Processed")

### Header Parsing

In [47]:
import re
def get_size(s: str):
    parts = s.split('|')
    idx = -1
    for i in range(len(parts)):
        if re.search(r'\.', parts[i]):
            idx = i
            break
    if idx == -1:
        return 0
    return len(parts[i + 1:])

import numpy as np
arlington_known = np.vectorize(get_size)(arlington_desc) != 0
arlington = arlington[arlington_known]
arlington_labels = arlington_labels[arlington_known]
arlington_desc = arlington_desc[arlington_known]

def get_tax(s: str):
    parts = s.split('|')
    for i in range(len(parts)):
        if re.search(r'\.', parts[i]):
            break
    tax = [x.strip('[]') for x in parts[i + 1:-1]]
    unc = ["UNKNOWN"] * 6
    cond = ['/' not in i and 'sp.' not in i for i in tax]
    return np.where(cond, tax, unc)

arlington_tax = np.empty((arlington_desc.shape[0], 6), dtype=object)
for i in range(arlington_desc.shape[0]):
    arlington_tax[i, :] = get_tax(arlington_desc[i])

def get_conf(s: str):
    parts = s.split('|')
    for i in range(len(parts)):
        if re.search(r'\.', parts[i]):
            break
    return float(parts[i])

arlington_conf = np.vectorize(get_conf)(arlington_desc)

def get_abund(s: str):
    parts = s.split('|')
    for i in range(len(parts)):
        if re.search(r'\.', parts[i]):
            break
    return int(parts[i - 1])

arlington_abund = np.vectorize(get_abund)(arlington_desc)

### Sequence Encoding

In [48]:
def preprocess_reads(x):
    from tqdm.notebook import tqdm
    from sklearn import preprocessing
    import numpy as np

    BASES = ["A", "T", "G", "C"]

    LENGTH = 300

    # mask = np.vectorize(len)(x) >= LENGTH
    seqs = np.vectorize(lambda i: i[:LENGTH])(x)

    print("Encoding sequences...")
    # Sequence encoding
    def encode_seq(seq: str):  # Function to encode a sequence using one-hot encoding
        idx = [BASES.index(char) for char in seq]
        return [[(1 if j == i else 0) for j in range(5)] for i in idx]
    # Time consuming step, vectorize if possible!!!
    final_seqs = np.empty((seqs.shape[0], LENGTH, 5), dtype=np.int32)  # Init empty array
    for i in tqdm(range(len(seqs))):  # Iterate over sequences
        encoded = encode_seq(seqs[i][:LENGTH])  # Encode each sequence
        final_seqs[i] = np.concatenate([np.zeros((LENGTH - len(encoded), 5)), encoded])

    return final_seqs

arlington_processed = preprocess_reads(arlington)

Encoding sequences...


  0%|          | 0/17142 [00:00<?, ?it/s]

### Variable Region Separation

#### Preprocessing

In [49]:
def get_variable_region(i):  # Function to get the variable region from the record's description
    header_parts = i.split("|")  # Header is | delimited
    variable_region = "UNKNOWN"  # Default variable region
    for val in header_parts:  # Iterate over the header parts
        if result := re.search(r"^V\d+", val):  # Match any pattern starting with a V and ending with digits
            variable_region = val  # Set variable region to the text in the matched part of the header
            break
    return variable_region

import numpy as np
v_regions = np.vectorize(get_variable_region)(arlington_desc)

known_mask = v_regions != 'UNKNOWN'
known_seqs, known_labels = arlington_processed[known_mask], v_regions[known_mask]

from sklearn.preprocessing import LabelBinarizer
bn = LabelBinarizer()
v_regions_enc = bn.fit_transform(known_labels)

#### NN Classifier

In [50]:
import tensorflow as tf
inputs = tf.keras.layers.Input((300, 5))
hidden = tf.keras.layers.Dense(300, activation='relu')(inputs)
flat = tf.keras.layers.Flatten()(hidden)
outputs = tf.keras.layers.Dense(v_regions_enc.shape[1], activation='softmax')(flat)

v_region_classifier = tf.keras.Model(inputs=inputs, outputs=outputs)
v_region_classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
v_region_classifier.summary()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 300, 5)]          0         
                                                                 
 dense_11 (Dense)            (None, 300, 300)          1800      
                                                                 
 flatten_4 (Flatten)         (None, 90000)             0         
                                                                 
 dense_12 (Dense)            (None, 6)                 540006    
                                                                 
Total params: 541,806
Trainable params: 541,806
Non-trainable params: 0
_________________________________________________________________


In [51]:
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(known_seqs, v_regions_enc, test_size=.2)
import tensorflow as tf
v_region_classifier.fit(x_train, y_train, validation_data=(x_val, y_val),
                       callbacks=[tf.keras.callbacks.EarlyStopping(patience=1, monitor='val_loss')],
                       batch_size=100,
                       epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100


<keras.callbacks.History at 0x14c2493562b0>

#### Classify All Seqs

In [53]:
unknown_labels = v_region_classifier.predict(arlington_processed[~known_mask], verbose=1)
int_predictions = unknown_labels.argmax(axis=1)
predictions = np.vectorize(lambda i: bn.classes_[i])(int_predictions)
certainty = unknown_labels.max(axis=1)
print(certainty.shape, np.nonzero(certainty > .99)[0].shape)  # We can classify almost everything with 99% confidence

v_region_labels = np.full_like(v_regions, 'UNKNOWN')
v_region_labels[known_mask] = known_labels
a = v_region_labels[~known_mask]
a[certainty > .99] = predictions[certainty > .99]
v_region_labels[~known_mask] = a
print(v_region_labels[v_region_labels != 'UNKNOWN'].shape[0] / v_region_labels.shape[0])
# We successfully classified about 99% of our unknown data

(8571,) (8545,)
0.9984832574962081


In [54]:
import matplotlib.pyplot as plt
# plt.style.use('default')
labels, sizes = np.unique(v_region_labels, return_counts=True)
# s = np.argsort(sizes)
# labels, sizes = labels[s], sizes[s]
fig, ax = plt.subplots(figsize=(16, 12))
ax.pie(sizes, labels=labels, startangle=90, autopct='%1.1f%%', explode=(.7, 0, 0, 0, 0, 0, .5))
ax.axis('equal')
ax.set_title('Variable Region Breakdown After Classification\n')
plt.savefig('ANC_vregion_breakdown.png')
# plt.style.use('dark_background')

### Prep for Model

In [55]:
codes = BASES + ['N']
def to_str(s):
    return ''.join(codes[i] for i in s)

def decode(sample):
    str_seq_samp = []
    for i in tqdm(sample):
        str_seq_samp.append(to_str(i))
    return np.asarray(str_seq_samp)
anc_str_seqs = decode(arlington_processed.argmax(axis=-1))

  0%|          | 0/17142 [00:00<?, ?it/s]

In [56]:
from sklearn.model_selection import train_test_split
anc_seqs_train, anc_seqs_val, anc_str_seqs_train, anc_str_seqs_val = \
    train_test_split(arlington_processed, anc_str_seqs, test_size=.01)

In [57]:
pairs = np.array(np.meshgrid(np.arange(anc_seqs_val.shape[0]), np.arange(anc_seqs_val.shape[0]))).T.reshape(-1, 2)
anc_val_x1 = anc_seqs_val[pairs[:, 0]]
anc_val_x2 = anc_seqs_val[pairs[:, 1]]

In [58]:
from Bio import pairwise2
def dissimilarity(pair):
    return (1 / (pairwise2.align.localxx(
        anc_str_seqs_val[pair[0]], anc_str_seqs_val[pair[1]], score_only=True) / 300)) - 1

anc_val_labels = np.array([dissimilarity(i) for i in tqdm(pairs)])

  0%|          | 0/29584 [00:00<?, ?it/s]

## Fine-Tuning

### Model Definition

In [60]:
import tensorflow as tf
from tensorflow.keras import layers
with mirrored_strategy.scope():
    anc_input = layers.Input((300, 5))
    res = layers.Reshape((5, 300))(anc_input)
    mid = layers.Dense(300)(res)
    mid = layers.Dense(300)(mid)
    res = layers.Reshape((300, 5))(mid)
    embed = embeddings(res)

    fine_embeddings = tf.keras.Model(inputs=anc_input, outputs=embed)

fine_embeddings.summary()

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 300, 5)]          0         
                                                                 
 reshape_5 (Reshape)         (None, 5, 300)            0         
                                                                 
 dense_13 (Dense)            (None, 5, 300)            90300     
                                                                 
 dense_14 (Dense)            (None, 5, 300)            90300     
                                                                 
 reshape_6 (Reshape)         (None, 300, 5)            0         
                                                                 
 model (Functional)          (None, 2)                 15529942  
                                                                 
Total params: 15,710,542
Trainable params: 180,600
Non-trai

In [61]:
from keras import backend as K
def correlation_coefficient_loss(y_true, y_pred):
    x = y_true
    y = y_pred
    mx = K.mean(x)
    my = K.mean(y)
    xm, ym = x-mx, y-my
    r_num = K.sum(tf.multiply(xm,ym))
    r_den = K.sqrt(tf.multiply(K.sum(K.square(xm)), K.sum(K.square(ym))))
    r = r_num / r_den

    r = K.maximum(K.minimum(r, 1.0), -1.0)
    return 1 - K.square(r)

def combined_loss(y_true, y_pred):
    return correlation_coefficient_loss(y_true, y_pred) + tf.keras.losses.MeanSquaredError(
        tf.keras.losses.Reduction.NONE)(y_true, y_pred)

with mirrored_strategy.scope():
    inputa = layers.Input((300, 5), name='input_a')
    inputb = layers.Input((300, 5), name='input_b')
    distances = DistanceLayer()(
        fine_embeddings(inputa),
        fine_embeddings(inputb),
    )
    fine_siamese = tf.keras.Model(inputs=[inputa, inputb], outputs=distances)
    fine_siamese.compile(optimizer='adam',
                    loss=combined_loss,
                    metrics=[tf.keras.metrics.MeanAbsoluteError()])
fine_siamese.summary()

Model: "model_4"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_a (InputLayer)           [(None, 300, 5)]     0           []                               
                                                                                                  
 input_b (InputLayer)           [(None, 300, 5)]     0           []                               
                                                                                                  
 model_3 (Functional)           (None, 2)            15710542    ['input_a[0][0]',                
                                                                  'input_b[0][0]']                
                                                                                                  
 distance_layer_1 (DistanceLaye  (None,)             0           ['model_3[0][0]',          

### Training

In [62]:
from Bio import pairwise2
def dissimilarity(pair):
    return (1 / (pairwise2.align.localxx(pair[0], pair[1], score_only=True) / 300)) - 1

from sklearn.model_selection import train_test_split
rng = np.random.default_rng()
def randomized_epoch(data, str_data):
    x1, x2, x1_str, x2_str = train_test_split(data, str_data, test_size=.5, random_state=0)
    
    y = np.array([dissimilarity(i) for i in tqdm(zip(x1_str, x2_str), total=x1_str.shape[0])])
#     import multiprocessing
#     with multiprocessing.Pool() as p:
#         y = np.array(list(tqdm(p.imap(dissimilarity, zip(x1_str, x2_str), chunksize=1000), total=x1_str.shape[0])))
    
    train_data = tf.data.Dataset.from_tensor_slices(({'input_a': x1, 'input_b': x2}, y))
    train_data = train_data.batch(100)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_data = train_data.with_options(options)
    
    fine_siamese.fit(train_data, epochs=1)

def validate():
    val_data = tf.data.Dataset.from_tensor_slices((
        {'input_a': anc_val_x1[:10000], 'input_b': anc_val_x2[:10000]},))
    val_data = val_data.batch(100)
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    val_data = val_data.with_options(options)
    pred = fine_siamese.predict(val_data)
    return np.corrcoef(pred, anc_val_labels[:10000])[0, 1] ** 2, ((pred - anc_val_labels[:10000]) ** 2).mean()

def train(epochs):
    for i in range(epochs):
        print(f'Epoch {i + 1}:')
        randomized_epoch(anc_seqs_train, anc_str_seqs_train)
        val_r, val_mse = validate()
        print(f'val_mse: {val_mse}; val_r2: {val_r}')

In [63]:
train(100)

Epoch 1:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.06252066376104262; val_r2: 0.6816427546050022
Epoch 2:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.0610471633910062; val_r2: 0.7107947423946692
Epoch 3:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.052737997422355014; val_r2: 0.7061911952433837
Epoch 4:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.059985276111745575; val_r2: 0.7226272573397204
Epoch 5:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.062457244023415855; val_r2: 0.725667252024425
Epoch 6:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.050611884612311746; val_r2: 0.7218607896530376
Epoch 7:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.055272162841072794; val_r2: 0.725200961371239
Epoch 8:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.05967544963187861; val_r2: 0.7193539895501461
Epoch 9:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.05736218615570365; val_r2: 0.7209011743586109
Epoch 10:


  0%|          | 0/8485 [00:00<?, ?it/s]

val_mse: 0.05674700007609547; val_r2: 0.7255925506486571
Epoch 11:


  0%|          | 0/8485 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [64]:
fine_siamese.save('Models/comparative_encoder/full_model_anc')
fine_embeddings.save('Models/comparative_encoder/encoder_anc')



### Evaluation

In [65]:
anc_seq_reps = fine_embeddings.predict(arlington_processed)

2022-06-11 16:07:02.106742: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Did not find a shardable source, walked to a node which is not a dataset: name: "FlatMapDataset/_9"
op: "FlatMapDataset"
input: "PrefetchDataset/_8"
attr {
  key: "Targuments"
  value {
    list {
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: -2
  }
}
attr {
  key: "f"
  value {
    func {
      name: "__inference_Dataset_flat_map_slice_batch_indices_4047434"
    }
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\024FlatMapDataset:11455"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
        dim {
          size: -1
        }
      }
    }
  }
}
attr {
  key: "output_types"
  value {
    list {
      type: DT_INT64
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_



#### Correlation Plot

In [68]:
import random
from scipy.spatial.distance import euclidean
codes = BASES + ['N']
def to_str(s):
    return ''.join(codes[i] for i in s)
def evaluate():
    a = random.randint(0, anc_seq_reps.shape[0] - 1)
    b = random.randint(0, anc_seq_reps.shape[0] - 1)
    pred = euclidean(anc_seq_reps[a], anc_seq_reps[b])
    first = to_str(arlington_processed[a].argmax(-1))
    second = to_str(arlington_processed[b].argmax(-1))
    score = dissimilarity((first, second))
    return [score, pred]
results = np.asarray([evaluate() for i in tqdm(range(10000))])
f = plt.figure(figsize=(8, 6))
plt.scatter(results[:, 0], results[:, 1], alpha=.2)
plt.xlabel('Pairwise Similarity Score (Smith-Waterman)')
plt.ylabel('Euclidean Distance Between Encodings')
plt.title('Correlation Plot of Model Encodings')
plt.savefig('Results/it1/anc/eval/comparative_enc_eval_anc.png')

  0%|          | 0/10000 [00:00<?, ?it/s]

In [69]:
np.corrcoef(results[:, 0], results[:, 1])

array([[1.        , 0.89908656],
       [0.89908656, 1.        ]])

In [70]:
np.mean((results[:, 0] - results[:, 1]) ** 2)

0.033395462379869754

#### Full Dataset Visualization

In [71]:
x, y = anc_seq_reps.T
f = plt.figure(figsize=(8, 6))
plt.scatter(x, y, alpha=.1, marker='o')
plt.title("Encoded Representations of ANC Dataset")
plt.savefig('Results/it1/anc/anc_all.png')

## Dataset Exploration

### Phylum

In [72]:
all_phyla, counts = np.unique(arlington_tax[:, 0], return_counts=True)
genuses = all_phyla[counts > 50]

In [73]:
plottable = np.isin(arlington_tax[:, 0], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[arlington_tax[plottable][:, 0] == genuses[i], i] = 1
plottable_seqs = anc_seq_reps[plottable]

In [74]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 500)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.3, marker='o')
ax.set_title("Encoded Representations of Phyla of Bacteria")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/anc/anc_phylum.png')

### Variable Region

In [75]:
all_regions, counts = np.unique(v_region_labels, return_counts=True)
genuses = all_regions[counts > 50]

In [76]:
plottable = np.isin(v_region_labels, genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[v_region_labels[plottable] == genuses[i], i] = 1
plottable_seqs = anc_seq_reps[plottable]

In [77]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.2, marker='o')
ax.set_title("Encoded Representations (Colored by Variable Region)")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/anc/anc_vregion.png')

### V8

In [78]:
subset = v_region_labels == 'V8'
arlington_tax_sub = arlington_tax[subset]
anc_2d_sub = anc_seq_reps[subset]
anc_lbl_sub = arlington_labels[subset]

In [79]:
f = plt.figure(figsize=(8, 6))
plt.scatter(anc_2d_sub[:, 0], anc_2d_sub[:, 1], alpha=.05)
plt.title('V8 Sequences from Arlington National Cemetery')
plt.savefig('Results/it1/anc/anc_v8.png')

#### Phylum

In [80]:
all_phyla, counts = np.unique(arlington_tax_sub[:, 0], return_counts=True)
genuses = all_phyla[counts > 50]

In [81]:
plottable = np.isin(arlington_tax_sub[:, 0], genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[arlington_tax_sub[plottable][:, 0] == genuses[i], i] = 1
plottable_seqs = anc_2d_sub[plottable]

In [82]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
for i in to_plot.T:
    pop = plottable_seqs[i.astype(bool)]
    samp = rng.integers(0, len(pop), 1000)
    x, y = zip(*pop[samp])
    ax.scatter(x, y, alpha=.3, marker='o')
ax.set_title("Encoded Representations of Phyla of Bacteria (V8 Only)")
leg = plt.legend(genuses,
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/anc/anc_v8_phylum.png')

#### Pigment

In [83]:
all_phyla, counts = np.unique(anc_lbl_sub, return_counts=True)
genuses = all_phyla[counts > 50]

In [84]:
plottable = np.isin(anc_lbl_sub, genuses)
to_plot = np.zeros((np.nonzero(plottable)[0].shape[0], genuses.shape[0]))
for i in range(len(genuses)):
    to_plot[anc_lbl_sub[plottable] == genuses[i], i] = 1
plottable_seqs = anc_2d_sub[plottable]

In [85]:
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111)
black, grey, unpigmented = anc_lbl_sub == 'black', anc_lbl_sub == 'grey', anc_lbl_sub == 'unpigmented'
plt.scatter(anc_2d_sub[black][:, 0], anc_2d_sub[black][:, 1], alpha=1)
plt.scatter(anc_2d_sub[grey][:, 0], anc_2d_sub[grey][:, 1], alpha=.2)
plt.scatter(anc_2d_sub[unpigmented][:, 0], anc_2d_sub[unpigmented][:, 1], alpha=.1)
ax.set_title("Encoded Representations of V8 Bacteria, by Pigmentation")
leg = plt.legend(['Black', 'Grey', 'Unpigmented'],
                markerscale=1,
                borderpad=1)
for lh in leg.legendHandles:
    lh.set_alpha(1)
plt.savefig('Results/it1/anc/anc_v8_pigment.png')