# Performance Analysis -  Jax
> Number of effective sequences implemented in Tensorflow
- toc: true
- branch: master
- badges: true
- author: Donatas Repečka
- categories: [performance]

## Introduction

In [the previous post](https://donatasrep.github.io/donatas.repecka/performance/2021/04/27/Performance-comparison.html) I have compared various languages and libraries in terms of their speed. This notebook contains the code used in the comparison as well as some details about the choices made to improve the performance of Jax implementation.

## Setup

In [None]:
# !wget https://github.com/donatasrep/donatas.repecka/blob/master/data/picked_msa.fasta

In [None]:
# ! pip install --upgrade pip
# ! pip install numpy
# ! pip install pandas
# ! pip install --upgrade jax jaxlib
# ! pip install --upgrade jax jaxlib==0.1.66+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

## Getting data

In [None]:
import pandas as pd
import numpy as np

In [None]:
def get_data(path):
    fasta_df = pd.read_csv(path, sep="\n", lineterminator=">", index_col=False, names=['id', 'seq'])
    return fasta_df.seq.to_numpy(dtype=str)

In [None]:
seqs = get_data('../data/picked_msa.fasta')

Just to remind the pseudo code looks like this:

```
for seq1 in seqs:
  for seq2 in seqs:
    if count_mathes(seq1, seq2) > threshold:
      weight +=1
  meff += 1/weight
 
meff = meff/(len(seq1)^0.5)
```

In [None]:
import jax.numpy as jnp
from jax import jit as jax_jit
from jax import vmap

In [None]:
@jax_jit
def get_nf_jax_pair(a, b, threshold=0.8,batch_size=1):
    return  jnp.equal(a, b).mean(-1) > threshold

@jax_jit
def get_nf_jax_gpu(seqs):
    n_seqs, seq_len = seqs.shape      
    out = vmap(get_nf_jax_single, (0, None))(seqs, seqs)
    return jnp.sum(out) /(seq_len**0.5)

In [None]:
seqs_ = seqs[:100]
get_nf_jax_gpu(seqs_.view(np.uint32).reshape(seqs_.shape[0], -1))

In [None]:
%%timeit -n 3 -r 3
seqs_ = seqs[:100]
get_nf_tf(seqs_.view(np.uint32).reshape(seqs_.shape[0], -1))

In [None]:
seqs_ = seqs[:100]
get_nf_tf(seqs_.view(np.uint32).reshape(seqs_.shape[0], -1), dtype='float32')

In [None]:
%%timeit -n 3 -r 3
with tf.device('/cpu:0'):
    get_nf_tf(seqs_.view(np.uint32).reshape(seqs_.shape[0], -1), dtype='float32')

Couple points:
* Jax does not support Windows
* It is signficantly fater than Pytorch and Tensorflow and I am not entirely sure why. 