# Learned Quality-Diversity using checkpoints parameters

## Import

In [1]:
import os
import pickle

import jax
import jax.numpy as jnp
import wandb
from omegaconf import OmegaConf

from learned_qd.evo.populations.learned_population import LearnedPopulation

In [2]:
key = jax.random.key(0)

## Load model

In [3]:
def get_config(run_path):
	if os.path.isdir(run_path):
		# Local
		config = OmegaConf.load(os.path.join(run_path, ".hydra", "config.yaml"))
	else:
		# WandB
		api = wandb.Api()
		run = api.run(run_path)
		config = OmegaConf.create(run.config)

	return config


def get_model_path(run_path):
	if os.path.isdir(run_path):
		# Local
		model_path = run_path
	else:
		# WandB
		api = wandb.Api()
		run = api.run(run_path)

		artifacts = run.logged_artifacts()
		model = [artifact for artifact in artifacts if artifact.type == "model"][0]
		model_path = model.download()

	return model_path


def get_config_and_model_path(run_path):
	config = get_config(run_path)
	model_path = get_model_path(run_path)
	return config, model_path

## Learned Quality-Diversity trained for Fitness

### Config

In [4]:
batch_size = 1024
genotype_size = 10
descriptor_size = 2

### Load fitness competition params

In [5]:
with open(os.path.join("learned_qd", "evo", "populations", "params", "fitness.pickle"), "rb") as f:
	params = pickle.load(f)

### Instantiate population

In [6]:
population = LearnedPopulation.init(
	genotype=jnp.zeros((genotype_size,)),
	key=key,
	max_size=1024,
	descriptor_size=descriptor_size,
	learned_fitness={
		"num_layers": 4,
		"num_heads": 4,
		"num_features": 16,
		"num_ffn_features": 16,
	},
)

population = population.replace(params=params)

### Commit genotypes to the population

In [7]:
# Dummy genotypes
genotypes = jnp.zeros((batch_size, genotype_size))
fitness = jnp.zeros((batch_size,))
descriptor = jnp.zeros((batch_size, descriptor_size))

# Commit dummy genotypes to population
population = population.commit(genotypes, fitness, descriptor)

## Learned Quality-Diversity trained for Novelty

### Config

In [8]:
batch_size = 1024
genotype_size = 10
descriptor_size = 2

### Load fitness competition params

In [9]:
with open(os.path.join("learned_qd", "evo", "populations", "params", "novelty.pickle"), "rb") as f:
	params = pickle.load(f)

### Instantiate population

In [10]:
population = LearnedPopulation.init(
	genotype=jnp.zeros((genotype_size,)),
	key=key,
	max_size=1024,
	descriptor_size=descriptor_size,
	learned_fitness={
		"num_layers": 2,
		"num_heads": 4,
		"num_features": 16,
		"num_ffn_features": 16,
	},
)

population = population.replace(params=params)

### Commit genotypes to the population

In [11]:
# Dummy genotypes
genotypes = jnp.zeros((batch_size, genotype_size))
fitness = jnp.zeros((batch_size,))
descriptor = jnp.zeros((batch_size, descriptor_size))

# Commit dummy genotypes to population
population = population.commit(genotypes, fitness, descriptor)

## Learned Quality-Diversity trained for Quality-Diversity

### Config

In [13]:
batch_size = 1024
genotype_size = 10
descriptor_size = 2

### Load fitness competition params

In [17]:
with open(os.path.join("learned_qd", "evo", "populations", "params", "qd.pickle"), "rb") as f:
	params = pickle.load(f)

### Instantiate population

In [18]:
population = LearnedPopulation.init(
	genotype=jnp.zeros((genotype_size,)),
	key=key,
	max_size=1024,
	descriptor_size=descriptor_size,
	learned_fitness={
		"num_layers": 4,
		"num_heads": 4,
		"num_features": 16,
		"num_ffn_features": 16,
	},
)

population = population.replace(params=params)

### Commit genotypes to the population

In [19]:
# Dummy genotypes
genotypes = jnp.zeros((batch_size, genotype_size))
fitness = jnp.zeros((batch_size,))
descriptor = jnp.zeros((batch_size, descriptor_size))

# Commit dummy genotypes to population
population = population.commit(genotypes, fitness, descriptor)