## 1. Quick start and overview of RNAglib

In this section, we describe the main object types of RNAglib's Task module. 
A more in-depth description of the objects described here is presented in the online's documentation (rnaglib.org), under the section "a peek under the hood".

### 1.1. RNADataset

RNADataset objects represent a set of RNAs, each one being represented by its 3D structure.

Each item of the RNADataset is encoded by a dictionary containing (under the key "rna") the networkx Graph representing the RNA.

It is also possible to add Representation and FeaturesComputer objects to a RNADataset.

To create a default RNA Dataset, you can run the code below

In [1]:
from rnaglib.dataset import RNADataset

dataset = RNADataset(debug=True)

Database was found and not overwritten


When calling the `__get_item__` method of a `RNADataset` object, which takes as argument the index of one RNA, the following steps happen:
* If the dataset has `dataset.in_memory=False`, the graph of this RNA is loaded (otherwise, it has already been loaded)
* A dictionary encoding the RNA called `rna_dict` is being built. This dictionary has 3 items: the graph of the RNA, the path of the graph and the path of the structures of the RNA
* If some transforms to apply have been specified in `dataset.transforms`, then these transforms are being applied to the dataset.
* The features dictionary of this RNA is being computed using the transform `dataset.features_computer` which is an attribute of the dataset and maps a dictionary of type `rna_dict` to a  dictionary of features.
* Each representation associated with the dataset (that is to say contained in `dataset.representations`) is being applied to the considered RNA and appended to the dictionary `rna_dict`
* The method returns the dictionary `rna_dict_` which contains the graph of the RNA (under the key `rna`), the path to the graph (under the key `graph_path`), the path to the RNA structures (under the key `cif_path`) and the RNA representations for each representation of  `dataset.representations` (under the keys corresponding to the representation names such as `graph` or `point_cloud`)

In [2]:
dataset[0]

{'rna': <networkx.classes.digraph.DiGraph at 0x7c3fe560fec0>,
 'graph_path': PosixPath('/home/vincent/.rnaglib/datasets/rnaglib-nr-2.0.2/graphs/1a9n.json'),
 'cif_path': PosixPath('/home/vincent/.rnaglib/structures/1a9n.cif')}

### 1.2. Transform

The Transform class groups all functions taking RNAs (i.e. the items of a RNADataset object) as inputs. A specific tutorial gives further details about this class: https://rnaglib.org/en/latest/rnaglib.transforms.html

In [3]:
from rnaglib.transforms import IdentityTransform

rna = dataset[10]
t = IdentityTransform()
new_rna = t(rna)
new_rna

{'rna': <networkx.classes.digraph.DiGraph at 0x7c3fe5696720>,
 'graph_path': PosixPath('/home/vincent/.rnaglib/datasets/rnaglib-nr-2.0.2/graphs/1d4r.json'),
 'cif_path': PosixPath('/home/vincent/.rnaglib/structures/1d4r.cif')}

Transforms are split in different categories :

* **Filter**: accept or reject an RNA based on some criteria (e.g. remove RNAs that are too large)

* **Partition**: generate a collection of substructure from a whole RNA (e.g. break up an RNA into individual chains)

* **Annotation**: adds or removes annotations from the RNA (e.g. query a database and store results in the RNA)

  * **Featurize**: A special kind of annotation that convert some RNA features into tensors for learning.

* **Represent**: compute tensor-based representations of RNAs (e.g. convert to voxel grid)

In [4]:
from rnaglib.transforms import SizeFilter

t = SizeFilter(max_size=200)
t(rna)

True

### 1.3. Dataset Transforms

``DSTransform`` are transformed that take a whole RNADataset as input and return a whole dataset.

They mostly revolve around computing distances between RNAs in the dataset, which is in turn used to remove redundancy and split the dataset.

In [5]:
from rnaglib.dataset import RNADataset
from rnaglib.dataset_transforms import CDHitComputer

rna_names = ['1a9n', '1av6', '1b23']
dataset = RNADataset(rna_id_subset=rna_names)

dataset = CDHitComputer()(dataset)
dataset.distances

Database was found and not overwritten


CD-Hit: 100%|██████████| 3/3 [00:00<00:00, 72315.59it/s]

Subsetting started...
Subsetting completed successfully.





{'cd_hit': array([[0., 1., 1.],
        [1., 0., 1.],
        [1., 1., 0.]])}

### 1.4. Tasks

A Task is an object representing a benchmarking task to be performed on the RNA.
 
This objects gathers a specific RNADataset, obtained by applying RNA transforms from the original database. This is implemented in process(). 
In addition, it includes redundancy removal and custom splitting, as implemented in the function post_process().

We propose seven tasks that can be fetched directly from Zenodo. 

In [6]:
from rnaglib.tasks import BindingSite

task_site = BindingSite(root='my_root')

>>> Loading precomputed task...
>>> Done


## 2. Using the tasks

#### 2.1. Instantiate the task

Choose the task appropriate to your model. Here, we chose _RNA-Site_ for illustration.
When instantiating the task, custom splitters or other arguments can be passed if needed.

In [7]:
from rnaglib.tasks import BindingSite
from rnaglib.transforms import FeaturesComputer

task = BindingSite(root="my_root")

>>> Loading precomputed task...
>>> Done


#### 2.2. [Optional] Customize the task dataset with a custom annotator

You might want to go further by implementing custom transforms to add new features to the nodes of the graphs, for instance to propose a new task, or train a better model using those new annotations. The custom annotator will typically be called in the `process()` method of a `Task` object.

In [8]:
from rnaglib.transforms import AnnotationTransform
from networkx import set_node_attributes


class CustomAnnotator(AnnotationTransform):
    def forward(self, rna_dict: dict) -> dict:
        custom_annotation = {
            node: self._custom_annotation(nodedata)
            for node, nodedata in rna_dict['rna'].nodes(data=True)
        }
        set_node_attributes(rna_dict['rna'], custom_annotation, "custom_annotation")
        return rna_dict

    @staticmethod
    def _custom_annotation(nodedata):
        return None  # RNA dictionary-wise formula to compute the custom annotation

Once defined, you can apply your custom annotator to the dataset using the following code:

In [9]:
task.dataset.transforms.append(CustomAnnotator())
# Get nodes in the first item of the dataset, to show the new custom annotation
nodes = list(task.dataset[0]['rna'].nodes(data=True))
node, data = nodes[0]
print(data['custom_annotation'])

None


#### 2.3. [Optional] Customize the features

You might want to use input features which are different from the default ones specified for this task in RNAglib. In this case, it is necessary to add it to the features computer of the RNA.
The features can be chosen among the list of features available in the RNA graph.

The other avenue is to create custom features. To do so, you should subclass the `Transform` object and specify:

* its `name`

* its `forward` method taking as input the dictionary representing one RNA and returning the updated RNA dictionary (containing its additional features)

Once the custom features have been created, you still have to add them to the FeaturesComputer of the graph. To do so, you can check the documentation above (cf. section "Adding features to the features computer of a RNADataset").
Below is the structure to write such a transform:

In [10]:
from rnaglib.transforms import Transform


class AddCustomFeature(Transform):
    name = "add_custom_feature"

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def forward(self, rna_dict):
        # ... compute and add additional features
        rna_dict["custom"] = 0
        return rna_dict


task.dataset.transforms.append(AddCustomFeature())
task.dataset[0]

{'rna': <networkx.classes.digraph.DiGraph at 0x7c3fe5ad9a30>,
 'graph_path': PosixPath('my_root/dataset/1a34_0.json'),
 'cif_path': None,
 'custom': 0}

#### 2.4. Add a representation

##### 2.4.1. Add an already implemented representation

It is necessary to add to the dataset a representation of the RNA structure. If the representation you want to add to perform the task is already implemented, you have to follow the code below. Already implemented representations include graphs (`GraphRepresentation` class), point clouds (`PointCloudRepresentation` class), voxels (`VoxelRepresentation` class) and rings (`RingRepresentation` class).

In [11]:
from rnaglib.transforms import GraphRepresentation

task.dataset.add_representation(GraphRepresentation(framework='pyg'))
task.dataset[0]

>>> Adding graph to dataset representations.


{'rna': <networkx.classes.digraph.DiGraph at 0x7c3fdb8f6780>,
 'graph_path': PosixPath('my_root/dataset/1a34_0.json'),
 'cif_path': None,
 'custom': 0,
 'graph': Data(x=[20, 4], edge_index=[2, 54], edge_attr=[54], y=[20])}

##### 2.4.2. Create a custom representation

However, you might want to use a representation which doesn't belong to the aforementioned already implemented representations. In this case, you have to define your transformation.

In [12]:
from rnaglib.transforms import Representation


class CustomRepresentation(Representation):
    """
    Converts RNA into a custom representation (here just 0).
    """

    def __init__(self):
        super().__init__()
        pass

    def __call__(self, rna_graph, features_dict):
        # computes the representation
        return 0

    @property
    def name(self):
        return "custom_representation"  # the name of the representation

    def batch(self, samples):
        # ... defines the way to batch representations of different samples together
        return samples

Once the transformation has been defined, you have to add it to the dataset as in the case in which the representation has already been implemented

In [13]:
task.dataset.add_representation(CustomRepresentation())

>>> Adding custom_representation to dataset representations.


#### 2.5. Set loaders

Since we changed the dataset by adding a representation (and maybe some additional features) to it, it is necessary to call `set_loaders` in order to update the train, val and test dataloaders.

In [14]:
# The default splitting algorithm tries to balance the labels present in our data, which can take a few minutes.
# For illustration purposes we use a RandomSplitter instead.
from rnaglib.dataset_transforms import RandomSplitter
task.splitter = RandomSplitter()
task.set_loaders()
for batch in task.train_dataloader:
    print(batch)
    break

Subsetting started...
Subsetting completed successfully.
Subsetting started...
Subsetting completed successfully.
Subsetting started...
Subsetting completed successfully.
{'graph': DataBatch(x=[75, 4], edge_index=[2, 206], edge_attr=[206], y=[75], batch=[75], ptr=[2]), 'custom_representation': [0], 'rna': [<networkx.classes.digraph.DiGraph object at 0x7c3fdbd8ae70>], 'cif_path': [None], 'custom': [0], 'graph_path': [PosixPath('my_root/dataset/1asz_0.json')]}



### 3. Creating custom tasks

#### Create a custom Task

In order to create a custom task, you have to define it as a subclass of a task category (for instance ResidueClassificationClass or a subclass you have created by yourself) and to specify the following:

* a name
* an input variable or a list of input variables: the inputs of the model
* a target variable: the variable which has to be predicted by the model
* a method `get_tasks_var` specifying the FeaturesComputer to build to perform the task (in general, it will call the aforementioned target and input variables)
* a method `process` creqting the dataset and applying some preprocessing to the dataset (especially annotation and filtering transforms) if needed

If the task belongs to another task category than ResidueClassificationClass (that is to say, node-level classification task), you have to define a new Task subclass corresponding to this task category and to specify:
* a method named `dummy_model` returning a dummy model to use to check the task is working well without any effort to define a model
* a method named `evaluate` which, given a model, outputs a dictionary containing performace metrics of this model on the task of interest.

A step-by-step tutorial is available in the online documentation