#This is a code reading comprehension practice.
## Below there is a block of code below written in Jax/Haiku.<br>
This code describes a [Transformer model](https://arxiv.org/abs/1706.03762). The model is intended to train to predict the lipophilicity (logD) directly from a SMILES string of a molecule. However there are errors in this code.<br>


## We would like you to answer a few questions about the code.
Please answer each questions in less than 5 sentences. <br>

## We would also like you to try to give a corrected version of the model

Note 1. There is no dataset and training loop. **You don't need to train the model.** <br>
Note 2. The Jax/Haiku code reads very similar to PyTorch or Tensorflow. <br>
Don't worry about the exact syntax when writing the corrected version of the code.

## Task and dataset explanation:
We have a simple artificial dataset where each data point is a pair of (SMILES, logD). In this artificial dataset you can assume all SMILES strings have 10 characters and that logD is a real number which is the regression target.

You can find some example data in the code block below.
Note that we assume the characters in SMILES are tokenized (str->List[int]) already.

[**SMILES**](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system) <br>
The simplified molecular-input line-entry system (SMILES) is a specification in the form of a line notation for describing the structure of chemical species using short ASCII strings.

For example, ethanol ($C_2H_6O$, 2D drawing shown below) can be written as **CCO** in the SMILES string.

In [7]:
#@title Example minibatch data
import numpy as np
import collections

print('Two example data points here:')
print('-'* 50)
print(f'{"c1ccccc1NO":<15} {"logD: 1.48":<25}')
print(f'{"Oc1ccccc1N":<15} {"logD: 0.97":<25}')

# Let's assume we have the dataset tokenized at the word level
# using a dictonary.
print('')
print('The tokenized data shown below:')
print('the minibatch of size = 2:')
print('-'* 50)
# '1' -> 1, 'c' -> 2, 'N' -> 3, 'O' -> 4
tokenized_text = np.array([[2, 1, 2, 2, 2, 2, 2, 1, 3, 4],
                           [4, 2, 1, 2, 2, 2, 2, 2, 1, 3]])
logD = np.array([1.48, 0.97])

# minibatch is the input to the lipo_net below.
minibatch = {'inputs': tokenized_text, 'logD': logD}

for i in range(minibatch['inputs'].shape[0]):
  print('inputs: ', minibatch['inputs'][i], 'logD: ', minibatch['logD'][i])

Two example data points here:
--------------------------------------------------
c1ccccc1NO      logD: 1.48               
Oc1ccccc1N      logD: 0.97               

The tokenized data shown below:
the minibatch of size = 2:
--------------------------------------------------
inputs:  [2 1 2 2 2 2 2 1 3 4] logD:  1.48
inputs:  [4 2 1 2 2 2 2 2 1 3] logD:  0.97


# The Lipophilicity Network

In [2]:
#@title Import relevant packages.
# !pip install dm-haiku
import jax
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

## lipo_net with Transformer block. (with bugs)

In [None]:
# Constants def.
_VOCAB_SIZE = 20
_EMBED_SIZE = 64
_NUM_ATTENTION_HEAD = 4
_NUM_TRANSFORMER_LAYERS = 1
_DROPOUT_RATE = 0.1

In [None]:
def lipo_net(batch):
  """The main network that predict logD from SMILES."""
  x = batch['inputs']  # (B, T), B:batch_size, T: sequence_length
  seq_len = x.shape[-1]

  # Embed the tokens.
  embedder = hk.Embed(_VOCAB_SIZE, _EMBED_SIZE)
  res = []
  for i in range(x.shape[0]):
    res.append(embedder(x[i]))
  token_embed = jnp.stack(res)

  trf_blocks = Transformer(
      num_heads=_NUM_ATTENTION_HEAD,
      num_layers=_NUM_TRANSFORMER_LAYERS,
      dropout_rate=_DROPOUT_RATE)

  out_tokens = trf_blocks(token_embed, is_training=True)

  y = jnp.sum(out_tokens, axis=1)
  return hk.Linear(1)(y)  # the predicted logD

class Transformer(hk.Module):
  """A transformer stack."""

  def __init__(self, num_heads, num_layers, dropout_rate, name=None):
    super().__init__(name=name)
    self._num_layers = num_layers
    self._num_heads = num_heads
    self._dropout_rate = dropout_rate

  def __call__(self, h, is_training):
    """Connects the transformer.
    Args:
      h: Inputs, [B, T, D].
      is_training: Whether we're training or not.
    Returns:
      Array of shape [B, T, D].
    """
    print(f"input embedding")
    print("-"*50)
    h_pooled = h.sum(axis=1)
    print(h_pooled[0] - h_pooled[1])

    init_scale = 2. / self._num_layers
    dropout_rate = self._dropout_rate if is_training else 0.0
    batch_size, seq_len, emb_size = h.shape

    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))
    causal_mask = jnp.tile(causal_mask, (batch_size, 1, 1, 1))
    # causal_mask if of shape (B, 1, T, T)
    # the dummy 2nd dim of size 1 is for broadcasting over multi attention heads.

    for i in range(self._num_layers):
      # Pass through the multi-head attention.
      # haiku.MultiHeadAttention
      # https://github.com/deepmind/dm-haiku/blob/main/haiku/_src/attention.py
      attn_module = hk.MultiHeadAttention(num_heads=self._num_heads,
                                          key_size=emb_size // self._num_heads,
                                          model_size=emb_size,
                                          w_init_scale=init_scale)
      h_attn = attn_module(query=h, key=h, value=h) #, mask=causal_mask)
      
      print(f"attention otuput layer {i}")
      print("-"*50)
      h_attn_pooled = h_attn.sum(axis=1)
      print(h_attn_pooled[0] - h_attn_pooled[1])
      
      
      h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn)
      h += h_attn
      h = layer_norm(h, f"atten_block_ln_{i}")

      # Pass through the MLP block for each token.
      mlp_block = hk.Sequential([
          hk.Linear(emb_size, w_init=hk.initializers.VarianceScaling(init_scale)),
          jax.nn.gelu,
          hk.Linear(emb_size, w_init=hk.initializers.VarianceScaling(init_scale))
      ])
      h_dense = mlp_block(h)
      h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense)
      h += h_dense
      h = layer_norm(h, f"mlp_block_ln_{i}")

    # Apply LayerNorm at the end of the whole Transformer stack.
    h = layer_norm(h, 'ln_final')
    return h


def layer_norm(x, name):
  """Apply a unique LayerNorm to x with default settings."""
  return hk.LayerNorm(
      axis=-1, create_scale=True, create_offset=True, name=name)(x)


In [82]:
# Demo of a single forward pass of the lipo_net.
net = hk.transform(lipo_net)
params = net.init(jax.random.PRNGKey(42), minibatch)
model_out = net.apply(params, jax.random.PRNGKey(42), minibatch)
print('The predicted logD output from the lipo net:')
print('-'*50)
print(model_out)

input embedding
--------------------------------------------------
[ 2.9802322e-08 -9.5367432e-07  0.0000000e+00 -2.3841858e-07
  4.7683716e-07  0.0000000e+00 -4.7683716e-07  0.0000000e+00
  0.0000000e+00  0.0000000e+00 -5.9604645e-08  0.0000000e+00
  5.9604645e-08  0.0000000e+00  0.0000000e+00  0.0000000e+00
 -9.5367432e-07  2.3841858e-07  0.0000000e+00  0.0000000e+00
  0.0000000e+00  0.0000000e+00  0.0000000e+00 -4.7683716e-07
  0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
  4.7683716e-07  0.0000000e+00  0.0000000e+00  8.9406967e-08
 -2.3841858e-07  0.0000000e+00  0.0000000e+00  0.0000000e+00
  4.7683716e-07 -4.7683716e-07 -1.1920929e-07  0.0000000e+00
  9.5367432e-07 -1.1920929e-07  0.0000000e+00  9.5367432e-07
  0.0000000e+00  1.4901161e-08  0.0000000e+00  4.7683716e-07
  0.0000000e+00  0.0000000e+00  0.0000000e+00  0.0000000e+00
 -4.7683716e-07  0.0000000e+00  9.5367432e-07 -4.7683716e-07
  0.0000000e+00 -4.7683716e-07  1.1920929e-07  4.7683716e-07
 -4.7683716e-07  0

# Questions about the code.
Let's assume we have a L2 loss function and a training loop to minimize the loss. <br>
[Loss function and the training loop not shown for simplicity.] <br>
Now we run the train loop and can see the loss has been minimzed. <br>

We found the model performance is not great in general though. <br>
Could you help debug by answering the questions below? <br>

## Question 1. When evaluated, the model predicts the exact same logD for
  *  Oc1ccccc1N
  *  c1ccccc1NO  <br>
These molecules are shown below. Can you think what ingredients may be missing from the model? <br>

### Answer:

## Question 2. The model is not stable during training (output NaN)
  * In your experience, what usually happens to the model so that it outputs Nan?
  * What may be missing in our Trasnformer implementation that makes it less stable than a standard implementation?

### Answer:

## Question 3. The model is not performing that well on the evaluation dataset.
 * From a hyperparameter sweep, strangely it seems the eval performance gets worse with higher dropout rate.
 * Could you guess what may be wrong in the training loop/model setup?

### Answer:

## Question 4. We find the model to be slightly worse than the one from our colleagues.
  * Which part of the implemented model may prevent the model's full capacity?
  * Hint: Think of language models (e.g. GPT-x) compared to BERT
What is the difference?

### Answer

## Question 5. Can you find parts in the model that could be improved for performance?
  * Any parts that can be vectorized/batch-applied?
  * No need to be faimilar with Jax/Haiku. You can give pseudo-code in PyTorch/Tensorflow.

### Answer:

## Question 6. Do you think this model will run for longer SMILES strings (e.g. 20 characters)?
  * Yes or No? Can you explain why?
  * Is there a specific line you may want to change to make it possibly generalize better for longer sequence?

### Answer

## Question 7. Combining all the questions above, could you improve the model in the block below?
  * Please try to keep the code backbone, and comment the changes you make.
  * Feel free to use the syntax from PyTorch/Tensorflow with comments.
  * We are not expecting the model to be syntax correct and runnable.