# HydrAMP training on Colab

This Colab notebook trains the simplified HydrAMP variational autoencoder on the public [`pszmk/LAMP-datasets`](https://huggingface.co/datasets/pszmk/LAMP-datasets) release.  It mirrors the reference implementation from the Szczurek lab while:

* expanding the latent space to 128 dimensions,
* removing decoder conditioning inputs, and
* providing lightweight checkpoint helpers for saving and restoring training progress.

> **Tip:** Select *Runtime → Change runtime type → GPU* in Colab for practical training times.  CPU execution works for demonstration purposes but is slower.

In [None]:
!pip install -q datasets==2.19.1

In [None]:
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

from datasets import load_dataset

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
NUM_EPOCHS = 3
LEARNING_RATE = 1e-3
MAX_LENGTH = 25
PAD_TOKEN = " "
VOCAB = [PAD_TOKEN] + list("ACDEFGHIKLMNPQRSTVWY")
TOKEN_TO_IDX = {token: idx for idx, token in enumerate(VOCAB)}
PAD_IDX = TOKEN_TO_IDX[PAD_TOKEN]
DATASET_REPO = "pszmk/LAMP-datasets"
DATA_SOURCE = "huggingface"  # choose 'huggingface' for the full dataset or 'sample' for the tiny demo set
CHECKPOINT_DIR = Path("checkpoints")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

SAMPLE_DATA = {
    "train": [
        {"sequence": "GIGKFLHSAKKFGKAFVGEIMNS"},
        {"sequence": "KWKLFKKIEKVGQNIRDGIIKAGPAVAVVGQAT"},
        {"sequence": "ILPWKWPWWPWRR"},
        {"sequence": "LKLKLLLLKLK"},
        {"sequence": "GLFDIVKKVVGAFGSL"},
        {"sequence": "VNWKKVLGKIIKVVTMTTV"},
        {"sequence": "FFHHIFRGIVHVGKTIHRLVTG"},
        {"sequence": "ILPWRWPWWPWRR"},
        {"sequence": "GLWSKIKNVAAAAGKAALNAVN"},
        {"sequence": "KKLFKKILKYL"},
        {"sequence": "LLKKLLKKLLKK"},
        {"sequence": "GKKKKFLKKAKKFG"},
        {"sequence": "GWKRWWWW"},
        {"sequence": "GRWLRRFLRKIRRFRPPYLPRPRPRPV"},
        {"sequence": "KKVLKKSYKLLK"},
        {"sequence": "KWKLFKKIGAVLKVL"},
        {"sequence": "GLFKVLGKKISGLL"},
        {"sequence": "WFKKWWKFK"},
        {"sequence": "LKKIGKKIERVGQNTR"},
        {"sequence": "LRKKLWKKLLKLL"},
    ],
    "validation": [
        {"sequence": "FFRLLHSLGKIIKG"},
        {"sequence": "GKKLFKKKGGH"},
        {"sequence": "GLKLRFEK"},
        {"sequence": "GIGKFLHSAGKFGKAF"},
        {"sequence": "GLFDIVKKLVGAFGSL"},
    ],
}

## Load and preprocess the dataset

The Hugging Face dataset exposes peptide sequences as strings.  Each residue is mapped to an integer index, then sequences are padded or truncated to 25 residues (the HydrAMP target length).  By default the notebook downloads the full dataset from Hugging Face.  Set `DATA_SOURCE = "sample"` in the configuration cell above if you only want to run the tiny bundled demo split, and the helper will still fall back automatically when the remote download is unavailable.

In [None]:
def encode_sequence(sequence: str, max_length: int = MAX_LENGTH) -> torch.Tensor:
    sequence = sequence.upper()
    tokens = [TOKEN_TO_IDX.get(residue, PAD_IDX) for residue in sequence]
    if len(tokens) < max_length:
        tokens.extend([PAD_IDX] * (max_length - len(tokens)))
    else:
        tokens = tokens[:max_length]
    return torch.tensor(tokens, dtype=torch.long)


class LAMPSequenceDataset(Dataset):
    def __init__(self, split: list[Dict[str, str]]):
        self.encoded = [encode_sequence(item["sequence"]) for item in split]

    def __len__(self) -> int:
        return len(self.encoded)

    def __getitem__(self, index: int) -> torch.Tensor:
        return self.encoded[index]


def _sequence_iterable(raw_split: Any) -> list[Dict[str, str]]:
    if isinstance(raw_split, list):
        if raw_split and isinstance(raw_split[0], dict):
            return raw_split
        return [{"sequence": str(item)} for item in raw_split]
    data = list(raw_split)
    if data and isinstance(data[0], dict):
        if "sequence" in data[0]:
            return data
        return [{"sequence": item.get("sequence", str(item))} for item in data]
    return [{"sequence": str(item)} for item in data]


def _load_sample_splits() -> tuple[list[Dict[str, str]], list[Dict[str, str]]]:
    print("Using bundled sample data (demo only).")
    return SAMPLE_DATA["train"], SAMPLE_DATA["validation"]


VALID_DATA_SOURCES = {"huggingface", "sample"}


def load_lamp_data(data_source: str = DATA_SOURCE) -> tuple[LAMPSequenceDataset, LAMPSequenceDataset]:
    source = data_source.lower()
    if source not in VALID_DATA_SOURCES:
        raise ValueError(f"Unknown data source {data_source!r}. Choose 'huggingface' or 'sample'.")

    if source == "sample":
        train_split, val_split = _load_sample_splits()
    else:
        try:
            train_split = load_dataset(DATASET_REPO, split="train")
            try:
                val_split = load_dataset(DATASET_REPO, split="validation")
            except ValueError:
                val_split = load_dataset(DATASET_REPO, split="test")
        except Exception as exc:  # noqa: BLE001
            print(f"Falling back to bundled sample data because dataset download failed: {exc}")
            train_split, val_split = _load_sample_splits()

    train_seq = _sequence_iterable(train_split)
    val_seq = _sequence_iterable(val_split)
    return LAMPSequenceDataset(train_seq), LAMPSequenceDataset(val_seq)


train_dataset, val_dataset = load_lamp_data(DATA_SOURCE)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)



## Define the HydrAMP encoder, decoder, and checkpoint helpers

The custom GRU layers follow the structure from the original HydrAMP project.  The decoder projects the 128-dimensional latent vector into the recurrent state size before unrolling the decoder GRU.  Lightweight `save_checkpoint` and `load_checkpoint` functions store the model, optimiser, and bookkeeping state.

In [None]:
class HydrAMPGRU(nn.Module):
    def __init__(self, units: int = 66, input_units: int = 66, output_len: int = 25, device: str | torch.device = "cpu"):
        super().__init__()
        self.output_len = output_len
        self.units = units
        self.input_units = input_units
        self.device = torch.device(device)
        self.kernel = nn.Parameter(torch.zeros(size=(input_units, units * 3), device=self.device))
        self.recurrent_kernel = nn.Parameter(torch.zeros(size=(units, units * 3), device=self.device))
        self.bias = nn.Parameter(torch.zeros(size=(units * 3,), device=self.device))

    def cell_forward(self, inputs: torch.Tensor, state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        h_tm1 = state
        matrix_x = torch.matmul(inputs, self.kernel)
        matrix_x = matrix_x + self.bias
        x_z, x_r, x_h = torch.split(matrix_x, self.units, dim=-1)

        matrix_inner = torch.matmul(h_tm1, self.recurrent_kernel[: self.units * 2])
        recurrent_z, recurrent_r, recurrent_h = torch.split(matrix_inner, self.units, dim=-1)

        z = torch.sigmoid(x_z + recurrent_z)
        r = torch.sigmoid(x_r + recurrent_r)

        recurrent_h = torch.matmul(r * h_tm1, self.recurrent_kernel[:, 2 * self.units :])
        hh = torch.tanh(x_h + recurrent_h)
        h = z * h_tm1 + (1 - z) * hh
        new_state = h
        return h, new_state

    def forward(self, input_: Optional[torch.Tensor], state: Optional[torch.Tensor] = None) -> torch.Tensor:
        if input_ is None:
            if state is None:
                raise ValueError("Either input_ or state must be provided to HydrAMPGRU.forward.")
            input_ = torch.zeros((state.shape[0], self.input_units), device=self.device)
        if state is None:
            state = torch.zeros((input_.shape[0], self.units), device=self.device)
        current_output = input_
        current_state = state
        outputs = []
        for _ in range(self.output_len):
            current_output, current_state = self.cell_forward(current_output, current_state)
            outputs.append(current_output)
        return torch.stack(outputs, dim=1)

    def forward_on_sequence(self, input_: torch.Tensor, state: Optional[torch.Tensor] = None) -> torch.Tensor:
        if state is None:
            state = torch.zeros((input_.shape[0], self.units), device=self.device)
        current_state = state
        outputs = []
        for i in range(input_.shape[1]):
            current_output, current_state = self.cell_forward(input_[:, i], current_state)
            outputs.append(current_output)
        return torch.stack(outputs, dim=1)


class HydrAMPDecoder(nn.Module):
    def __init__(self, device: str | torch.device = "cpu"):
        super().__init__()
        self.device = torch.device(device)
        self.gru = HydrAMPGRU(units=66, input_units=66, device=self.device)
        self.latent_to_state = nn.Linear(128, self.gru.units).to(self.device)
        self.lstm = nn.LSTM(66, 100, batch_first=True).to(self.device)
        self.dense = nn.Linear(100, len(VOCAB)).to(self.device)

    def forward(self, latent_state: torch.Tensor, return_logits: bool = True, gumbel_temperature: float = 0.001) -> torch.Tensor:
        latent_state = latent_state.to(self.device)
        initial_state = self.latent_to_state(latent_state)
        gru_output = self.gru(None, initial_state)
        lstm_output = self.lstm(gru_output)[0]
        dense_output = self.dense(lstm_output)
        if return_logits:
            return dense_output
        return torch.nn.functional.gumbel_softmax(dense_output, tau=gumbel_temperature, hard=False)


class HydrAMPEncoder(nn.Module):
    def __init__(self, device: str | torch.device = "cpu"):
        super().__init__()
        self.device = torch.device(device)
        self.embedding = nn.Embedding(num_embeddings=len(VOCAB), embedding_dim=100, device=self.device)
        self.gru1_f = HydrAMPGRU(input_units=100, units=128, device=self.device)
        self.gru1_r = HydrAMPGRU(input_units=100, units=128, device=self.device)
        self.gru2_f = HydrAMPGRU(input_units=256, units=128, device=self.device)
        self.gru2_r = HydrAMPGRU(input_units=256, units=128, device=self.device)
        self.mean_linear = nn.Linear(256, 128, device=self.device)
        self.logvar_linear = nn.Linear(256, 128, device=self.device)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        x = x.to(self.device)
        embeddings = self.embedding(x)
        gru1_f_output = self.gru1_f.forward_on_sequence(embeddings)
        gru1_r_output = self.gru1_r.forward_on_sequence(torch.flip(embeddings, (1,)))
        gru_1_output = torch.concat([gru1_f_output, torch.flip(gru1_r_output, (1,))], dim=-1)
        gru2_f_output = self.gru2_f.forward_on_sequence(gru_1_output)
        gru2_r_output = self.gru2_r.forward_on_sequence(torch.flip(gru_1_output, (1,)))
        gru_2_output = torch.concat([gru2_f_output[:, -1], gru2_r_output[:, -1]], dim=-1)
        mean = self.mean_linear(gru_2_output)
        logvar = self.logvar_linear(gru_2_output)
        return mean, logvar


@dataclass
class HydrAMPCheckpoint:
    encoder_state_dict: Dict[str, Any]
    decoder_state_dict: Dict[str, Any]
    optimizer_state_dict: Dict[str, Any]
    epoch: int
    global_step: int
    extra: Dict[str, Any]

    def to_dict(self) -> Dict[str, Any]:
        return {
            "encoder_state_dict": self.encoder_state_dict,
            "decoder_state_dict": self.decoder_state_dict,
            "optimizer_state_dict": self.optimizer_state_dict,
            "epoch": self.epoch,
            "global_step": self.global_step,
            "extra": self.extra,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> "HydrAMPCheckpoint":
        return cls(
            encoder_state_dict=data["encoder_state_dict"],
            decoder_state_dict=data["decoder_state_dict"],
            optimizer_state_dict=data["optimizer_state_dict"],
            epoch=data.get("epoch", 0),
            global_step=data.get("global_step", 0),
            extra=data.get("extra", {}),
        )


def save_checkpoint(path: str | Path, encoder: HydrAMPEncoder, decoder: HydrAMPDecoder, optimizer: torch.optim.Optimizer, epoch: int, global_step: int, extra: Optional[Dict[str, Any]] = None) -> None:
    checkpoint = HydrAMPCheckpoint(
        encoder_state_dict=encoder.state_dict(),
        decoder_state_dict=decoder.state_dict(),
        optimizer_state_dict=optimizer.state_dict(),
        epoch=epoch,
        global_step=global_step,
        extra=extra or {},
    )
    torch.save(checkpoint.to_dict(), Path(path))


def load_checkpoint(path: str | Path, encoder: HydrAMPEncoder, decoder: HydrAMPDecoder, optimizer: Optional[torch.optim.Optimizer] = None, map_location: str | torch.device | None = None) -> HydrAMPCheckpoint:
    raw_checkpoint = torch.load(Path(path), map_location=map_location)
    checkpoint = HydrAMPCheckpoint.from_dict(raw_checkpoint)
    encoder.load_state_dict(checkpoint.encoder_state_dict)
    decoder.load_state_dict(checkpoint.decoder_state_dict)
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint.optimizer_state_dict)
    return checkpoint

## Instantiate the models and optimiser

In [None]:
encoder = HydrAMPEncoder(device=DEVICE).to(DEVICE)
decoder = HydrAMPDecoder(device=DEVICE).to(DEVICE)
parameters = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam(parameters, lr=LEARNING_RATE)

global_step = 0

## Train for a few epochs

The loop minimises the evidence lower bound (ELBO) consisting of a categorical cross-entropy reconstruction term plus a KL divergence regulariser that keeps the latent distribution close to a unit Gaussian.  Each epoch saves a checkpoint to the `checkpoints/` directory.

In [None]:
def kl_divergence(mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
    return -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp(), dim=1)


def run_epoch(epoch: int, train: bool = True) -> float:
    global global_step
    data_loader = train_loader if train else val_loader
    encoder.train(train)
    decoder.train(train)
    total_loss = 0.0
    total_batches = 0
    for batch in data_loader:
        batch = batch.to(DEVICE)
        with torch.set_grad_enabled(train):
            mean, logvar = encoder(batch)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            latent = mean + eps * std
            logits = decoder(latent)
            recon_loss = F.cross_entropy(
                logits.view(-1, len(VOCAB)),
                batch.view(-1),
                ignore_index=PAD_IDX,
            )
            kl_loss = kl_divergence(mean, logvar).mean()
            loss = recon_loss + 0.1 * kl_loss
            if train:
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(parameters, max_norm=1.0)
                optimizer.step()
                global_step += 1
        total_loss += loss.detach().item()
        total_batches += 1
    avg_loss = total_loss / max(1, total_batches)
    if train:
        print(f"Epoch {epoch} train loss: {avg_loss:.4f}")
    else:
        print(f"Epoch {epoch} val loss: {avg_loss:.4f}")
    return avg_loss


history = {"train": [], "val": []}
for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = run_epoch(epoch, train=True)
    val_loss = run_epoch(epoch, train=False)
    history["train"].append(train_loss)
    history["val"].append(val_loss)
    checkpoint_path = CHECKPOINT_DIR / f"hydramp_epoch_{epoch:03d}.pt"
    save_checkpoint(
        checkpoint_path,
        encoder=encoder,
        decoder=decoder,
        optimizer=optimizer,
        epoch=epoch,
        global_step=global_step,
        extra={"train_loss": train_loss, "val_loss": val_loss},
    )
    print(f"Saved checkpoint to {checkpoint_path}")

## Restore from the latest checkpoint

Use `load_checkpoint` to resume training or run inference with a trained model.  The snippet below restores the most recent checkpoint saved in the session.

In [None]:
latest_checkpoint = None
if CHECKPOINT_DIR.exists():
    checkpoints = sorted(CHECKPOINT_DIR.glob("hydramp_epoch_*.pt"))
    if checkpoints:
        latest_checkpoint = checkpoints[-1]

if latest_checkpoint is not None:
    print(f"Loading checkpoint from {latest_checkpoint}")
    state = load_checkpoint(
        latest_checkpoint,
        encoder=encoder,
        decoder=decoder,
        optimizer=optimizer,
        map_location=DEVICE,
    )
    print(
        f"Restored epoch={state.epoch}, global_step={state.global_step}, "
        f"val_loss={state.extra.get('val_loss'):.4f}"
    )
else:
    print("No checkpoint found yet.")