<a href="https://colab.research.google.com/github/katarinagresova/AgoBind/blob/main/notebooks/DNABERT_for_CLASH_1_1_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [2]:
!pip install -q -U git+https://github.com/katarinagresova/AgoBind

[K     |████████████████████████████████| 3.8 MB 7.0 MB/s 
[K     |████████████████████████████████| 895 kB 35.5 MB/s 
[K     |████████████████████████████████| 6.5 MB 30.3 MB/s 
[K     |████████████████████████████████| 596 kB 26.6 MB/s 
[K     |████████████████████████████████| 67 kB 4.6 MB/s 
[?25h  Building wheel for agobind (setup.py) ... [?25l[?25hdone


In [3]:
!pip install -q comet_ml
!pip install -q matplotlib

[K     |████████████████████████████████| 342 kB 13.2 MB/s 
[K     |████████████████████████████████| 54 kB 2.4 MB/s 
[K     |████████████████████████████████| 551 kB 43.1 MB/s 
[K     |████████████████████████████████| 54 kB 2.2 MB/s 
[?25h  Building wheel for configobj (setup.py) ... [?25l[?25hdone


In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Download data

In [5]:
!wget https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/train_set_1_1_CLASH2013_paper.tsv -P data
!wget https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/evaluation_set_1_1_CLASH2013_paper.tsv -P data

--2022-03-25 14:17:08--  https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/train_set_1_1_CLASH2013_paper.tsv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2276853 (2.2M) [text/plain]
Saving to: ‘data/train_set_1_1_CLASH2013_paper.tsv’


2022-03-25 14:17:09 (64.3 MB/s) - ‘data/train_set_1_1_CLASH2013_paper.tsv’ saved [2276853/2276853]

--2022-03-25 14:17:09--  https://raw.githubusercontent.com/ML-Bioinfo-CEITEC/miRBind/main/Datasets/evaluation_set_1_1_CLASH2013_paper.tsv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 20

# Parameters

In [6]:
from agobind.models import get_dnabert

config = {
    "train_data" : "data/train_set_1_1_CLASH2013_paper.tsv",
    "test_data": "data/evaluation_set_1_1_CLASH2013_paper.tsv",
    "eval_dset_ratio" : 0.2, #Deducted from the train set
    "batch_size" : 64,
    "gradient_accumulation_steps":4,
    "eval_steps" : 100,
    "freeze":False ,
    "layers_to_unfreeze":None,
    "random_weights":True,
    "kmer_len" : 6,
    "stride" : 1,
    "early_stopping_patience" : 5, 
    "learning_rate" : 2e-4,
    "weight_decay":0.01,
    "backbone":get_dnabert, 
}

In [7]:
from transformers import TrainingArguments

args = TrainingArguments(output_dir="output_checkpoints",
                        learning_rate=config['learning_rate'],
                        weight_decay=config['weight_decay'], 
                        num_train_epochs=500, 
                        per_device_train_batch_size=config['batch_size'],
                        per_device_eval_batch_size=config['batch_size'],
                        do_train=True,
                        do_eval=True,
                        logging_steps=10000,
                        warmup_steps=5000, 
                        eval_steps=config['eval_steps'],
                        evaluation_strategy="steps",
                        logging_strategy="steps",
                        logging_first_step=True,
                        load_best_model_at_end=True,
                        save_steps=100, 
                        save_total_limit=5,
                        gradient_accumulation_steps=config['gradient_accumulation_steps'],
                        metric_for_best_model="eval_loss"
)

# Setup comet.ml

When Google Drive is mounted, it will use API KEY stored there. Prompt will pop-up otherway.

In [8]:
import comet_ml

comet_ml.init(project_name='dnabert_for_clash')

COMET INFO: Comet API key is valid


# Train model

In [None]:
from agobind.training import get_trained_model

model, tokenizer = config['backbone'](config) 
trainer, encoded_samples_test = get_trained_model(config, args, model, tokenizer)
trained_model = trainer.model

# Final logging

In [None]:
from agobind.eval import get_f1_score, compute_pr_curve, get_probs_and_labels

current_experiment = comet_ml.get_global_experiment()
afterlog_experiment = comet_ml.ExistingExperiment(previous_experiment=current_experiment.get_key())

name = f"{'CLASH2013_paper'}:{config['kmer_len']}:{config['stride']}:freeze={config['freeze']}:LR={config['learning_rate']}:WD={config['weight_decay']}:BS={config['batch_size']}:rand_weights={config['random_weights']}:"
afterlog_experiment.set_name(name)

probs, labels = get_probs_and_labels(config['test_data'], encoded_samples_test, trained_model)
f1_score_test = get_f1_score(probs, labels)
recall, precision = compute_pr_curve(probs, labels)

afterlog_experiment.log_parameters(config)
afterlog_experiment.log_metric("test F1 score", f1_score_test)
afterlog_experiment.log_curve(f"pr-curve", recall, precision)
# TODO: find path to the best model and log it
afterlog_experiment.log_model("DNABERT_CLASH", "./output_checkpoints/checkpoint-900")

afterlog_experiment.end()

In [None]:
!zip /content/dnabert_for_clash_1_1.zip output_checkpoints/checkpoint-900/

  adding: output_checkpoints/checkpoint-900/ (stored 0%)
