In [1]:
%load_ext autoreload
%autoreload 2

import os.path as osp
import pyrootutils
from functools import partial

root = pyrootutils.setup_root(
    search_from=osp.abspath(''),
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,  # add to system path 
    dotenv=True,      # load environment variables .env file
    cwd=True,         # change cwd to root
)


import numpy as np
import hydra
import jax
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
import equinox.nn as nn
import matplotlib.pyplot as plt
from qdax.utils.metrics import default_qd_metrics
from bin.init.config import load_model_weights
from src.task.dnaqd import QDSearchDNA
from src.problem.levelgen import SimpleLevelGeneration
from src.evo.qd import MAPElites, CMAOptEmitter
from src.model.dev import NCA, NCA_DNA
from src.nn.ca import IdentityAndSobelFilter, SliceOutput, MaxPoolAlive
from src.nn.dna import DNAContextEncoder, DNAControl, DNAIndependentSampler
from src.analysis.levelgen import plot_generated_levels
from src.utils import tree_shape
from src.analysis.run_utils import select_and_unstack, generate_outputs

  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


In [2]:
model = NCA_DNA(
    nca=NCA(
        state_size=9,
        grid_size=(16, 16),
        dev_steps=(50),
        update_prob=0.5,
        context_encoder=DNAContextEncoder(4, 8, 16, key=jr.PRNGKey(5)),
        control_fn=DNAControl(9, 16, key=jr.PRNGKey(4)),
        alive_fn=MaxPoolAlive(alive_bit=3, alive_threshold=0.1),
        message_fn=IdentityAndSobelFilter(),
        update_fn=nn.Sequential(
            layers=[
                nn.Conv2d(in_channels=9 * 3, out_channels=32, kernel_size=1, key=jr.PRNGKey(1)),
                nn.Lambda(jax.nn.relu),
                nn.Conv2d(in_channels=32, out_channels=9, kernel_size=1, key=jr.PRNGKey(2)),
            ],
        ),
        output_decoder=SliceOutput(
            dim=0,
            start_idx=0,
            end_idx=1,
            squashing_function=partial(jax.numpy.argmax, axis=0),
        ),
        output_dev_steps=True,
    ),
    dna_generator=DNAIndependentSampler(
        8, 4, jr.PRNGKey(3)
    ),
)

weights = load_model_weights(model, "data/logs/dnaqd/simple_level_gen/nca_dna/evo/2024-01-19_07-42", checkpoint_file="best_ckpt-iteration_001676")
model = eqx.combine(weights, model)

In [3]:
qd_algorithm = MAPElites(
    CMAOptEmitter(
        10,
        32,
        0.1,
        num_descriptors = 2,
        num_centroids = 10,
        random_key= jr.PRNGKey(6),
    ),
    partial(default_qd_metrics, qd_offset=0.0),
)


# jaxpr = jax.make_jaxpr(qd_algorithm.init)
# print(jaxpr.jaxpr)


task = QDSearchDNA(
    SimpleLevelGeneration(16, 16),
    qd_algorithm=qd_algorithm,
    n_iters=2,
    popsize=10,
    n_centroids=100,
    n_centroid_samples=1000,
)

In [4]:
model_outputs = generate_outputs(model, task, jr.PRNGKey(1))
(dna_dist, outputs, states) = model_outputs[0]
scores, measures = model_outputs[2][0][:2]  #type: ignore

  repertoire = MapElitesRepertoire.init(


## Intervening on DNA strings

In [None]:
def intervene_on_dna(model, attn_weights):
    # attn_weights has shape (n_dev_iters, dna_seqlen, H, W)
    max_gene_idx = attn_weights.mean([0, 2, 3]).argmax()
    
    # hijack the model by changingin the context function
    mask = np.zeros((16 * 16, 8)).astype(bool)
    mask[:,max_gene_idx] = False
    
    intervened_model = eqx.tree_at(lambda m: m.nca.control_fn.dna_mask, model, mask, is_leaf=lambda x: x is None)
    return intervened_model

In [None]:
best_dna = measures[-1].argmax(axis=0)[0] # use 0 for path length, 1 for symmetry
# best_dna = scores[-1].argmax()

dna = dna_dist[-1, best_dna]
output = outputs[-1, best_dna]
# print(tree_shape(states))
weights = states[1][-1, best_dna]

intervened_model = intervene_on_dna(model, weights)
intervened_output, intervened_states = intervened_model.nca(dna, jr.key(0))

In [None]:
print(outputs.shape)

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.imshow(output)
ax2.imshow(intervened_output)
fig

## Finding directions of variations within the DNA

In [5]:
from sklearn.linear_model import LassoCV
from sklearn.cross_decomposition import PLSRegression

In [6]:
flattened_dnas = dna_dist.reshape(-1, 8, 4).argmax(-1)
flattened_scores = scores.reshape(-1)
flattened_measures = measures.reshape(-1, 2)
path_lengths, symmetries = flattened_measures.T

metrics = jnp.concatenate([flattened_scores[..., None], flattened_measures], axis=-1)

In [7]:
lasso_pl = LassoCV(alphas=np.geomspace(0.01, 1.0, 100), max_iter=10000)
lasso_symmetry = LassoCV(alphas=np.geomspace(0.01, 1.0, 100), max_iter=10000)
lasso_scores = LassoCV(alphas=np.geomspace(0.01, 1.0, 100), max_iter=10000)

In [8]:
lasso_pl.fit(flattened_dnas, path_lengths)
lasso_symmetry.fit(flattened_dnas, symmetries)
lasso_scores.fit(flattened_dnas, flattened_scores)

In [10]:
coeffs = np.stack([lasso_pl.coef_, lasso_symmetry.coef_, lasso_scores.coef_])
plt.imshow(coeffs, cmap='hot')
plt.colorbar()

In [11]:
print(lasso_symmetry.coef_)