In [1]:
import os
import logging
from datetime import datetime
from pathlib import Path

import boto3
import h5py
import pandas as pd
import torch
from accelerate import Accelerator
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

In [2]:
torch.multiprocessing.set_sharing_strategy('file_descriptor')

# Defines

## Constants

In [3]:
accelerator = Accelerator()
device = accelerator.device

In [4]:
bucket = ""

In [5]:
s3_output_path = ""

In [6]:
embeddings_s3_key = "interpro/processed/data_sample/esm_embeddings.h5"
embeddings_dir = "/mnt/sagemaker-nvme/esm/data/processed/"
embeddings_path = os.path.join(embeddings_dir, "esm_embeddings.h5")

In [7]:
if not Path(embeddings_path).is_file():
    os.makedirs(embeddings_dir, exist_ok=True)
    s3 = boto3.client('s3')
    s3.download_file(bucket, embeddings_s3_key, embeddings_path)

In [8]:
random_state = 42
train_frac = 0.8
test_frac = 0.5

model_save_dir = "model_save"

## Methods

In [9]:
class InterProDataset(Dataset):
    def __init__(
        self,
        df_proteinipr: pd.DataFrame,
        embeddings_path: str,
        index: pd.Index | None = None
    ):
        self.df_proteinipr = df_proteinipr
        self.embeddings_path = embeddings_path
        self.protein_embeddings = None

        if index is None:
            self.index = self.df_proteinipr.index
        else:
            self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx: int):
        idx = self.index[idx]

        if self.protein_embeddings is None:
            self.protein_embeddings = h5py.File(self.embeddings_path, "r")["protein_embeddings"]
        # Features - ESM protein embeddings
        x = torch.from_numpy(self.protein_embeddings[idx])

        # Labels
        y = torch.zeros(x.shape[0], dtype=torch.long)
        # TODO this only supports the demo case of working with just IPR004839 (or any other single InterPro ID)
        y_start, y_end = self.df_proteinipr.iloc[idx]["start"] + 1, self.df_proteinipr.iloc[idx]["end"] + 1
        y[y_start:y_end] = 1

        return x, y

In [10]:
class InterProModel(nn.Module):
    def __init__(
        self,
        n_layers: int,
        d_model: int,
        d_ff: int = 2048,
        n_heads: int = 4,
        dropout: float = 0.1,
        activation: str = "gelu",
        output_dim: int = 2
    ):
        super(InterProModel, self).__init__()

        if activation == "gelu":
            output_activation = nn.GELU()
        elif activation == "relu":
            output_activation = nn.ReLU()
        else:
            raise RuntimeError(f"activation should be relu/gelu, not {activation}")

        self.layers = nn.ModuleList(
            [
                nn.LayerNorm(d_model)
            ] + \
            [
                nn.TransformerEncoderLayer(
                    d_model=d_model,
                    nhead=n_heads,
                    dim_feedforward=d_ff,
                    dropout=dropout,
                    activation=activation
                )
                for _ in range(n_layers)
            ]
        )
        self.output_layer = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            output_activation,
            nn.Linear(d_model // 2, output_dim)
        )
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        x = self.output_layer(x)

        return x

In [11]:
class InterProModelTrainer:
    def __init__(self, device: str | None = None):
        self.accelerator = Accelerator()
        self.device = device or self.accelerator.device
        logging.basicConfig(level=logging.INFO)

    def train(
        self,
        model: InterProModel,
        train_loader: torch.utils.data.DataLoader,
        epochs: int = 1,
        lr: float = 1e-4,
        val_loader: torch.utils.data.DataLoader | None = None,
        checkpoint_metric: str = "val_acc",
        checkpoint_mode: str = "max",
        output_dir: str = "./checkpoints"
    ):
        os.makedirs(output_dir, exist_ok=True)
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()

        # Prepare objects with accelerator
        model, optimizer, train_loader = self.accelerator.prepare(model, optimizer, train_loader)
        if val_loader is not None:
            val_loader = self.accelerator.prepare(val_loader)

        best_metric = -float("inf") if checkpoint_mode == "max" else float("inf")

        for epoch in tqdm(range(epochs), desc="Epochs"):
            model.train()
            running_loss = 0.0
            n_batches = 0

            # Training loop with tqdm progress bar
            train_pbar = tqdm(train_loader, desc="Training", leave=False)
            for batch in train_pbar:
                inputs, targets = batch
                outputs = model(inputs)
                # Swapping time and class axes for nn.CrossEntropyLoss()
                loss = criterion(torch.swapaxes(outputs, -1, -2), targets)
                optimizer.zero_grad()
                self.accelerator.backward(loss)
                optimizer.step()

                running_loss += loss.item()
                n_batches += 1
                train_pbar.set_postfix(loss=loss.item())

            avg_train_loss = running_loss / n_batches if n_batches > 0 else 0.0
            logging.info(f"Epoch {epoch + 1}/{epochs} - Training Loss: {avg_train_loss:.4f}")

            # Evaluate on val set if provided
            if val_loader is not None:
                model.eval()
                total_loss = 0.0
                correct = 0
                total = 0
                with torch.no_grad():
                    val_pbar = tqdm(val_loader, desc="Evaluating", leave=False)
                    for batch in val_pbar:
                        inputs, targets = batch
                        outputs = model(inputs)
                        # Swapping time and class axes for nn.CrossEntropyLoss()
                        loss = criterion(torch.swapaxes(outputs, -1, -2), targets)
                        total_loss += loss.item() * targets.size(0)

                        preds = outputs.argmax(dim=-1)
                        correct += (preds == targets).sum().item()
                        total += targets.shape.numel()
                    avg_val_loss = total_loss / total if total > 0 else 0.0
                    accuracy = correct / total if total > 0 else 0.0
                logging.info(
                    f"Epoch {epoch + 1}/{epochs} - Validation Loss: {avg_val_loss:.4f}, "
                    f"Validation Accuracy: {accuracy:.4f}"
                )

                # Checkpoint based on metric improvement (using val accuracy here)
                current_metric = accuracy
                improved = (current_metric > best_metric) if checkpoint_mode == "max" \
                    else (current_metric < best_metric)
                if improved:
                    best_metric = current_metric
                    best_path = os.path.join(output_dir, "best_checkpoint.pt")
                    if self.accelerator.is_main_process:
                        torch.save(model.state_dict(), best_path)
                        logging.info(
                            f"Saved new best checkpoint at epoch {epoch + 1} "
                            f"with {checkpoint_metric}: {current_metric:.4f}"
                        )

            # Save checkpoint at end of epoch
            epoch_checkpoint = os.path.join(output_dir, f"checkpoint_epoch_{epoch + 1}.pt")
            if self.accelerator.is_main_process:
                torch.save(model.state_dict(), epoch_checkpoint)
                logging.info(f"Saved checkpoint for epoch {epoch + 1}")

        # Save final checkpoint
        final_path = os.path.join(output_dir, "final_checkpoint.pt")
        if self.accelerator.is_main_process:
            torch.save(model.state_dict(), final_path)
            logging.info("Saved final checkpoint.")

# Data

In [12]:
df = pd.read_parquet(
    os.path.join(s3_output_path, "proteinipr_with_sequences.parquet")
)

In [13]:
# Only "moderately" imbalanced
((df['end'] - df['start']) / 1000).mean()

0.32668393333333334

In [14]:
df_train, df_val = train_test_split(df, train_size=train_frac, random_state=random_state)
df_val, df_test = train_test_split(df_val, train_size=test_frac, random_state=random_state)

In [15]:
ds_train = InterProDataset(
    df_proteinipr=df,
    embeddings_path=embeddings_path,
    index=df_train.index
)
ds_val = InterProDataset(
    df_proteinipr=df,
    embeddings_path=embeddings_path,
    index=df_val.index
)
ds_test = InterProDataset(
    df_proteinipr=df,
    embeddings_path=embeddings_path,
    index=df_test.index
)

In [16]:
dl_train = DataLoader(
    dataset=ds_train,
    batch_size=64,
    shuffle=True,
)
dl_val = DataLoader(
    dataset=ds_val,
    batch_size=64,
    shuffle=False,
)
dl_test = DataLoader(
    dataset=ds_test,
    batch_size=64,
    shuffle=False,
)

# Training Loop

In [17]:
inputs, targets = next(iter(dl_train))

In [18]:
model = InterProModel(
    n_layers=2,
    d_model=inputs.shape[-1]
)

In [19]:
trainer = InterProModelTrainer()

In [20]:
current_timestamp = datetime.now().strftime("%Y-%m-%d_%H%M%S")
cur_output_dir = os.path.join(model_save_dir, current_timestamp)

In [21]:
trainer.train(
    model=model,
    train_loader=dl_train,
    val_loader=dl_val,
    output_dir=cur_output_dir,
    epochs=3
)

Epochs:   0%|          | 0/3 [00:00<?, ?it/s]

Training:   0%|          | 0/188 [00:00<?, ?it/s]

INFO:root:Epoch 1/3 - Training Loss: 0.0724


Evaluating:   0%|          | 0/24 [00:00<?, ?it/s]

INFO:root:Epoch 1/3 - Validation Loss: 0.0000, Validation Accuracy: 0.9837
INFO:root:Saved new best checkpoint at epoch 1 with val_acc: 0.9837
INFO:root:Saved checkpoint for epoch 1


Training:   0%|          | 0/188 [00:00<?, ?it/s]

INFO:root:Epoch 2/3 - Training Loss: 0.0445


Evaluating:   0%|          | 0/24 [00:00<?, ?it/s]

INFO:root:Epoch 2/3 - Validation Loss: 0.0000, Validation Accuracy: 0.9849
INFO:root:Saved new best checkpoint at epoch 2 with val_acc: 0.9849
INFO:root:Saved checkpoint for epoch 2


Training:   0%|          | 0/188 [00:00<?, ?it/s]

INFO:root:Epoch 3/3 - Training Loss: 0.0400


Evaluating:   0%|          | 0/24 [00:00<?, ?it/s]

INFO:root:Epoch 3/3 - Validation Loss: 0.0000, Validation Accuracy: 0.9847
INFO:root:Saved checkpoint for epoch 3
INFO:root:Saved final checkpoint.


# Evaluation

In [22]:
best_model_dict = torch.load(
    os.path.join(cur_output_dir, "best_checkpoint.pt"),
    weights_only=False
)

In [23]:
best_model = model

In [24]:
model_dict = best_model.state_dict()
model_dict.update(best_model_dict)
best_model.load_state_dict(model_dict)

<All keys matched successfully>

In [25]:
total_loss = 0.0
correct = 0
total = 0
criterion = nn.CrossEntropyLoss()

# Initialize counters for each class (binary: class 0 and class 1)
tp = [0, 0]  # true positives
fp = [0, 0]  # false positives
fn = [0, 0]  # false negatives

In [26]:
best_model, dl_test = trainer.accelerator.prepare(best_model, dl_test)
best_model.eval()
with torch.no_grad():
    val_pbar = tqdm(dl_test, desc="Test", leave=False)
    
    for batch in val_pbar:
        inputs, targets = batch
        outputs = best_model(inputs)
        # Swapping time and class axes for nn.CrossEntropyLoss()
        loss = criterion(torch.swapaxes(outputs, -1, -2), targets)
        total_loss += loss.item() * targets.size(0)
    
        preds = outputs.argmax(dim=-1)
        correct += (preds == targets).sum().item()
        total += targets.numel()
    
        # Update per-class counters
        for cls in [0, 1]:
            tp[cls] += ((preds == cls) & (targets == cls)).sum().item()
            fp[cls] += ((preds == cls) & (targets != cls)).sum().item()
            fn[cls] += ((preds != cls) & (targets == cls)).sum().item()

avg_test_loss = total_loss / total if total > 0 else 0.0
accuracy = correct / total if total > 0 else 0.0

# Compute precision and recall for each class
precision = [tp[i] / (tp[i] + fp[i]) if (tp[i] + fp[i]) > 0 else 0.0 for i in [0, 1]]
recall    = [tp[i] / (tp[i] + fn[i]) if (tp[i] + fn[i]) > 0 else 0.0 for i in [0, 1]]

logging.info(
    f"Test Loss: {avg_test_loss:.4f}, "
    f"Test Accuracy: {accuracy:.4f}, "
    f"Precision: {precision}, "
    f"Recall: {recall}"
)

Test:   0%|          | 0/24 [00:00<?, ?it/s]

INFO:root:Test Loss: 0.0000, Test Accuracy: 0.9862, Precision: [0.9914733976777852, 0.9754123168433498], Recall: [0.9879108102316648, 0.9825932128977707]
