<a href="https://colab.research.google.com/github/move37-labs/nucleobench/blob/documentation/nucleobench/colab/custom_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# `nucleopt` is a minimal package that has nucleobench optimizers.
# For the full library, including tasks, install `nucleobench`.
!pip install nucleobench

1. Non-gradient designers, like AdaBeam, can use tasks described simply as a function (also [AdaLead](https://arxiv.org/abs/2010.02141), Directed Evolution, Ordered Beam, Unordered Beam)

1. Torch designers, like [Ledidi](https://www.biorxiv.org/content/10.1101/2020.05.21.109686v1), require a task class with gradients defined (also [FastSeqProp](https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-021-04437-5))

1. TISM designers, like Gradient Evo, require a task class with [Taylor In-Silico Mutagenesis](https://www.sciencedirect.com/science/article/pii/S2589004224020327) defined.

### Style 1: Task as a simple function.

In [2]:
# Step 1: Define the task.
import re
def count_regular_expression(seqs: list[str]) -> float:
  """Counts the number of occurances of 'ACT'"""
  # Use lookaheads so we allow overlapping regions.
  # Designers minimize function.
  return [-1 * float(len(re.findall('(?=(ACT))', s))) for s in seqs]

# Step 2: Define the designer.
from nucleobench import optimizations
opt_obj = optimizations.get_optimization('adabeam')
# Every task has some baseline, default arguments to initialize.
opt_init_args = opt_obj.debug_init_args()
opt_init_args['model_fn'] = count_regular_expression
opt_init_args['start_sequence'] = 'C' * 10
designer = opt_obj(**opt_init_args)

# Step 3: Run the designer and show the results.
designer.run(n_steps=100)
ret = designer.get_samples(1)
ret_score = count_regular_expression(ret)
print(f'Final score: {ret_score[0]}')
print(f'Final sequence: {ret[0]}')

Step 99 current scores: [np.float64(3.0), np.float64(3.0), np.float64(3.0), np.float64(3.0), np.float64(3.0), np.float64(2.0), np.float64(2.0), np.float64(2.0), np.float64(2.0), np.float64(2.0)]
Final score: -3.0
Final sequence: ACTTACTACT


Simple functions on gradient designers will fail.

In [3]:
opt_obj = optimizations.get_optimization('ledidi')
opt_init_args = opt_obj.debug_init_args()
opt_init_args['model_fn'] = count_regular_expression
opt_init_args['start_sequence'] = 'C' * 100
designer = opt_obj(**opt_init_args)  # Will fail.

AttributeError: 'function' object has no attribute 'model'

### Style 2: Task as a PyTorch differentiable object.

Define a differentiable PyTorch model that counts substrings.

In [4]:
import torch
import torch.nn.functional as F
from nucleobench.common import constants
from nucleobench.common import string_utils
from nucleobench.optimizations import model_class as mc

# Step 1: Define the task class.

class CountSubstringModel(torch.nn.Module, mc.PyTorchDifferentiableModel):
    """Count number of substrings, using convs."""
    def __init__(self, substring: str, vocab: list[str] = constants.VOCAB):
        super().__init__()
        self.substring = substring
        self.vocab = vocab

        self.substr_tensor = string_utils.dna2tensor(
            substring, vocab_list=self.vocab)
        self.substr_tensor = torch.unsqueeze(self.substr_tensor, dim=0)
        self.substr_tensor.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 3
        assert x.shape[1] == 4, x.shape
        out_tensor = F.conv1d(x, self.substr_tensor)
        out_tensor = torch.squeeze(out_tensor, 1)
        out_tensor = torch.square(out_tensor)  # Square to incentivize exact matches.
        out_tensor = torch.sum(out_tensor, dim=1)
        return -1 * out_tensor  # Flip sign so we minimize.

    def inference_on_tensor(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)

    def __call__(self, seqs: list[str]):
        torch_seq = string_utils.dna2tensor_batch(seqs)
        result = self.forward(torch_seq)
        assert result.ndim == 1, result.shape
        return [float(x) for x in result]


# Step 2: Instantiate the class (and sanity test it).
count_substring_model = CountSubstringModel('ACT')
assert count_substring_model(['ACT'])[0] == -(3**2)
assert count_substring_model(['ACTACT'])[0] == -(3**2 + 3**2)
assert count_substring_model(['ACTCT'])[0] == -(3**2 + 2**2)

# Step 3: Define the designer.
from nucleobench import optimizations
opt_obj = optimizations.get_optimization('ledidi')
designer = opt_obj(
    model_fn = count_substring_model,
    start_sequence = 'C' * 10,
    rng_seed=0)

# Step 3: Run the designer and show the results.
_ = designer.run(n_steps=1000)
ret = designer.get_samples(1)
ret_score = count_substring_model(ret)
print(f'Final score: {ret_score[0]}')
print(f'Final sequence: {ret[0]}')

Final score: -28.0
Final sequence: ACTACTCACT


### Style 3: Task as a TISM-aware object.

In [5]:
from typing import Optional

import torch
import torch.nn.functional as F

from nucleobench.common import constants
from nucleobench.common import attribution_lib_torch as att_lib
from nucleobench.common import string_utils

from nucleobench.optimizations import model_class as mc

class CountSubstringModel(torch.nn.Module, mc.TISMModelClass):
    """Count number of substrings, using convs."""
    def __init__(self, substring: str, vocab: list[str] = constants.VOCAB):
        super().__init__()
        self.substring = substring
        self.vocab = vocab

        self.substr_tensor = string_utils.dna2tensor(
            substring, vocab_list=self.vocab)
        self.substr_tensor = torch.unsqueeze(self.substr_tensor, dim=0)
        self.substr_tensor.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.ndim == 3
        assert x.shape[1] == 4, x.shape
        out_tensor = F.conv1d(x, self.substr_tensor)
        out_tensor = torch.squeeze(out_tensor, 1)
        out_tensor = torch.square(out_tensor)  # Square to incentivize exact matches.
        out_tensor = torch.sum(out_tensor, dim=1)
        return -1 * out_tensor  # Flip sign so we minimize.

    def inference_on_tensor(self, x: torch.Tensor) -> torch.Tensor:
        return self.forward(x)

    def tism(self, x: str, idxs: Optional[list[int]] = None) -> tuple[torch.Tensor, list[dict[str, torch.Tensor]]]:
        input_tensor = string_utils.dna2tensor(x, vocab_list=self.vocab)
        sg_tensor = att_lib.smoothgrad_torch(
            input_tensor=input_tensor,
            model=self.inference_on_tensor,
            noise_stdev=0.0,
            times=1,
            idxs=idxs,
        )
        sg = att_lib.smoothgrad_tensor_to_dict(sg_tensor, vocab=self.vocab)
        x_effective = x if idxs is None else [x[idx] for idx in idxs]
        sg = att_lib.smoothgrad_to_tism(sg, x_effective)
        y = self.inference_on_tensor(torch.unsqueeze(input_tensor, dim=0))
        return y, sg

    def __call__(self, seqs: list[str]):
        torch_seq = string_utils.dna2tensor_batch(seqs)
        result = self.inference_on_tensor(torch_seq)
        assert result.ndim == 1, result.shape
        return [float(x) for x in result]

# Step 2: Instantiate the class (and sanity test it).
count_substring_model = CountSubstringModel('ACT')
assert count_substring_model(['ACT'])[0] == -(3**2)
assert count_substring_model(['ACTACT'])[0] == -(3**2 + 3**2)
assert count_substring_model(['ACTCT'])[0] == -(3**2 + 2**2)

# Step 3: Define the designer.
from nucleobench import optimizations
opt_obj = optimizations.get_optimization('directed_evolution')
designer = opt_obj(
    model_fn = count_substring_model,
    start_sequence = 'C' * 10,
    use_tism=True,
    location_only=False,
    budget=10,
    fraction_tism=1.0,
    rnd_seed=0)

# Step 3: Run the designer and show the results.
_ = designer.run(n_steps=7)
ret = designer.get_samples(2)
ret_score = count_substring_model(ret)
print(f'Final score: {ret_score}')
print(f'Final sequence: {ret}')

Parsed TISM args: TISMArgs(location_only=False, budget=10, fraction_tism=1.0)


100%|██████████| 7/7 [00:01<00:00,  4.94it/s]

Best score: -28.0
Final score: [-28.0]
Final sequence: ['ACTACTCACT']



