# Ribonanza - Attempt 3

A second approach to the [Stanford Ribonanza problem](https://www.kaggle.com/competitions/stanford-ribonanza-rna-folding/) that builds off the first and second approaches.

Major differences:
- use of pytorch instead of tensorflow
- use of attention model architecture
- use bpp

Currently, the attention architecture scores 0.182

## Todo

- experiment with other types of attention
- try using LinearFold instead of EternaFold
- try using ThreshKnot instead of EternaFold

## Setup

### Filesystem Setup

Your project directory should look like this:

- `(project directory)`
    - `ribonanza2.ipynb`
    - `train_data.csv`
    - `test_data.csv` (optional)

`train_data.csv` is the only file necessary for training, and it can be downloaded from the kaggle competition linked in the description.

`test_data.csv` is only necessary if you intend to make and submit predictions.

### Dependency Seetup

Need to install pip packages:
```sh
pip install torch numpy seaborn xformers arnie datasets tensorboard
```
Need to install conda packages for eternafold:
```sh
conda install -c conda-forge "libgcc-ng>=12" "libstdcxx-ng>=12"
conda install -c bioconda eternafold
```

### Code Setup

In [1]:
# imports
import torch
from torch.utils.tensorboard.writer import SummaryWriter
import torch.utils.data as data
import numpy as np
from tqdm import tqdm
from datasets import Dataset
import os

# for visualization
import seaborn

# typing hints
from typing import List
from collections.abc import Callable

# used for better attention mechanisms
import xformers.components.positional_embedding as embeddings
import xformers.ops as xops
import xformers.components.attention as attentions
import xformers.components.attention.utils as att_utils
import xformers.components as components

# used for bpps
from arnie.bpps import bpps

In [2]:
# constants

# according to kaggle, this is the maximum # of reactivites to be used
NUM_REACTIVITIES = 457

# there are 4 different bases (AUCG)
NUM_BASES = 4

In [3]:
# if no gpu available, use cpu
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Preprocessing

### Filter Data

In [4]:
def filter_data(out: str, key: str, value: str, file_name: str, force: bool):
    """
    Filters a file to only take datapoints
    whose values of `key` are `value`.

    Parameters:
        - out: str - the name of the file that will store the filtered datapoints
        - key: str - the name of the key to look at
        - value: str - the value that the key should have
        - file_name: str - the name of the file that contains all the datapoints.
        - force: bool - whether or not to force re-processing of the data (if False and `out` already exists, no work will be done)
    """
    if os.path.exists(out) and not force:
        print("File already exists, not doing any work")
        return

    count = 0

    # count how many lines we have in total
    with open(file_name) as file:
        line = file.readline()  # ignore the header
        line = (
            file.readline()
        )  # take the first line since we increment count in the loop
        while line != "":
            count += 1
            line = file.readline()

    # use that knowledge for a progress bar
    with open(file_name, "r") as file, open(out, "w") as outfile:
        # write the header
        header = file.readline()
        outfile.write(header)

        # get what index the SN_filter is
        SN_idx = header.split(",").index(key)

        # only take the approved filtered lines
        for _ in tqdm(range(count)):
            line = file.readline()
            temp = line.split(",")
            if temp[SN_idx] == value:
                outfile.write(line)


def filter_train_data(force: bool = False):
    """
    Filters the immense train_data.csv to only take datapoints
    whose SN_filter (Signal to Noise filter) is 1. In other words,
    we only take good reads. These filtered datapoints are then
    written to the file provided

    Parameters:
        - force: bool - whether or not to force re-processing of the data (if False and `out` already exists, no work will be done)
    """
    filter_data("train_data_filtered.csv", "SN_filter", "1", "train_data.csv", force)


def filter_2A3(force: bool = False):
    """
    Only take the 2A3 points

    Parameters:
        - force: bool - whether or not to force re-processing of the data (if False and `out` already exists, no work will be done)
    """
    filter_data(
        "train_data_2a3.csv",
        "experiment_type",
        "2A3_MaP",
        "train_data_filtered.csv",
        force,
    )


def filter_DMS(force: bool = False):
    """
    Only take the DMS points

    Parameters:
        - force: bool - whether or not to force re-processing of the data (if False and `out` already exists, no work will be done)
    """
    filter_data(
        "train_data_dms.csv",
        "experiment_type",
        "DMS_MaP",
        "train_data_filtered.csv",
        force,
    )

In [5]:
# filter our data
filter_train_data()

File already exists, not doing any work


In [6]:
# take the 2a3 points
filter_2A3()

File already exists, not doing any work


In [7]:
# take the dms points
filter_DMS()

File already exists, not doing any work


### Convert Data to Inputs and Outputs

In [8]:
# encode inputs as
# A : 1
# U : 2
# C : 3
# G : 4
base_map = {
    "A": 1,
    "U": 2,
    "C": 3,
    "G": 4,
}

In [9]:
def process_data(row):
    """
    Convert a row containing all csv columns in the original dataset
    to a row containing only the columns:
    - inputs
    - outputs
    - bpp
    - output_masks
    - reactivity_error
    - bool_output_masks
    """
    # initialize arrays
    # note that we assume everything is masked until told otherwise
    inputs = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)
    bpp = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)
    output_masks = np.ones((NUM_REACTIVITIES,), dtype=np.bool_)
    reactivity_errors = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)
    reactivities = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)

    seq_len = len(row["sequence"])

    # encode the bases
    inputs[:seq_len] = np.array(
        list(map(lambda letter: base_map[letter], row["sequence"]))
    )

    # get the probability that any of those bases are paired
    bpp[:seq_len] = np.max(bpps(row["sequence"], package="eternafold"), axis=-1)

    # get the reactivities and their errors
    reactivities[:seq_len] = np.array(
        list(
            map(
                lambda seq_idx: np.float32(
                    row["reactivity_" + str(seq_idx + 1).rjust(4, "0")]
                ),
                range(seq_len),
            )
        )
    )
    reactivity_errors[:seq_len] = np.array(
        list(
            map(
                lambda seq_idx: np.float32(
                    row["reactivity_error_" + str(seq_idx + 1).rjust(4, "0")]
                ),
                range(seq_len),
            )
        )
    )

    # replace reactivity error nans with 0s (assume no error)
    reactivity_errors = np.where(np.isnan(reactivity_errors), 0.0, reactivity_errors)

    # get where all the reactivities are nan
    nan_locats = np.isnan(reactivities)

    # where it is nan, store True, else False
    output_masks[:seq_len] = nan_locats[:seq_len]

    # where it is not nan, store the reactivity and error, else 0
    reactivities[:seq_len] = np.where(
        nan_locats[:seq_len] == False, reactivities[:seq_len], 0.0
    )
    reactivity_errors[:seq_len] = np.where(
        nan_locats[:seq_len] == False, reactivity_errors[:seq_len], 0.0
    )

    # store the values
    row = {}
    row["inputs"] = inputs
    row["bpp"] = bpp
    row["outputs"] = np.clip(reactivities, 0, 1)
    row["output_masks"] = np.clip(
        np.where(output_masks, 0.0, 1.0) - np.abs(reactivity_errors), 0, 1
    )
    row["bool_output_masks"] = output_masks
    row["reactivity_errors"] = np.abs(reactivity_errors)

    return row


def process_data_test(row):
    """
    Almost the same as process_data, except it only takes inputs and bpp
    """
    # initialize arrays
    # note that we assume everything is masked until told otherwise
    inputs = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)
    bpp = np.zeros((NUM_REACTIVITIES,), dtype=np.float32)

    seq_len = len(row["sequence"])

    # encode the bases
    inputs[:seq_len] = np.array(
        list(map(lambda letter: base_map[letter], row["sequence"]))
    )

    # get the probability that any of those bases are paired
    bpp[:seq_len] = np.max(bpps(row["sequence"], package="eternafold"), axis=-1)

    row["inputs"] = inputs
    row["bpp"] = bpp
    return row

In [10]:
def preprocess_csv(
    out: str,
    file_name: str,
    n_proc: int = 12,
    map_fn: Callable = process_data,
    extra_cols_to_keep: List[str] = [],
):
    """
    Preprocess the csv and save the preprocessed data as a dataset
    that can be loaded via datasets.Dataset.load_from_file

    The dataset contains the following items:
        - bool_output_masks: Tensor(dtype=torch.bool) - the output masks.
            If True, then that item should NOT be used to calculate loss.
            If False, then that item should be used to calculate loss
        - reactivity_errors: Tensor(dtype=torch.float32) - the reactivity errors
        - output_masks: Tensor(dtype=torch.float32) - the elementwise weights to multiply the loss by to properly
            account for masked items and reactivity errors
        - inputs: tensor(dtype=torch.float32) - the input sequence, specifically of shape (None, NUM_REACTIVITIES)
        - bpp: tensor(dtype=torch.float32)
        - outputs: tensor(dtype=torch.float32) - the expected reactivities. Note that a simple MAE or MSE loss will not
            suffice for training models on this dataset. Please use the output_masks tensor as well.

    Parameters:
        - out: str - the name of the file to save the arrays to
        - file_name: str - the name of the input csv file
        - n_proc: int - the number of processes to use while processing data
        - map_fn: Callable - the function to apply to all dataset rows
        - extra_cols_to_keep: List[str] - the names of any extra columns to keep in the dataset
    """
    if os.path.exists(out):
        print(
            "File already exists, not doing any work.\n"
            + "To force re-preprocessing, delete the dataset directory and restart the kernel."
        )
        return

    names_to_keep = [
        "reactivity_errors",
        "bool_output_masks",
        "output_masks",
        "inputs",
        "outputs",
        "bpp",
    ] + extra_cols_to_keep

    # load dataset and map it to our preprocess function
    ds = Dataset.from_csv(file_name).map(map_fn, num_proc=n_proc)

    # drop excess columns and save to disk
    ds.remove_columns(
        list(filter(lambda c: c not in names_to_keep, ds.column_names))
    ).save_to_disk(out)

In [11]:
preprocess_csv("train_data_2a3_preprocessed", "train_data_2a3.csv")

File already exists, not doing any work.
To force re-preprocessing, delete the dataset directory and restart the kernel.


In [12]:
preprocess_csv("train_data_dms_preprocessed", "train_data_dms.csv")

File already exists, not doing any work.
To force re-preprocessing, delete the dataset directory and restart the kernel.


In [13]:
preprocess_csv(
    "test_data_preprocessed",
    "test_sequences.csv",
    map_fn=process_data_test,
    extra_cols_to_keep=["id_min", "id_max"],
)

File already exists, not doing any work.
To force re-preprocessing, delete the dataset directory and restart the kernel.


### Load the desired dataset

In [14]:
desired_dataset = "2a3"  # either "2a3" or "dms"

In [15]:
dataset = Dataset.load_from_disk(
    f"train_data_{desired_dataset}_preprocessed"
).with_format("torch")
dataset

Dataset({
    features: ['inputs', 'bpp', 'outputs', 'output_masks', 'bool_output_masks', 'reactivity_errors'],
    num_rows: 210992
})

In [16]:
columns = ["inputs", "outputs", "output_masks", "bpp"]
split = dataset.train_test_split(test_size=0.1).select_columns(columns)
train_dataset = split["train"]
val_dataset = split["test"]

print(
    "train set is len", len(train_dataset), "and val dataset is len", len(val_dataset)
)

train set is len 189892 and val dataset is len 21100


### Visualize

Now that we have the dataset preprocessed, we can visualize what the distribution looks like

In [17]:
visualize = False

In [18]:
def multiply(*args):
    """
    Gets the product of all arguments passed to it
    """
    prod = 1
    for item in args:
        prod *= item
    return prod

In [19]:
if visualize:
    # select all the reactivities that are valid (that shouldn't be masked)
    visualized_items = torch.masked_select(
        dataset["outputs"], dataset["bool_output_masks"] == False
    ).numpy()

    # sanity check that we didn't take all the items
    print(
        f"took {visualized_items.shape[0] / multiply(*dataset['outputs'].shape):.5f}% of the data"
    )

In [20]:
if visualize:
    seaborn.histplot(visualized_items, binwidth=0.1)
else:
    print("Not visualizing. Set `visualize` to `True` to visualize data")

Not visualizing. Set `visualize` to `True` to visualize data


## Model

To model our distribution, we have two models:
- an AttentionModel that uses attention layers
- a BaselineModel that we compare against that uses only a couple convolution layers and a Linear layer

### Baseline

First, we establish a baseline model comprised of Linear and Convolutional layers

In [21]:
class BaselineModel(torch.nn.Module):
    def __init__(self, context_window: int = 31, device: str = DEVICE):
        super(BaselineModel, self).__init__()
        self.preLayer = torch.nn.Linear(2, 2).to(device)
        self.conv_layer = torch.nn.Conv1d(2, 2, context_window, padding="same").to(
            device
        )
        self.conv_layer_b = torch.nn.Conv1d(2, 2, context_window, padding="same").to(
            device
        )
        self.ff = torch.nn.Linear(NUM_REACTIVITIES * 2, NUM_REACTIVITIES).to(device)
        self.gelu = torch.nn.GELU()
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor):
        x = self.gelu(self.preLayer(x))
        x = x.reshape(x.shape[0], -1, x.shape[1])

        x = self.gelu(
            self.conv_layer(x)
            + torch.flip(self.conv_layer_b(torch.flip(x, dims=[2])), dims=[2])
        )

        return self.relu(self.ff(x.flatten(start_dim=1)))

Now we can write our attention model

In [22]:
class CustomTransformerEncoderLayer(torch.nn.Module):
    def __init__(
        self,
        attention: components.Attention,
        latent_dim: int,
        ff_dim: int,
        n_heads: int,
        device: str = DEVICE,
        *args,
        **kwargs
    ) -> None:
        super(CustomTransformerEncoderLayer, self).__init__()
        self.attention = components.MultiHeadDispatch(
            dim_model=latent_dim, num_heads=n_heads, attention=attention, **kwargs
        ).to(device)
        self.layer_norm = torch.nn.LayerNorm(latent_dim).to(device)

        self.ff1 = torch.nn.Linear(latent_dim, ff_dim).to(device)
        self.ff2 = torch.nn.Linear(ff_dim, latent_dim).to(device)
        self.gelu = torch.nn.GELU()

    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor):
        # MHA self attention, add, norm
        x = self.layer_norm(self.attention(x, att_mask=attention_mask) + x)

        # ff, add, norm
        x = self.layer_norm(self.gelu(self.ff2(self.gelu(self.ff1(x)))) + x)

        return x


class CustomTransformerDecoderLayer(torch.nn.Module):
    def __init__(
        self,
        attention: components.Attention,
        latent_dim: int,
        ff_dim: int,
        n_heads: int,
        device: str = DEVICE,
        *args,
        **kwargs
    ) -> None:
        super(CustomTransformerDecoderLayer, self).__init__()
        self.crossattention = components.MultiHeadDispatch(
            dim_model=latent_dim, num_heads=n_heads, attention=attention, **kwargs
        ).to(device)

        self.selfattention = components.MultiHeadDispatch(
            dim_model=latent_dim, num_heads=n_heads, attention=attention, **kwargs
        ).to(device)
        self.layer_norm = torch.nn.LayerNorm(latent_dim).to(device)

        self.ff1 = torch.nn.Linear(latent_dim, ff_dim).to(device)
        self.ff2 = torch.nn.Linear(ff_dim, latent_dim).to(device)
        self.gelu = torch.nn.GELU()

    def forward(self, x: torch.Tensor, ctx: torch.Tensor, attention_mask: torch.Tensor):
        # MHA self attention, add norm
        x = self.layer_norm(self.selfattention(x) + x)

        # MHA cross attention, add, norm
        x = self.layer_norm(
            self.crossattention(key=ctx, query=ctx, value=x, att_mask=attention_mask)
            + x
        )

        # ff, add, norm
        x = self.layer_norm(self.gelu(self.ff2(self.gelu(self.ff1(x)))) + x)

        return x


class CustomTransformerEncoder(torch.nn.Module):
    def __init__(
        self,
        attention_type: components.Attention,
        n_layers: int,
        latent_dim: int,
        ff_dim: int,
        n_heads: int,
        device: str = DEVICE,
        **kwargs
    ) -> None:
        super(CustomTransformerEncoder, self).__init__()
        for i in range(n_layers):
            self.add_module(
                str(i),
                CustomTransformerEncoderLayer(
                    attention=attention_type,
                    latent_dim=latent_dim,
                    ff_dim=ff_dim,
                    n_heads=n_heads,
                    device=device,
                    **kwargs
                ),
            )

    def forward(self, x: torch.Tensor, attention_mask: torch.Tensor):
        for module in self._modules.values():
            x = module(x, attention_mask=attention_mask)
        return x


class CustomTransformerDecoder(torch.nn.Module):
    def __init__(
        self,
        attention_type: components.Attention,
        n_layers: int,
        latent_dim: int,
        ff_dim: int,
        n_heads: int,
        device: str = DEVICE,
        **kwargs
    ) -> None:
        super(CustomTransformerDecoder, self).__init__()
        for i in range(n_layers):
            self.add_module(
                str(i),
                CustomTransformerDecoderLayer(
                    attention=attention_type,
                    latent_dim=latent_dim,
                    ff_dim=ff_dim,
                    n_heads=n_heads,
                    device=device,
                    **kwargs
                ),
            )

    def forward(self, x: torch.Tensor, ctx: torch.Tensor, attention_mask: torch.Tensor):
        for module in self._modules.values():
            x = module(x, ctx, attention_mask=attention_mask)
        return x


class AttentionModel(torch.nn.Module):
    def __init__(
        self,
        attention_type: attentions.Attention = attentions.ScaledDotProduct(dropout=0.1),
        latent_dim: int = 128,
        ff_dim: int = 1024,
        n_heads: int = 2,
        enc_layers: int = 1,
        dec_layers: int = 1,
        device: str = DEVICE,
    ) -> None:
        super(AttentionModel, self).__init__()

        # data
        self.n_heads = n_heads
        self.latent_dim = latent_dim

        # projection layer
        self.proj = torch.nn.Linear(2, latent_dim).to(device)

        # positional embedding encoder/decoder layers
        self.pos_embedding = embeddings.SinePositionalEmbedding(latent_dim).to(device)
        self.has_encoder = enc_layers >= 1
        self.has_decoder = dec_layers >= 1
        if self.has_encoder:
            self.encoder_layers = CustomTransformerEncoder(
                latent_dim=latent_dim,
                ff_dim=ff_dim,
                n_heads=n_heads,
                device=device,
                attention_type=attention_type,
                n_layers=enc_layers,
            )
        if self.has_decoder:
            self.decoder_layers = CustomTransformerDecoder(
                latent_dim=latent_dim,
                ff_dim=ff_dim,
                n_heads=n_heads,
                attention_type=attention_type,
                n_layers=dec_layers,
            )

        # output head
        self.head = torch.nn.Linear(latent_dim, 1).to(device)
        self.final_result = torch.nn.Linear(NUM_REACTIVITIES, NUM_REACTIVITIES).to(
            device
        )

        # activations
        self.relu = torch.nn.ReLU()
        self.gelu = torch.nn.GELU()

    def forward(self, x: torch.Tensor):
        mask = att_utils.maybe_merge_masks(
            att_mask=None,
            key_padding_mask=(x != 0).any(dim=-1),
            batch_size=x.shape[0],
            num_heads=self.n_heads,
            src_len=x.shape[1],
        )

        # project to latent dimension
        x = self.proj(x)

        # embed and then perform attention
        x = self.pos_embedding(x)
        if self.has_decoder and self.has_encoder:
            x = self.decoder_layers(
                x, ctx=self.encoder_layers(x, attention_mask=mask), attention_mask=mask
            )
        elif self.has_encoder:
            x = self.encoder_layers(x, attention_mask=mask)
        elif self.has_decoder:
            x = self.decoder_layers(x, ctx=x, attention_mask=mask)

        # final result
        x = self.relu(self.final_result(self.gelu(self.head(x).flatten(start_dim=1))))
        return x

### Instantiate Model

In [23]:
use_baseline = False

In [24]:
model_dms_kwargs = dict(
    latent_dim=32,
    n_heads=1,
    enc_layers=4,
    dec_layers=4,
    ff_dim=2048,
)
model_2a3_kwargs = dict(
    latent_dim=32,
    n_heads=1,
    enc_layers=4,
    dec_layers=4,
    ff_dim=2048,
)

In [25]:
# select the model
if use_baseline:
    model = BaselineModel()
elif desired_dataset == "dms":
    model = AttentionModel(**model_dms_kwargs)
elif desired_dataset == "2a3":
    model = AttentionModel(**model_2a3_kwargs)

In [26]:
# load old weights if possible
if os.path.exists(f"{desired_dataset}_model"):
    try:
        model.load_state_dict(torch.load(f"{desired_dataset}_model"))
        print("loaded previous weights")
    except Exception as e:
        print("not loading previous weights because", e)
        pass

loaded previous weights


In [27]:
# make sure that calling the model works as expected
inp = torch.zeros((2, NUM_REACTIVITIES, 2))
inp[:, 0, :] = 1

model(inp.to(DEVICE)).cpu().detach()

tensor([[0.0000e+00, 0.0000e+00, 2.9502e-02, 3.2924e-02, 0.0000e+00, 3.1256e-02,
         2.7086e-02, 1.0857e-02, 2.4872e-02, 3.2272e-02, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.1683e-02, 3.2852e-02,
         8.5152e-03, 2.9365e-02, 2.4005e-02, 2.1168e-02, 1.8684e-02, 3.7977e-02,
         3.7581e-02, 0.0000e+00, 3.0670e-01, 8.0938e-02, 1.2839e-01, 1.8455e-01,
         7.5870e-02, 2.2543e-01, 2.3599e-01, 2.4518e-01, 2.7627e-01, 2.6089e-01,
         2.1842e-01, 2.2787e-01, 2.6565e-01, 2.3789e-01, 2.3019e-01, 2.5626e-01,
         2.6424e-01, 2.9859e-01, 2.8677e-01, 2.7021e-01, 3.0774e-01, 2.9375e-01,
         2.8281e-01, 3.2400e-01, 2.3871e-01, 2.8888e-01, 2.5705e-01, 2.7922e-01,
         2.6251e-01, 2.6378e-01, 2.8091e-01, 2.4998e-01, 2.3434e-01, 1.7966e-01,
         1.7853e-01, 2.0429e-01, 2.2204e-01, 2.8498e-01, 2.7079e-01, 2.4806e-01,
         2.6214e-01, 2.5215e-01, 2.6187e-01, 2.8171e-01, 2.5798e-01, 2.7527e-01,
         3.1098e-01, 2.3426e

In [28]:
print(model)

AttentionModel(
  (proj): Linear(in_features=2, out_features=32, bias=True)
  (pos_embedding): SinePositionalEmbedding()
  (encoder_layers): CustomTransformerEncoder(
    (0): CustomTransformerEncoderLayer(
      (attention): MultiHeadDispatch(
        (attention): ScaledDotProduct(
          (attn_drop): Dropout(p=0.1, inplace=False)
        )
        (in_proj_container): InputProjection(
          (q_proj): Linear(in_features=32, out_features=32, bias=True)
          (k_proj): Linear(in_features=32, out_features=32, bias=True)
          (v_proj): Linear(in_features=32, out_features=32, bias=True)
        )
        (resid_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=32, out_features=32, bias=True)
      )
      (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (ff1): Linear(in_features=32, out_features=2048, bias=True)
      (ff2): Linear(in_features=2048, out_features=32, bias=True)
      (gelu): GELU(approximate='none')
    )
    (1)

In [29]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print("Total params:", params)

Total params: 1325851


In [30]:
# create optimizer
optimizer = torch.optim.AdamW(model.parameters(), 1e-4, weight_decay=1e-2)

## Train

### Setup Training Utils

In [31]:
run_name = "2a3_SN_4enc_4dec"
writer = SummaryWriter(f"runs/{run_name}")

In [32]:
BATCH_SIZE = 64
SHUFFLE = True

# create dataloaders
train_dataloader = data.DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE
)
# no point in shuffling validation set
val_dataloader = data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [33]:
def unweightedL1(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    weights: torch.Tensor,
    l1=torch.nn.L1Loss(reduction="none"),
):
    """
    MAE Loss function where sample weights are only used to determine masks.
    """
    return (l1(y_pred, y_true))[weights != 0].mean()


def weightedL1(
    y_pred: torch.Tensor,
    y_true: torch.Tensor,
    weights: torch.Tensor,
    l1=torch.nn.L1Loss(reduction="none"),
):
    """
    MAE loss function that takes into account sample weights
    """
    return (l1(y_pred, y_true) * weights)[weights != 0].mean()

In [34]:
def train_batch(
    m: torch.nn.Module, inps: torch.Tensor, outs: torch.Tensor, masks: torch.Tensor
):
    """
    Get the loss on a batch and perform the corresponding weight updates.
    Used for training purposes
    """
    optimizer.zero_grad()
    preds = m(inps)

    # get the weighted mae
    weighted_loss = weightedL1(preds, outs, masks)
    weighted_loss.backward()

    # calculate gradients
    optimizer.step()

    with torch.no_grad():
        unweighted_loss = unweightedL1(preds, outs, masks)

    # return weighted and unweighted mae loss
    return weighted_loss.detach().cpu(), unweighted_loss.detach().cpu()


def noupdate_batch(
    m: torch.nn.Module, inps: torch.Tensor, outs: torch.Tensor, masks: torch.Tensor
):
    """
    Get the loss on a batch without performing any updates.
    Used for validation purposes
    """
    with torch.no_grad():
        preds = m(inps)
        weighted_loss = weightedL1(preds, outs, masks)
        unweighted_loss = unweightedL1(preds, outs, masks)

    # return weighted and unweighted mae loss
    return weighted_loss.cpu(), unweighted_loss.cpu()


def masked_train(
    m: torch.nn.Module,
    train_dataloader: data.DataLoader,
    val_dataloader: data.DataLoader,
    epochs: int = 1,
    device: torch.device = DEVICE,
    log: bool = True,
):
    """
    Train the given model.

    Arguments:
        - m: torch.nn.Module - the model to train.
        - train_dataloader: data.Dataloader - the dataloader that provides the batched training data
        - val_dataloader: data.Dataloader - the dataloader that provides the batched validation data
        - epochs: int - how many epochs to train for. Defaults to `1`.
        - device: torch.device - the device to train on, defaults to `DEVICE`
        - log: bool - whether or not to log to `writer`
    """
    m = m.to(device)

    for epoch in range(1, epochs + 1):
        print(f"Epoch {epoch}")
        epoch_mae = 0.0
        epoch_weighted_mae = 0.0

        m = m.train()
        for tdata in (prog := tqdm(train_dataloader, desc="batch")):
            inps = torch.stack([tdata["inputs"], tdata["bpp"]], dim=-1)
            outs = tdata["outputs"]
            masks = tdata["output_masks"]

            inps = inps.to(device)
            outs = outs.to(device)
            masks = masks.to(device)

            weighted_mae, mae = train_batch(m, inps, outs, masks)

            epoch_weighted_mae += weighted_mae
            epoch_mae += mae

            # log
            prog.set_postfix_str(
                f"mae_loss: {mae:.5f}, weighted_mae_loss: {weighted_mae:.5f}"
            )

            # break  # used for sanity check
        epoch_weighted_mae /= len(train_dataloader)
        epoch_mae /= len(train_dataloader)

        # do validation
        val_mae = 0.0
        val_weighted_mae = 0.0
        m = m.eval()
        for vdata in val_dataloader:
            inps = torch.stack([vdata["inputs"], vdata["bpp"]], dim=-1)
            outs = vdata["outputs"]
            masks = vdata["output_masks"]

            inps = inps.to(device)
            outs = outs.to(device)
            masks = masks.to(device)
            weighted_mae, mae = noupdate_batch(m, inps, outs, masks)

            val_weighted_mae += weighted_mae
            val_mae += mae
        val_weighted_mae /= len(val_dataloader)
        val_mae /= len(val_dataloader)

        print(
            f"Epoch MAE: {epoch_mae:.5f}\tEpoch Weighted MAE: {epoch_weighted_mae:.5f}\t"
            + f"Val MAE: {val_mae:.5f}\tVal Weighted MAE: {val_weighted_mae:.5f}"
        )

        if log:
            writer.add_scalar("epoch_mae", epoch_mae, global_step=epoch)
            writer.add_scalar("val_mae", val_mae, global_step=epoch)

### Actually Train

In [35]:
# baseline gets ~ 3.02 on dms w/ 10 epochs, ~ 3.87 on 2a3 w/ 20 epochs
masked_train(
    model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    epochs=3,
    log=False,
)

Epoch 1


batch: 100%|██████████| 2968/2968 [08:44<00:00,  5.66it/s, mae_loss: 0.21019, weighted_mae_loss: 0.18523]


Epoch MAE: 0.19372	Epoch Weighted MAE: 0.15266	Val MAE: 0.19210	Val Weighted MAE: 0.15071
Epoch 2


batch: 100%|██████████| 2968/2968 [08:47<00:00,  5.63it/s, mae_loss: 0.23693, weighted_mae_loss: 0.15563]


Epoch MAE: 0.19339	Epoch Weighted MAE: 0.15238	Val MAE: 0.19188	Val Weighted MAE: 0.15054
Epoch 3


batch: 100%|██████████| 2968/2968 [08:51<00:00,  5.58it/s, mae_loss: 0.11292, weighted_mae_loss: 0.10283]


Epoch MAE: 0.19312	Epoch Weighted MAE: 0.15218	Val MAE: 0.19175	Val Weighted MAE: 0.15044


## Pseudo Label data

In [36]:
def label(m: torch.nn.Module, t0: Dataset, device: torch.device = DEVICE):
    """
    Label T with predictions
    """

    def predict(row):
        with torch.no_grad():
            row["outputs"] = m(
                torch.stack([row["inputs"], row["bpp"]], dim=-1).to(device)
            ).cpu()
        return row

    t0 = t0.map(
        predict, batched=True, batch_size=BATCH_SIZE, load_from_cache_file=False
    )
    return t0


def train_pseudo(
    m: torch.nn.Module,
    t0: Dataset,
    epochs: int = 1,
    device: torch.device = DEVICE,
):
    """
    Make pseudo-labeled data and train on that pseudo-labeled data

    Arguments:
        - m: torch.nn.Module - the model to train
        - t0: Dataset - the dataset containing `inputs` and `bpp` on which to pseudo-label and train on
        - epochs: int - how many epochs to do on the pseudo-labeled data
        - device: torch.device - the device to train on, defaults to `DEVICE`
    """
    m = m.train().to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch}")

        t0 = label(m, t0, device)
        dataloader = data.DataLoader(t0, batch_size=BATCH_SIZE)

        epoch_weighted_mae = 0
        epoch_mae = 0

        for pdata in (prog := tqdm(dataloader)):
            inps = torch.stack([pdata["inputs"], pdata["bpp"]], dim=-1).to(device)
            outs = pdata["outputs"].to(device)

            weights = torch.where(inps[:, :, 0] != 0, 1, 0).to(device)
            weighted_mae, mae = train_batch(m, inps, outs, weights)

            epoch_weighted_mae += weighted_mae
            epoch_mae += mae

            # log
            prog.set_postfix_str(
                f"mae_loss: {mae:.5f}, weighted_mae_loss: {weighted_mae:.5f}"
            )

        epoch_weighted_mae /= len(dataloader)
        epoch_mae /= len(dataloader)

        print(
            f"Pseudo Epoch MAE: {epoch_mae:.5f}\tEpoch Weighted MAE: {epoch_weighted_mae:.5f}"
        )

In [37]:
def ssl_train(
    m: torch.nn.Module,
    train_dataloader: data.DataLoader,
    val_dataloader: data.DataLoader,
    t0: str = "test_data_preprocessed",
    epochs: int = 1,
    s_epochs: int = 1,
    p_epochs: int = 1,
    device: torch.device = DEVICE,
):
    """
    Train a model via Semi Supervised Learning.

    The process is to train it on the T0 (the unlabeled data) for `pseudo_epochs` epochs
    before training it on real data for `real_epochs` epochs.

    Arguments:
        - m: torch.nn.Module - the model to train.
        - train_dataloader: data.Dataloader - the dataloader that provides the batched training data
        - val_dataloader: data.Dataloader - the dataloader that provides the batched validation data
        - t0: str - the dataset that contains T0 (unlabeled data). Defaults to `test_data_preprocessed`
        - epochs: int - how many epochs to train for. Defaults to `1`. Note that one epoch involves
            training on BOTH real and pseudo datasets
        - s_epochs: int - how many sub-epochs to do on the train dataset (the supervised dataset). Defaults to `1`.
        - p_epochs: int - how many sub-epochs to do on the pseudo-labeled data. Defaults to `1`.
        - device: torch.device - the device to train on, defaults to `DEVICE`
    """

    t0 = Dataset.load_from_disk(t0).with_format("torch")

    for epoch in range(1, epochs + 1):
        print("=" * 64)
        print(f"Macro Epoch {epoch}/{epochs}")
        print("=" * 32)
        print(f"Pseudo-labeled Epochs")

        # do pseudo training first
        train_pseudo(m, t0, p_epochs, device)
        print("=" * 32)
        print(f"Supervised Epochs")

        # do real training
        masked_train(m, train_dataloader, val_dataloader, s_epochs, device, log=False)

In [38]:
ssl_train(model, train_dataloader, val_dataloader, epochs=20, s_epochs=4)

Epoch 0 Pseudo
Epoch 0


Map:   0%|          | 0/1343823 [00:00<?, ? examples/s]

100%|██████████| 20998/20998 [1:02:38<00:00,  5.59it/s, mae_loss: 0.00612, weighted_mae_loss: 0.00612]


Pseudo Epoch MAE: 0.01669	Epoch Weighted MAE: 0.01669
Epoch 0 Real
Epoch 1


batch:  53%|█████▎    | 1566/2968 [06:18<04:31,  5.16it/s, mae_loss: 0.18101, weighted_mae_loss: 0.14370]

## Save

Now that we've trained our model, we can save its weights and biases ("state dict") so that we can load them for later inferencing or further training

In [None]:
torch.save(model.state_dict(), f"{desired_dataset}_model")

## Process Outputs

Once we have both models, it's time to create a submission file. 
This section creates a zipped csv submission file that can
be submitted on Kaggle.

In [None]:
make_submissions = False

In [None]:
valid = False

if (
    os.path.exists("2a3_model")
    and os.path.exists("dms_model")
    and os.path.exists("test_sequences.csv")
    and make_submissions
):
    if use_baseline:
        model_dms = BaselineModel()
        model_2a3 = BaselineModel()
    else:
        model_2a3 = AttentionModel(**model_2a3_kwargs)
        model_dms = AttentionModel(**model_dms_kwargs)

    model_2a3.load_state_dict(torch.load("2a3_model"))
    model_dms.load_state_dict(torch.load("dms_model"))

    model_2a3.eval().to(DEVICE)
    model_dms.eval().to(DEVICE)

    valid = True
else:
    print("Not going to create submissions.")

Not going to create submissions.


In [None]:
def pipeline(
    model_2a3: torch.nn.Module,
    model_dms: torch.nn.Module,
    input_ds: str,
    out: str,
    batch_size: int,
):
    """
    Make predictions on the test dataset and write them to a csv file

    Parameters:
        - model_2a3: torch.nn.Module - the model trained on the 2a3 distribution
        - model_dms: torch.nn.Module - the model trained on the dms distribution
        - input_ds: str - name of the dataset to load
        - out: str - name of the file to write to
        - batch_size: int - size of the batches to use to process the data.
            In general, larger batch sizes mean faster runtime
    """
    ds = Dataset.load_from_disk(input_ds).with_format("torch")
    loader = data.DataLoader(ds, batch_size=batch_size, shuffle=False)

    iterable = iter(loader)

    with open(out, "w") as outfile:
        # write the header
        outfile.write("id,reactivity_DMS_MaP,reactivity_2A3_MaP\n")

        for _ in tqdm(range(len(loader))):
            # get the next group of data
            tdata = next(iterable)
            inputs = torch.stack([tdata["inputs"], tdata["bpp"]], dim=-1).to(DEVICE)
            min_ids = tdata["id_min"].numpy()
            max_ids = tdata["id_max"].numpy()

            # make predictions w/o gradients
            with torch.no_grad():
                preds_2a3 = model_2a3(inputs).cpu().numpy()
                preds_dms = model_dms(inputs).cpu().numpy()

            # write preds
            for i in range(inputs.shape[0]):
                outfile.writelines(
                    map(
                        lambda seq_idx: f"{seq_idx},{preds_dms[i, seq_idx-min_ids[i]]:.3f},{preds_2a3[i, seq_idx-min_ids[i]]:.3f}\n",
                        # +1 since the id_max is inclusive
                        range(min_ids[i], max_ids[i] + 1),
                    )
                )

In [None]:
if valid:
    pipeline(
        model_2a3,
        model_dms,
        "test_data_preprocessed",
        "submission.csv",
        batch_size=BATCH_SIZE,
    )
else:
    print("Not going to create submissions.")

Not going to create submissions.


In [None]:
if valid:
    # zip our submission into an easily-uploadable zip file
    print("zipping submissions. This may take a while...")
    os.system("zip submission.csv.zip submission.csv")
    print("Done zipping submissions!")
else:
    print("Not going to zip submissions.")

Not going to zip submissions.
