## Saving, Loading, and Converting Models

A challenge faced by bpnet-lite is that the official BPNet and ChromBPNet repositories are in TensorFlow. This means that optimal usage of 

### Saving Models

Let's start simple. Given a bpnet-lite model (either created using bpnet-lite or loaded into PyTorch using one of the techniques below) you can save the model in the same way you save any other PyTorch model.

In [2]:
import torch
from bpnetlite import BPNet

toy_model = BPNet(n_filters=4, n_layers=2) # Make the model small to save disk space

torch.save(toy_model, "toy_bpnet_model.torch")

Because bpnet-lite attempts to be as low-level as possible, there are no additional wrappers or tricks for saving these models. Any feature that is present in PyTorch can be used out-of-the-box with bpnet-lite models.

This works the same with ChromBPNet models. You can either save the entire model or save either of the components.

In [3]:
from bpnetlite import ChromBPNet

accessibility = BPNet(n_filters=4, n_layers=2)
bias = BPNet(n_filters=4, n_layers=2)

toy_chrombpnet_model = ChromBPNet(bias, accessibility)

torch.save(toy_chrombpnet_model, "toy_chrombpnet_model.torch")
torch.save(toy_chrombpnet_model.bias, "toy_chrombpnet_model.bias.torch")
torch.save(toy_chrombpnet_model.accessibility, "toy_chrombpnet_model.accessibility.torch")

These models can also be saved after being wrapped, if you would like to do that. Personally, I do not do this because then I have to remember which models I have wrapped which ways. I would rather just save the base models and re-wrap them however I need to for each subsequent analysis.

In [4]:
from bpnetlite.bpnet import CountWrapper

toy_count_model = CountWrapper(toy_model)

torch.save(toy_count_model, "toy_count_model.torch")

### Loading Models
#### From PyTorch

Loading models that have been trained using bpnet-lite is just as easy as loading any other PyTorch model. This can be done directly through the load command. If it's just the 

In [5]:
toy_model2 = torch.load("toy_model.torch", weights_only=False)
toy_model2

BPNet(
  (iconv): Conv1d(4, 4, kernel_size=(21,), stride=(1,), padding=(10,))
  (irelu): ReLU()
  (rconvs): ModuleList(
    (0): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (1): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
  )
  (rrelus): ModuleList(
    (0-1): 2 x ReLU()
  )
  (fconv): Conv1d(6, 2, kernel_size=(75,), stride=(1,), padding=(37,))
  (linear): Linear(in_features=5, out_features=1, bias=True)
)

In [None]:
If we have wrapped the models, we can also load them the same way.

In [6]:
toy_count_model2 = torch.load("toy_count_model.torch", weights_only=False)
toy_count_model2

CountWrapper(
  (model): BPNet(
    (iconv): Conv1d(4, 4, kernel_size=(21,), stride=(1,), padding=(10,))
    (irelu): ReLU()
    (rconvs): ModuleList(
      (0): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): Conv1d(4, 4, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
    )
    (rrelus): ModuleList(
      (0-1): 2 x ReLU()
    )
    (fconv): Conv1d(6, 2, kernel_size=(75,), stride=(1,), padding=(37,))
    (linear): Linear(in_features=5, out_features=1, bias=True)
  )
)

#### From Official Repositories (ChromBPNet)

#### From tar.gz files

On the ENCODE Portal, ChromBPNet models come in sets of five with one model trained on each of five cross-chromosomal folds. These models, along with important metadata, are packaged together and uploaded as a single tar.gz file. One could untar these files and then operate on the model files independently using the code above, but that might be inconvenient, result in too many files, or one may want to keep all of these compressed together. Conveniently, one can load models directly from these tar.gz files without needing to unpack them.

To load directly from a tar.gz file, first you'll need to find where in the tar your model files are. 

In [12]:
import tarfile

with tarfile.open("ENCFF574YLK.tar.gz", "r:gz") as tar:
    for filename in tar.getnames():
        print(filename)

.
./fold_2
./fold_2/model.chrombpnet_nobias.fold_2.ENCSR000EOT.tar
./fold_2/model.bias_scaled.fold_2.ENCSR000EOT.h5
./fold_2/logs.models.fold_2.ENCSR000EOT
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.stdout_v1.txt
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.batch_loss.tsv
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.chrombpnet_model_params.tsv
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.args.json
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.epoch_loss.csv
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.chrombpnet_data_params.tsv
./fold_2/logs.models.fold_2.ENCSR000EOT/logfile.modelling.fold_2.ENCSR000EOT.chrombpnet.params.json
./fold_2/model.bias_scaled.fold_2.ENCSR000EOT.tar
./fold_2/model.chrombpnet_nobias.fold_2.ENCSR000EOT.h5
./fold_2/model.chrombpnet.fold_2.ENCSR000EOT.h5
./fold_2/model.chrom

We can see that the main structure of the tar.gz is five folders, with one folder for each of the five folds. Within each one there is an accessibility model and a bias model, logs and some other files. We want the .h5 file. We then need to decompress the accessibility and bias portions of the h5 (we do not need to read the entire thing into memory, which is nice), use the `BytesIO` wrapper to convert this stream into a fake readable file, and pass that into the existing code.

Below is the entirety of the code to load up the ChromBPNet model directly from the tar.gz.

In [14]:
import tarfile

from io import BytesIO

from bpnetlite.chrombpnet import ChromBPNet

with tarfile.open("ENCFF574YLK.tar.gz", "r:gz") as tar:    
    bias_tar = tar.extractfile("./fold_0/model.bias_scaled.fold_0.ENCSR000EOT.h5").read()
    accessibility_tar = tar.extractfile("./fold_0/model.chrombpnet_nobias.fold_0.ENCSR000EOT.h5").read()

chrombpnet = ChromBPNet.from_chrombpnet(
    BytesIO(bias_tar),
    BytesIO(accessibility_tar)
)

chrombpnet

ChromBPNet(
  (bias): BPNet(
    (iconv): Conv1d(4, 128, kernel_size=(21,), stride=(1,), padding=(10,))
    (irelu): ReLU()
    (rconvs): ModuleList(
      (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(4,), dilation=(4,))
      (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(8,), dilation=(8,))
      (3): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(16,), dilation=(16,))
    )
    (rrelus): ModuleList(
      (0-3): 4 x ReLU()
    )
    (fconv): Conv1d(128, 1, kernel_size=(75,), stride=(1,), padding=(37,))
    (linear): Linear(in_features=128, out_features=1, bias=True)
  )
  (accessibility): BPNet(
    (iconv): Conv1d(4, 512, kernel_size=(21,), stride=(1,), padding=(10,))
    (irelu): ReLU()
    (rconvs): ModuleList(
      (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
      (1): Conv1d(512, 512, kernel_size=(3,), stride=(1

This results in exactly the same type of model as if one unzipped the file and loaded it individually.

### Converting Models

At this point, converting models is simple. We know how to load models into PyTorch from a variety of formats, and we know how to save a model once it is in the bpnet-lite PyTorch format. Let's consider the situation where we want to download a model from the ENCODE Portal that was trained using the official TensorFlow repository and convert it into PyTorch. It's basically just one more line of code compared to the loading cell above.

In [15]:
import tarfile

from io import BytesIO

from bpnetlite.chrombpnet import ChromBPNet

with tarfile.open("ENCFF574YLK.tar.gz", "r:gz") as tar:    
    bias_tar = tar.extractfile("./fold_0/model.bias_scaled.fold_0.ENCSR000EOT.h5").read()
    accessibility_tar = tar.extractfile("./fold_0/model.chrombpnet_nobias.fold_0.ENCSR000EOT.h5").read()

chrombpnet = ChromBPNet.from_chrombpnet(
    BytesIO(bias_tar),
    BytesIO(accessibility_tar)
)

# New line of code here saving the model
torch.save(chrombpnet, "chrombpnet-test-model.torch")