## Model Architechture

In this section you will **implement the convolutional neural network** that scores RNA sequences (encoded as one‑hot tensors) for binding. The scaffold below contains a `GlobalCNN` class with a placeholder (`TODO-1`) where you must fill in the architecture.

The function we will be modeling is:

 $$f(\text{sequence})=P(\text{binding}|\text{sequence})$$

### Your Goals

1. **Understand the expected input shape**: `(N, 1, 4, L)`  
   - `N` = batch size  
   - First dimension after batch is a dummy channel (set to 1)  
   - Height = 4 rows (A, C, G, U order **must** match your encoder)  
   - Width = fixed sequence length after padding/truncation (e.g. `L = 501`)

2. **Design the network layers**, you can use
   - **Conv layers**: decide on kernel sizes and other paramters
   - **ReLUs** activation.
   - **MaxPools** along width only reduce length.
   - **Flatten** → Fully connected → ReLU → Output linear layer (1 value which should represent binding probabilty).

3. **Do NOT add a final sigmoid layer** inside the model.  
   - During training you will use `BCEWithLogitsLoss`, which internally applies a numerically stable sigmoid.  
   - During evaluation or for thresholding, wrap outputs with `torch.sigmoid(y_hat)`.
   - The final layer should be `nn.Linear(..., 1)`, change ... according to your architechture.

If you encounter issues **Verify shapes**:
   - After implementing, create a fake batch:
     ```python
     x = torch.randn(4, 1, 4, L)  # L = your sequence length
     out = model(x)
     print(out.shape)  # should be torch.Size([4])
     ```
   - If this fails, print intermediate shapes by temporarily breaking up `self.net` or inserting debug `forward` code.

In [None]:
import torch, numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import roc_auc_score

class GlobalCNN(nn.Module):
    """
    GlobalCNN
    =========
    A compact convolutional neural network for RNA (or DNA) binary
    binding-site classification. It treats each sequence as a single
    "image" with 1 channel, height 4 (nucleotide channels A,C,G,U),
    and width L (sequence length), and outputs a single value per
    sequence (before sigmoid).

    Expected Input
    --------------
    x : array of shape (N, 1, 4, L)
        - N = batch size
        - 1 = dummy channel dimension (because we already use the 4 rows
        - 4 = nucleotide axis (order must match your one-hot: A,C,G,U)
        - L = fixed padded sequence length (e.g. 501)

   
    Output
    ------
    logit : array of shape (N,)
        Raw (unnormalized) scores. Use torch.sigmoid() to
        convert to probabilities in (0,1). Use BCEWithLogitsLoss
        during training (so keep the final layer linear).

    Notes
    -----
    * No padding is used: width shrinks at each conv/pool step.
    * Height is collapsed to 1 after the first 4×10 convolution,
      effectively treating filters as motif detectors spanning all
      nucleotide rows simultaneously.
    """

    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            # TODO-1: Network architecture
            # === Your code here ===

            nn.Linear(..., 1)  # logit
            # === End of your code ===
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.

        Parameters
        ----------
        x : FloatTensor (N,1,4,L)
            One-hot encoded batch.

        Returns
        -------
        logits : FloatTensor (N,)
            Raw scores (before sigmoid).
        """
        return self.net(x).squeeze(1)

## Training
Now, we define a function to train the model.
Pay attention that the model outputs a number $z\in \mathbb{R}^n$, but, we want a probability. To do so, we use the sigmoid function which converts any real number to a value between 0 and 1.

Here, the skeleton code is provided as an example and you are free to change it according to any plots or information you want to add to your report.

Notes:

* Place the dataset files in the same directory as this notebook
* For validation, use `roc_auc_score` from `sklearn.metrics`

In [None]:
from utils.preprocess_fa import load_dataset, MAX_LEN, DEVICE

def train(model, train_loader, val_loader, epochs=30, lr=1e-3, wd=1e-4):
    model.to(DEVICE)
    opt  = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    criterion = nn.BCEWithLogitsLoss()

    for ep in range(1, epochs+1):
        # --- train ---
        model.train()
        for X,y in train_loader:
            X,y = X.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            y_hat = model(X)
            loss = criterion(y_hat, y)
            loss.backward()
            opt.step()

        # --- validate ---
        
        # TODO-2: Evaluate the model using the validation set
        model.eval()
        y_hat, labels = [], []
        with torch.no_grad():
            for X,y in val_loader: 
                probs = torch.sigmoid(model(X.to(DEVICE))).cpu().numpy()
                # === Your code here ===

        # === End of your code ===


# Load data and convert to 1-hot encoded matrices
train_pos = "./CLIPSEQ_AGO2.train.positives.fa"
train_neg = "./CLIPSEQ_AGO2.train.negatives.fa"
val_pos   = "./CLIPSEQ_AGO2.ls.positives.fa"
val_neg   = "./CLIPSEQ_AGO2.ls.negatives.fa"

X_train, y_train = load_dataset(train_pos, train_neg)
X_validate, y_validate = load_dataset(val_pos,   val_neg)

train_ds = TensorDataset(X_train, y_train)
val_ds   = TensorDataset(X_validate, y_validate)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True,  drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)

model = GlobalCNN()

# Train model
train(model, train_loader, val_loader, epochs=30)

# Motif Prediction

In this step, we generate a large set of random synthetic sequences to serve as a baseline for model scoring. Using the provided `get_random_seqs function`, we create both the raw sequences and their one-hot encoded representations (`one_hot`). We then evaluate these sequences in batches with our trained model, collecting their binding **probabilties**

Complete the code below, save in a numpy array called `scores` a list of scores for each random sequence

In [None]:
import torch, numpy as np

import importlib, utils.random_seqs
importlib.reload(utils.random_seqs)
from utils.random_seqs import get_random_seqs

# Generate random sequences
N_RANDOM = 100000  # how many random synthetic sequences to sample
synthetic_seqs, one_hot = get_random_seqs(N_RANDOM)
model.eval()

# Score with model
scores = []
with torch.no_grad():
    # TODO-3 : Score the random sequences in batches
    # === Your code here ===
    ...
    # === End of your code ===


assert (scores.shape[0] if hasattr(scores, "shape") else len(scores)) == N_RANDOM

## Save high scoring sequences to file
After scoring the random sequences, we filter for those with predicted binding scores above a defined threshold (`BINDING_SCORE`). This identifies sequences the model considers strong binders. From each high-scoring sequence, we extract a fixed-length central window (`WINDOW_LEN`) under the arbitrary assumption that this is the relevant region for motif discovery. These positive windows are saved to a text file, which can later be uploaded to tools such as [WebLogo](https://weblogo.berkeley.edu/logo.cgi) to visualize conserved sequence motifs.

**Note**: You might need to adjust `binding_score`, which is the binding probability threshold which is used to decide wether a specific sequence will be included in the PWM. Decide on a value that will only select the top $10-50$ sequences with highest binding probabiltiy.

In [None]:
BINDING_SCORE    = 0.9
WINDOW_LEN       = 100 # central window length to export
scores = np.asarray(scores) if isinstance(scores, list) else scores

# Extract positive windows
start = (MAX_LEN - WINDOW_LEN) // 2
end   = start + WINDOW_LEN
pos_idx = np.where(scores >= BINDING_SCORE)[0]
pos_windows = [synthetic_seqs[i][start:end] for i in pos_idx]
print(f"Positive threshold: {BINDING_SCORE:.4f}  (selected {len(pos_idx)})")

# Save
pos_file = f"positive.txt"
with open(pos_file, "w") as f:
    for w in pos_windows:
        f.write(w + "\n")

print(f"Saved {len(pos_windows)} positive windows to: {pos_file}")
print("Example positive windows (first 5):", pos_windows[:5])
print("Done. The files to WebLogo to visualize the motif.")