### 1. Quick start and overview of RNAglib

In this section, we describe the main object types of RNAglib

#### 1.1. RNADataset

RNADataset objects represent a set of RNAs, each one being represented by its 3D structure.

Each item of the RNADataset is encoded by a dictionary containing (under the key "rna") the networkx Graph representing the RNA.

It is also possible to add Representation and FeaturesComputer objects to a RNADataset.

To create a default RNA Dataset, you can run the code below

In [1]:
from rnaglib.data_loading import RNADataset

dataset = RNADataset()

Index file not found at /Users/wissam/.rnaglib/indexes/rnaglib-nr-1.0.0.json. Run rnaglib_index
Database was found and not overwritten


The default value will be changed to `edges="edges" in NetworkX 3.6.


  nx.node_link_graph(data, edges="links") to preserve current behavior, or
  nx.node_link_graph(data, edges="edges") for forward compatibility.


When calling the `__get_item__` method of a `RNADataset` object, which takes as argument the index of one RNA, the following steps happen:
* If the dataset has `dataset.in_memory=False`, the graph of this RNA is loaded (otherwise, it has already been loaded)
* A dictionary encoding the RNA called `rna_dict` is being built. This dictionary has 3 items: the graph of the RNA, the path of the graph and the path of the structures of the RNA
* If some transforms to apply have been specified in `dataset.transforms`, then these transforms are being applied to the dataset.
* The features dictionary of this RNA is being computed using the transform `dataset.features_computer` which is an attribute of the dataset and maps a dictionary of type `rna_dict` to a  dictionary of features.
* Each representation associated with the dataset (that is to say contained in `dataset.representations`) is being applied to the considered RNA and appended to the dictionary `rna_dict`
* The method returns the dictionary `rna_dict_` which contains the graph of the RNA (under the key `rna`), the path to the graph (under the key `graph_path`), the path to the RNA structures (under the key `cif_path`) and the RNA representations for each representation of  `dataset.representations` (under the keys corresponding to the representation names such as `graph` or `point_cloud`)

In [2]:
dataset[0]

{'rna': <networkx.classes.digraph.DiGraph at 0x16a651350>,
 'graph_path': PosixPath('/Users/wissam/.rnaglib/datasets/rnaglib-nr-1.0.0/graphs/1a9n.json'),
 'cif_path': PosixPath('/Users/wissam/.rnaglib/datasets/rnaglib-nr-1.0.0/structures/1a9n.cif')}

#### 1.2. Transform

The Transform class groups all the functions which map the dictionaries representing RNAs (i.e. the items of a RNADataset object) into other objects (other dictionqries or objects of a different nature).

A specific tutorial gives further details about this class: https://rnaglib.org/en/latest/rnaglib.transforms.html

Below are detailed some subclasses of Transform: Representation, FeaturesComputer, FilterTransform, AnnotationTransform, PartitionTransform, Compose and ComposeFilters.

##### 1.2.1. Representation

A Representation object is a Transform that maps a RNA dictionary (as defined above) to a mathematical representation of this RNA. In the current version of RNAGlib, 4 representations are already implemented: GraphRepresentation, PointCloudRepresentation, VoxelRepresentation and RingRepresentation

GraphRepresentation converts RNA into a Leontis-Westhof graph (2.5D) where nodes are residues and edges are either base pairs or backbones.

PointCloudRepresentation converts RNA into a 3D point cloud based representation.

VoxelRepresentation converts RNA into a voxel/3D grid representation.

RingRepresentation converts RNA into a ring-based representation.

##### 1.2.2. FeaturesComputer

A FeaturesComputer is a Transform that maps a RNA Dictionary to a dictionary of features and targets (both RNA-level and node-level features and targets) of this RNA.

##### 1.2.3. FilterTransform

A FilterTransform returns the RNAs of the RNADataset that pass a certain filter.

##### 1.2.4. AnnotationTransform

An AnnotationTransform is a transform computing additional node features within each RNA graph.

##### 1.2.5. PartitionTransform

A PartitionTransform is a transform which breaks up each RNA structure into substructures.

##### 1.2.6. Compose

A Compose object is a Transform which consists in the composition of a series of transforms.

##### 1.2.7. ComposeFilters

A ComposeFilters object is a Transform consisting in the composition of a series of filters (objects of type FilterTransform)

#### 1.3. Tasks

A Task is an object representing a benchmarking task to be performed on the RNA. It is associated with a specific RNADataset. Once implemented, the Task object can be called to evaluate the performance on the defined Task of various models. One particular category of Tasks is already implemented as a subclass of Tasks: ResidueClassificationTask, which groups all the tasks consisting in classifying the amino-acids of the RNA.

#### 1.4. Encoders

Encoders are objects that vectorize features with a specific encoding. Indeed, the features available in the RNA NetworkX graph might have different types, including text, therefore it is necessary to vectorize them to perform learning using them.

### 2. Using the tasks

First, generate the necessary index files

You can use the following command: 
```
$ rnaglib_index
```

#### 2.1. Instantiate the task

Choose the task appropriate to your model. Here, we chose _RNA-Site_, a task instance called `BindingSiteDetection` for illustration.
When instantiating the task, custom splitters or other arguments can be passed if needed.

In [3]:
from rnaglib.tasks import BindingSiteDetection
from rnaglib.transforms import FeaturesComputer

task = BindingSiteDetection(root="tutorial", recompute = True) # set recompute=True to use the dataset designed for the Task, otherwise the dataset located at tutorial/dataset will be used

Creating task dataset from scratch...
Database was found and not overwritten


The default value will be changed to `edges="edges" in NetworkX 3.6.


  nx.node_link_graph(data, edges="links") to preserve current behavior, or
  nx.node_link_graph(data, edges="edges") for forward compatibility.


>>> Saving dataset.


The default value will be `edges="edges" in NetworkX 3.6.


  nx.node_link_data(G, edges="links") to preserve current behavior, or
  nx.node_link_data(G, edges="edges") for forward compatibility.


>>> Done


#### 2.2. [In option] Customize the task dataset

##### 2.2.1. Apply already implemented transforms to the dataset

You might want to apply transforms or preprocessing to your dataset which is not implemented by default in the task. In this case, you can apply additional transforms to the dataset.

We illustrate this below with the application of the transform PDBIDNameTransform.

In [14]:
from rnaglib.data_loading import RNADataset
from rnaglib.transforms import PDBIDNameTransform

rnas = PDBIDNameTransform()(task.dataset)
task.dataset = RNADataset(rnas=[r["rna"] for r in rnas])

You might want to go further by implementing custom transforms to do preprocessing on the dataset. We show below how to create custom annotators and custom filters

##### 2.2.2. Create a custom annotator

You might want to create a custom annotator to add new features to the nodes of the graphs, for instance to perform a new task  using those new annotations. The custom annotator will typically be called in the `process` method of a `Task` object.

In [None]:
from rnaglib.transforms import AnnotationTransform
from networkx import set_node_attributes

class CustomAnnotator(AnnotationTransform):
    def forward (self, rna_dict: dict) -> dict:        
        custom_annotation = {
            node: self._custom_annotation(nodedata)
            for node, nodedata in rna_dict['rna'].nodes(data=True)
        }
        set_node_attributes(rna_dict['rna'], custom_annotation, "custom_annotation")
        return rna_dict
    @staticmethod
    def _has_binding_site(nodedata: dict) -> bool:
        return ... # RNA dictionary-wise formula to compute the custom annotation

Once defined, you can apply your custom annotator to the dataset using the following code:

In [None]:
rnas = CustomAnnotator()(task.dataset)
task.dataset = RNADataset(rnas=[r["rna"] for r in rnas])

##### 2.2.3. Create a custom filter

Several filters are already implemented and available in `rnaglib.transforms`: `SizeFilter` which rejects RNAs which are not in the given size bounds, `RNAAttributeFilter` that rejects RNAs that lack a certain annotation at the whole RNA level, `ResidueAttributeFilter` which rejects RNAs that lack a certain annotation at the whole residue-level, `RibosomalFilter` that rejects ribsosomal RNA and `NameFilter` that filters RNA based on their names. However, you might want to create your own filter. This one could be for instance called in the `process` method of a new `Task` object.

In [None]:
from rnaglib.transforms import FilterTransform

class CustomFilter(FilterTransform):

    def __init__(self, ..., **kwargs):
        ...
        super().__init__(**kwargs)

    def forward(self, rna_dict: dict) -> bool:

        ...

        return ... # should return a Boolean (True if the RNA described by rna_dict passes the filter, False otherwise)

Once defined, you can apply your custom annotator to the dataset using the following code:

In [None]:
rnas = CustomFilter()(task.dataset)
task.dataset = RNADataset(rnas=[r["rna"] for r in rnas])

#### 2.3. [Optional] Customize the features

##### 2.3.1. Add features from the graph to the dataset

You might want to use input features which are different from the default ones specified for this task in RNAglib. In this case, it is necessary to add it to the features computer of the RNA.

The features can be chosen among the list of features available in the RNA graph: 'index', 'index_chain', 'chain_name', 'nt_resnum', 'nt_name', 'nt_code', 'nt_id', 'nt_type', 'dbn', 'summary', 'alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', 'epsilon_zeta', 'bb_type', 'chi', 'glyco_bond', 'C5prime_xyz', 'P_xyz', 'form', 'ssZp', 'Dp', 'splay_angle', 'splay_distance', 'splay_ratio', 'eta', 'theta', 'eta_prime', 'theta_prime', 'eta_base', 'theta_base', 'v0', 'v1', 'v2', 'v3', 'v4', 'amplitude', 'phase_angle', 'puckering', 'sugar_class', 'bin', 'cluster', 'suiteness', 'filter_rmsd', 'frame', 'sse', 'binding_protein', 'binding_ion', 'binding_small-molecule'.

When adding a feature to the features computer, you have to specify a dictionary named `custom_encoders` mapping each feature to the encoder chosen to encode the feature. Canonical encoders corresponding to each feature are available in [NODE_FEATURE_MAP](https://github.com/cgoliver/rnaglib/blob/30bded91462f655c235ef57efc07e834456615a4/src/rnaglib/config/feature_encoders.py#L7)

In the example below, we add the feature named `"phase_angle"` to the features computer of the dataset and specify that it should be encoded using the pre-implemented FloatEncoder.

In [5]:
from rnaglib.encoders import FloatEncoder

task.dataset.features_computer.add_feature(feature_names="phase_angle", custom_encoders={"phase_angle":FloatEncoder()})

##### 2.3.2. Create custom features

The strategy to create custom features consists in creating a Transform object which takes as input a RNADataset and transforms it by adding the new features to the graphs representing all of the items of the RNADataset.

To do so, you have to build a subclass of `Transform` and specify:

* its `name`

* its associated `encoder`

* its `forward` method taking as input the dictionary representing one RNA and returning the updated RNA dictionary (containing its additional features)

Once the custom features have been created, you still have to add them to the FeaturesComputer of the graph. To do so, you can check the documentation above (cf. section "Adding features to the features computer of a RNADataset").

Below is the structure to write such a transform:

In [None]:
from rnaglib.transforms import Transform

class AddCustomFeature(Transform):
    name = "add_custom_feature"
    encoder = ...
    def __init__(
            self, **kwargs
    ):
        super().__init__(**kwargs)
    def forward(self, rna_dict: Dict) -> Dict:

        ... # compute and add additional features

        return rna_dict

Once the transform above has been defined, there remains to apply it to the dataset as illustrated below:

In [None]:
rnas = AddCustomFeature()(task.dataset)
task.dataset = RNADataset(rnas=[r["rna"] for r in rnas])

#### 2.4. Add a representation

##### 2.4.1. Add an already implemented representation

It is necessary to add to the dataset a representation of the RNA structure. If the representation you want to add to perform the task is already implemented, you have to follow the code below. Already implemented representations include graphs (`GraphRepresentation` class), point clouds (`PointCloudRepresentation` class), voxels (`VoxelRepresentation` class) and rings (`RingRepresentation` class).

In [8]:
from rnaglib.transforms import GraphRepresentation

task.dataset.add_representation(GraphRepresentation(framework='pyg'))

>>> Adding <rnaglib.transforms.represent.graph.GraphRepresentation object at 0x3ad265390> representations.


##### 2.4.2. Create a custom representation

However, you might want to use a representation which doesn't belong to the aforementioned already implemented representations. In this case, you have to define your transformation.

In [None]:
from rnaglib.transforms import Representation

class CustomRepresentation(Representation):
    """
    Converts RNA into a custom representation.
    """

    def __init__(self,
                 **kwargs):
        super().__init__(**kwargs)
        pass

    def __call__(self, rna_graph, features_dict):

        ... # computes the representation

        return representation

    @property
    def name(self):
        return "custom_representation" # the name of the representation

    def batch(self, samples):

        ... # defines the way to batch representations of different samples together

        return batched_samples

Once the transformation has been defined, you have to add it to the dataset as in the case in which the representation has already been implemented

In [None]:
task.dataset.add_representation(CustomRepresentation)

#### 2.5. Set loaders

Since we changed the dataset by adding a representation (and maybe some additional features) to it, it is necessary to call `set_loaders` in order to update the train, val and test dataloaders.

In [9]:
task.set_loaders()

#### 2.6. Build a model

We first define the architecture of the model

In [6]:
import torch
import torch.nn.functional as F

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(5, 1)

    def forward(self, data):  
        x, edge_index, edge_type = data.x, data.edge_index, data.edge_attr
        x = self.linear(x)
        return F.sigmoid(x)
    

We then instantiate the model and the optimizer

In [7]:
import torch.optim as optim

# Define model
learning_rate = 0.0001
epochs = 2
device = "cpu"
model = LinearModel()
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.BCELoss()

#### 2.7. [Optional] Define your own `evaluate` method

If the representation you have chosen isn't the canonical representation for the task or if you want to get performace metrics which aren't implemented by default, you have to defined a custom `evaluate` method. Below is an example.

In [None]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, matthews_corrcoef

my_representation = "graph" # or "point_cloud" or "custom_representation"

def evaluate(loader):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0

    for batch in loader:
        graph = batch[my_representation]
        graph = graph.to(device)
        out = model(graph)
        loss = criterion(out, torch.flatten(graph.y).long())
        total_loss += loss.item()
        preds = out.argmax(dim=1)
        all_preds.extend(preds.tolist())
        all_labels.extend(graph.y.tolist())

    avg_loss = total_loss / len(loader)

    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    auc = roc_auc_score(all_labels, all_preds)
    mcc = matthews_corrcoef(all_labels, all_preds)

    return accuracy, f1, auc, avg_loss, mcc

#### 2.8. Train the model

In [10]:
def train():
    model.train()
    for batch in task.train_dataloader:
        graph = batch["graph"]
        graph = graph.to(device)
        optimizer.zero_grad()
        out = model(graph)
        loss = criterion(out, graph.y)
        loss.backward()
        optimizer.step()


for epoch in range(epochs):
    train()
    train_metrics = task.evaluate(model, task.train_dataloader) # you might be using evaluate instead of task.evaluate
    val_metrics = task.evaluate(model, task.val_dataloader) # you might be using evaluate instead of task.evaluate
    print(
    f"""Epoch {epoch + 1}, TrainAcc {train_metrics["accuracy"]:.4f} Val Acc: {val_metrics["accuracy"]:.4f}"""
    )

Epoch 1, TrainAcc 0.7101 Val Acc: 0.7604
Epoch 2, TrainAcc 0.7100 Val Acc: 0.7604


#### 2.9. Evaluate the model on the dataset

In [11]:
test_metrics = task.evaluate(
    model, task.test_dataloader, device
) # you might be using evaluate instead of task.evaluate

print(
    f"""Test Accuracy: {test_metrics["accuracy"]:.4f}, Test F1 Score: {test_metrics["f1"]:.4f}, Test AUC: {test_metrics["auc"]:.4f}, Test MCC: {test_metrics["mcc"]:.4f}"""
)

Test Accuracy: 0.6861, Test F1 Score: 0.0005, Test AUC: 0.4964, Test MCC: -0.0068


### 3. Creating custom tasks

#### Create a custom Task

In order to create a custom task, you have to define it as a subclass of a task category (for instance ResidueClassificationClass or a subclass you have created by yourself) and to specify the following:

* a target variable: the variable which has to be predicted by the model
* an input variable or a list of input variables: the inputs of the model
* a method `get_tasks_var` specifying the FeaturesComputer to build to perform the task (in general, it will call the aforementioned target and input variables)
* a method `process` creqting the dataset and applying some preprocessing to the dataset (especially annotation and filtering transforms) if needed

If the task belongs to another task category than ResidueClassificationClass (that is to say, node-level classification task), you have to define a new Task subclass corresponding to this task category and to specify:
* a method named `dummy_model` returning a dummy model to use to check the task is working well without any effort to define a model
* a method named `evaluate` which, given a model, outputs a dictionary containing performace metrics of this model on the task of interest.

For instance, in the cell below, we define a toy task called AnglePrediction consisting in predicting the phase angle by using the nucleotide code.

Since it is a regression and not a classification task, we first need to define a new subclass of Tasks class which we will call `ResidueRegression`.

In [12]:
import torch
from rnaglib.tasks import Task
from rnaglib.utils import DummyResidueModel
from sklearn.metrics import root_mean_squared_error, mean_absolute_error

class ResidueRegression(Task):
    def __init__(self, root, splitter=None, **kwargs):
        super().__init__(root=root, splitter=splitter, **kwargs)

    @property
    def dummy_model(self) -> torch.nn:
        return DummyResidueModel()

    def evaluate(self, model: torch.nn, device: str = "cpu") -> dict:
        model.eval()
        all_probs = []
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch in self.test_dataloader:
                graph = batch["graph"]
                graph = graph.to(device)
                out = model(graph)

                preds = out > 0.5
                all_probs.extend(out.cpu().flatten().tolist())
                all_preds.extend(preds.cpu().flatten().tolist())
                all_labels.extend(graph.cpu().y.flatten().tolist())

        # Compute performance metrics
        RMSE = root_mean_squared_error(all_labels, all_preds)
        MAE = mean_absolute_error(all_labels, preds)


        return {"RMSE": RMSE, "MAE": MAE}

Once the subclass `ResidueRegression` is defined, one can define the specific task `AnglePrediction`

In [13]:
from rnaglib.transforms import PDBIDNameTransform
from rnaglib.encoders import BoolEncoder

class AnglePrediction(ResidueRegression):
    # Target variable
    target_var = "phase_angle"
    # Input variable
    input_var = "nt_code"

    def __init__(self, root, splitter=None, **kwargs):
        super().__init__(root=root, splitter=splitter, **kwargs)
        
    # Creation and preprocessing of the dataset
    def process(self) -> RNADataset:
        rnas = RNADataset(debug=False, redundancy='all', rna_id_subset=SPLITTING_VARS['PDB_TO_CHAIN_TR60_TE18'].keys())
        dataset = RNADataset(rnas=[r["rna"] for r in rnas])
        # TODO: remove wrong chains using  SPLITTING_VARS["PDB_TO_CHAIN_TR60_TE18"]
        rnas = PDBIDNameTransform()(rnas)
        dataset = RNADataset(rnas=[r["rna"] for r in rnas]) 
        return dataset
    
    # Computation of the FeaturesComputer
    def get_task_vars(self) -> FeaturesComputer:
        return FeaturesComputer(
            nt_features=[self.input_var],
            nt_targets=self.target_var,
            custom_encoders={self.target_var: BoolEncoder()},
        )