# üß¨ Genesis RNA: Train RNA Foundation Model in Colab

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/oluwafemidiakhoa/genesi_ai/blob/main/genesis_rna/genesis_rna_colab_training.ipynb)

Train a transformer-based RNA foundation model with **Adaptive Sparse Training (AST)** for energy-efficient pretraining.

## Features:
- üöÄ Free GPU training (T4/V100/A100)
- ‚ö° Adaptive Sparse Training (60% FLOP reduction)
- üéØ Multi-task learning (MLM + structure + pairing)
- üìä Real-time visualization
- üíæ Automatic checkpoint saving

## Runtime Settings:
**‚ö†Ô∏è IMPORTANT**: Go to `Runtime ‚Üí Change runtime type ‚Üí GPU (T4/V100/A100)`

## üì¶ Step 1: Setup & Installation

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"\n‚úÖ PyTorch version: {torch.__version__}")
print(f"‚úÖ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ CUDA version: {torch.version.cuda}")
else:
    print("‚ö†Ô∏è No GPU detected! Go to Runtime ‚Üí Change runtime type ‚Üí GPU")

In [None]:
# Clone repository
!git clone https://github.com/oluwafemidiakhoa/genesi_ai.git
%cd genesi_ai/genesis_rna

In [None]:
# Install dependencies
!pip install -q transformers datasets biopython pyyaml tqdm
!pip install -q adaptive-sparse-training

print("\n‚úÖ All dependencies installed!")

In [None]:
# Optional: Mount Google Drive to save checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory in Drive
!mkdir -p /content/drive/MyDrive/genesis_rna_checkpoints
CHECKPOINT_DIR = "/content/drive/MyDrive/genesis_rna_checkpoints"
print(f"‚úÖ Checkpoints will be saved to: {CHECKPOINT_DIR}")

## üìä Step 2: Data Preparation

Choose one of the following options:

### Option A: Quick Test with Dummy Data (Fastest)

In [None]:
# Use built-in dummy data generator
USE_DUMMY_DATA = True
DATA_PATH = None

print("‚úÖ Using dummy data for quick testing")
print("   This will generate synthetic RNA sequences")

### Option B: Small Real Dataset (Human ncRNAs, ~5 min download)

In [None]:
# Download human non-coding RNAs from Ensembl
!wget -q ftp://ftp.ensembl.org/pub/current_fasta/homo_sapiens/ncrna/Homo_sapiens.GRCh38.ncrna.fa.gz
!gunzip -f Homo_sapiens.GRCh38.ncrna.fa.gz

# Preprocess
!python scripts/preprocess_rna.py \
    --input Homo_sapiens.GRCh38.ncrna.fa \
    --output ./data/human_ncrna \
    --min_len 50 \
    --max_len 512 \
    --format pickle

USE_DUMMY_DATA = False
DATA_PATH = "./data/human_ncrna"

print("\n‚úÖ Human ncRNA data ready!")
!cat ./data/human_ncrna/stats.json

### Option C: Large Dataset (RNAcentral, ~30 min download, 15GB)

In [None]:
# Download RNAcentral (WARNING: Large file!)
!mkdir -p data/rnacentral
!wget -c ftp://ftp.ebi.ac.uk/pub/databases/RNAcentral/current_release/sequences/rnacentral_active.fasta.gz \
    -O data/rnacentral/rnacentral_active.fasta.gz
!gunzip -f data/rnacentral/rnacentral_active.fasta.gz

# Preprocess (limit to 1M sequences to fit in Colab)
!python scripts/preprocess_rna.py \
    --input data/rnacentral/rnacentral_active.fasta \
    --output ./data/rnacentral_processed \
    --min_len 50 \
    --max_len 512 \
    --max_sequences 1000000 \
    --format pickle

USE_DUMMY_DATA = False
DATA_PATH = "./data/rnacentral_processed"

print("\n‚úÖ RNAcentral data ready!")
!cat ./data/rnacentral_processed/stats.json

## ‚öôÔ∏è Step 3: Training Configuration

In [None]:
# Training hyperparameters
CONFIG = {
    # Model size: 'small', 'base', or 'large'
    'model_size': 'small',  # Use 'small' for Colab free tier
    
    # Training settings
    'batch_size': 16,       # Adjust based on GPU memory
    'num_epochs': 5,        # Increase for better performance
    'learning_rate': 1e-4,
    
    # AST settings (energy-efficient training)
    'use_ast': True,
    'ast_target_activation': 0.4,  # Train on 40% of samples
    
    # Output
    'output_dir': CHECKPOINT_DIR if 'CHECKPOINT_DIR' in dir() else './checkpoints',
}

print("üìã Training Configuration:")
for key, value in CONFIG.items():
    print(f"   {key}: {value}")

## üöÄ Step 4: Train the Model!

In [None]:
# Build training command
cmd = f"""
python -m genesis_rna.train_pretrain \
    --model_size {CONFIG['model_size']} \
    --batch_size {CONFIG['batch_size']} \
    --num_epochs {CONFIG['num_epochs']} \
    --learning_rate {CONFIG['learning_rate']} \
    --ast_target_activation {CONFIG['ast_target_activation']} \
    --output_dir {CONFIG['output_dir']}
"""

if USE_DUMMY_DATA:
    cmd += " --use_dummy_data"
else:
    cmd += f" --data_path {DATA_PATH}"

if CONFIG['use_ast']:
    cmd += " --use_ast"

print("üöÄ Starting training...\n")
print(f"Command: {cmd}\n")

# Set PYTHONPATH to include current directory so Python can find genesis_rna package
import os
os.environ['PYTHONPATH'] = os.getcwd() + ':' + os.environ.get('PYTHONPATH', '')

!{cmd}

## üìä Step 5: Monitor Training Progress

In [None]:
# Visualize training metrics (if training is complete)
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Check for training logs
log_file = Path(CONFIG['output_dir']) / 'training_log.json'

if log_file.exists():
    with open(log_file) as f:
        logs = json.load(f)
    
    # Plot loss curves
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Total loss
    axes[0, 0].plot(logs['epochs'], logs['train_loss'], label='Train')
    axes[0, 0].plot(logs['epochs'], logs['val_loss'], label='Val')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # MLM accuracy
    axes[0, 1].plot(logs['epochs'], logs['mlm_accuracy'])
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('MLM Accuracy')
    axes[0, 1].grid(True)
    
    # AST activation rate
    axes[1, 0].plot(logs['epochs'], logs['activation_rate'])
    axes[1, 0].axhline(y=0.4, color='r', linestyle='--', label='Target')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Activation Rate')
    axes[1, 0].set_title('AST Sample Selection Rate')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Pair F1 score
    axes[1, 1].plot(logs['epochs'], logs['pair_f1'])
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('F1 Score')
    axes[1, 1].set_title('Base-Pair Prediction F1')
    axes[1, 1].grid(True)
    
    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/training_curves.png", dpi=150)
    plt.show()
    
    print(f"\n‚úÖ Training metrics plotted and saved to {CONFIG['output_dir']}/training_curves.png")
else:
    print("‚ö†Ô∏è No training logs found yet. Run training first!")

## üî¨ Step 6: Test the Trained Model

In [None]:
# Load the trained model
import sys
sys.path.insert(0, '/content/genesi_ai/genesis_rna')

import torch
from genesis_rna.model import GenesisRNAModel
from genesis_rna.tokenization import RNATokenizer

# Load model checkpoint
model_path = f"{CONFIG['output_dir']}/best_model.pt"
model = GenesisRNAModel.from_pretrained(model_path, device='cuda')
model.eval()

# Create tokenizer
tokenizer = RNATokenizer()

print(f"‚úÖ Model loaded from {model_path}")
print(f"   Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Test on example sequences
test_sequences = [
    "ACGUACGUACGUACGU",
    "GGCCGGCCGGCCGGCC",
    "UUAAUUAAUUAAUUAA",
]

print("üß¨ Testing model on example sequences:\n")

for seq in test_sequences:
    # Tokenize
    input_ids = tokenizer.encode(seq, max_len=128).unsqueeze(0).to('cuda')
    
    # Get predictions
    with torch.no_grad():
        outputs = model(input_ids)
    
    # Decode MLM predictions
    mlm_preds = outputs['mlm_logits'].argmax(dim=-1)[0]
    predicted_seq = tokenizer.decode(mlm_preds)
    
    # Get structure predictions
    struct_preds = outputs['struct_logits'].argmax(dim=-1)[0]
    struct_labels = ['NONE', 'STEM', 'LOOP', 'BULGE', 'HAIRPIN']
    struct_pred_str = ' '.join([struct_labels[s] for s in struct_preds[1:len(seq)+1].cpu().numpy()])
    
    print(f"Input:     {seq}")
    print(f"Predicted: {predicted_seq[:len(seq)]}")
    print(f"Structure: {struct_pred_str}")
    print()

print("‚úÖ Inference test complete!")

## üìä Step 7: Visualize Predictions

In [None]:
# Visualize base-pair predictions
import matplotlib.pyplot as plt
import numpy as np

# Pick a sequence
seq = "GCGCAAACGCGC"  # Simple hairpin
input_ids = tokenizer.encode(seq, max_len=64).unsqueeze(0).to('cuda')

with torch.no_grad():
    outputs = model(input_ids)

# Get pair predictions
pair_logits = outputs['pair_logits'][0].cpu().numpy()
pair_probs = 1 / (1 + np.exp(-pair_logits))  # Sigmoid

# Plot heatmap
plt.figure(figsize=(10, 8))
plt.imshow(pair_probs[:len(seq)+2, :len(seq)+2], cmap='Blues', interpolation='nearest')
plt.colorbar(label='Pairing Probability')
plt.title(f'Predicted Base-Pair Matrix\nSequence: {seq}')
plt.xlabel('Position')
plt.ylabel('Position')

# Add sequence labels
labels = ['[CLS]'] + list(seq) + ['[SEP]']
plt.xticks(range(len(seq)+2), labels, rotation=90)
plt.yticks(range(len(seq)+2), labels)

plt.tight_layout()
plt.show()

print("‚úÖ Base-pair prediction heatmap generated!")

## üíæ Step 8: Download Trained Model

In [None]:
# Zip checkpoint directory
!zip -r genesis_rna_model.zip {CONFIG['output_dir']}

# Download
from google.colab import files
files.download('genesis_rna_model.zip')

print("‚úÖ Model checkpoint downloaded as genesis_rna_model.zip")
print("\nüìÅ Checkpoint contents:")
!ls -lh {CONFIG['output_dir']}

## üéØ Next Steps

### Continue Training
- Increase `num_epochs` for better performance
- Try `model_size='base'` for higher capacity (if you have Colab Pro)
- Use RNAcentral for full-scale pretraining

### Fine-Tuning
- Mutation effect prediction
- RNA-protein binding
- mRNA optimization

### Evaluation
- Test on benchmark datasets
- Compare with RiNALMo/RNA-FM
- Ablation studies (with/without AST)

### Deploy
- Export to ONNX for inference
- Build REST API
- Create web interface

---

## üìö Resources

- **GitHub**: https://github.com/oluwafemidiakhoa/genesi_ai
- **Documentation**: `genesis_rna/claude/genesis_rna_design_doc.md`
- **Paper**: (Link to your paper when published)

---

**Built with ‚ù§Ô∏è for RNA research | Powered by Adaptive Sparse Training**