In [None]:
%load_ext autoreload
%autoreload 2

# Simple representation space tests with an FCN


## Imports

In [None]:
from typing import Optional, List, Callable, Dict, Any, Tuple, Union
from dataclasses import dataclass
import os

import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import optax  # https://github.com/deepmind/optax
import torch  # https://pytorch.org
from jaxtyping import Array, Float, Int, PyTree  # https://github.com/google/jaxtyping
import ast

import equinox as eqx
import wandb

import seaborn as sns
import matplotlib.pyplot as plt

jax.config.update('jax_platform_name', 'gpu')

from synbio_morpher.utils.misc.numerical import make_symmetrical_matrix_from_sequence
from synbio_morpher.utils.misc.string_handling import convert_liststr_to_list
from synbio_morpher.utils.misc.type_handling import flatten_listlike
from synbio_morpher.utils.results.analytics.naming import get_true_names_analytics, get_true_interaction_cols

jax.devices()

## Load data

In [None]:
fn = '../data/processed/ensemble_mutation_effect_analysis/2023_07_17_105328/tabulated_mutation_info.csv'
data = pd.read_csv(fn)
data.drop(columns=['Unnamed: 0'], inplace=True)
for c in get_true_interaction_cols(data, interaction_attr='binding_sites_idxs', remove_symmetrical=True) + get_true_interaction_cols(
        data, interaction_attr='binding_site_group_range', remove_symmetrical=True):
    data[c] = data[c].map(ast.literal_eval)

## Model: network of fully connected layers

In [None]:
# https://coderzcolumn.com/tutorials/artificial-intelligence/haiku-cnn

class FCN(hk.Module):

    def __init__(self, key, in_expected: int, layer_sizes: List[int], n_head: int):
        
        self.layers = self.create_layers(in_expected, layer_sizes, n_head, key)
        
        
    def create_layers(self, in_expected: int, layer_sizes: List[int], n_head: int, key):
        sizes = [in_expected] + layer_sizes + [n_head]
        key, *subkeys = jax.random.split(key, len(sizes))
        l = []
        for i, (si, sj, subkey) in enumerate(zip(sizes[:-1], sizes[1:], subkeys)):
            if l:
                l.append(jax.nn.relu)
                if np.mod(i, 2) == 0:
                    l.append(jax.nn.sigmoid)
            # if sj == n_head:
            #     l.append(eqx.nn.Dropout(p=0.4))
            l.append(
                hk.Linear(s)
            )
        l.append(jax.nn.log_softmax)
        return l
        

    def __call__(self, x: Float[Array, " num_interactions"], inference: bool = False, seed: int = 0) -> Float[Array, " n_head"]:
        for i, layer in enumerate(self.layers):
            kwargs = {} if not type(layer) == eqx.nn.Dropout else {
                'inference': inference, 'key': jax.random.PRNGKey(seed)}

            x = layer(x, **kwargs)
            
            # wandb.log({f'emb_{i}_{type(layer)}': x})
        return x

## Hyperparameters

In [None]:
BATCH_SIZE = 64
N_BATCHES = 8000
TRAIN_SPLIT = int(0.8 * N_BATCHES)
TEST_SPLIT = N_BATCHES - TRAIN_SPLIT
LEARNING_RATE = 1e-5
STEPS = 5000
PRINT_EVERY = 200
SEED = 0
TOTAL_DS = BATCH_SIZE * N_BATCHES
INPUT_SPECIES = 'RNA_1'

# CNN Architecture
N_CHANNELS = 1
OUT_CHANNELS = 3
KERNEL_SIZE = 1
MAX_POOL_KERNEL_SIZE = 1

# FCN Architecture
LAYER_SIZES = [10, 20, 50, 50, 50, 50]


n_samples = len(data['sample_name'].unique())

key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key, 2)

## Define input

In [None]:
def convert_to_scientific_exponent(x): 
    return int(f'{x:.0e}'.split('e')[1])

vectorized_convert_to_scientific_exponent = np.vectorize(convert_to_scientific_exponent)
filt = data['sample_name'] == INPUT_SPECIES

x = data[filt][get_true_interaction_cols(data, 'binding_rates_dissociation', remove_symmetrical=True)].iloc[:TOTAL_DS].values
x = jax.tree_util.tree_map(vectorized_convert_to_scientific_exponent, x)
x = jax.random.permutation(key, x, axis=0, independent=True)

# Make binding into 2D Interactions
# x = np.expand_dims(np.array([make_symmetrical_matrix_from_sequence(xx, n_samples) for xx in x]), axis=1)

y = data[filt]['sensitivity_wrt_species-6'].iloc[:TOTAL_DS].to_numpy()
y = jax.tree_util.tree_map(vectorized_convert_to_scientific_exponent, y)[None, :]
y = jax.random.permutation(key, y, axis=0, independent=True)

N_HEAD = len(np.unique(y))


if x.shape[0] < TOTAL_DS:
    print(f'WARNING: The filtered data is not as large as the requested total dataset size: {x.shape[0]} vs. requested {TOTAL_DS}')

### Define model

In [None]:
model = FCN(subkey, in_expected=x.shape[1], layer_sizes=LAYER_SIZES, n_head=N_HEAD)


In [None]:
def loss(
    model: FCN, x: Float[Array, "batch 1 28 28"], y: Int[Array, " batch"]
) -> Float[Array, ""]:
    
    pred_y = jax.vmap(model)(x)
    return cross_entropy(y, pred_y)


def cross_entropy(
    y: Int[Array, " batch"], pred_y: Float[Array, "batch 10"]
) -> Float[Array, ""]:
    # y are the true targets, and should be integers 0-9.
    # pred_y are the log-softmax'd predictions.
    pred_y = jnp.take_along_axis(pred_y, y, axis=1)
    # pred_y = jnp.take_along_axis(pred_y, y, axis=1)
    return -jnp.mean(pred_y)


# Example loss
loss_value = loss(model, x[:10], y[:10])
print(loss_value.shape)  # scalar loss
# Example inference
output = jax.vmap(model)(x[:10])
print(output.shape)  # batch of predictions