In [1]:
%run data_splitting.ipynb

2025-04-22 16:08:38.729886: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-22 16:08:38.773726: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-22 16:08:38.773761: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-22 16:08:38.774821: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-22 16:08:38.781779: I tensorflow/core/platform/cpu_feature_guar

Check GPU runtime type... 
Change Runtype Type in top menu for GPU acceleration
 "Runtime" -> "Change Runtime Type" -> "GPU"


2025-04-22 16:08:43.621093: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2025-04-22 16:08:43.622161: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Found 5316 protein-pocket pairs.
Input shape: torch.Size([4, 32, 32, 32])
Label shape: torch.Size([1, 32, 32, 32])
Pocket voxels in label: 367.0
5305


In [2]:
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import time
import json


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
print(device)

cuda


# **Building predictor**    


In [5]:
class Pocket3DCNN(nn.Module):
    def __init__(self, in_channels=4):  # 4 = atom types: C, N, O, S
        super(Pocket3DCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),

            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),

            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),

            nn.Conv3d(128, 1, kernel_size=1),  # 1 output channel = binary classification
            nn.Sigmoid()  # voxel-wise output
        )

    def forward(self, x):
        return self.model(x)  # output shape: (batch, 1, D, H, W)

### Initializing the model

In [6]:
model = Pocket3DCNN(in_channels=4)
model = model.to(device)


We calculate the evaluation metrics scores with **custom functions**, so that is more **Pytorch-tensor frindly**:

In [None]:
# grab one sample from the dataset
X, Y = dataset[3]

X_input = X.unsqueeze(0)  

# run the model
model.eval()
with torch.no_grad():
    X_input = X_input.to(device)  
    preds = model(X_input)[0]  # remove batch dimension again, shape: (1, 32, 32, 32)
    

In [12]:
from sklearn.metrics import recall_score, precision_score

def accuracy(preds, labels):
    preds = preds > 0.5 
    correct = (preds == labels).sum().item()
    total = labels.numel()
    return correct / total

def precision(preds, labels, threshold=0.5):
    preds_bin = (preds > threshold).cpu().numpy().flatten()
    labels_bin = labels.cpu().numpy().flatten()
    return precision_score(labels_bin, preds_bin, zero_division=0)

def recall(preds, labels, threshold=0.5):
    preds_bin = (preds > threshold).cpu().numpy().flatten()
    labels_bin = labels.cpu().numpy().flatten()
    return recall_score(labels_bin, preds_bin, zero_division=0)


def f1_score(prec, rec, eps=1e-8):
    return 2 * prec * rec / (prec + rec + eps)

def dice_loss(probs, target, smooth=1e-8):
    target = target.to(preds.device)

    # dice Loss
    intersection = (probs * target).sum()
    dice_score = (2. * intersection + smooth) / (probs.sum() + target.sum() + smooth)
    loss = 1 - dice_score

    return loss

pos_weight = torch.tensor([10.0], device=device) 
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

We measure:

    - accuracy
    - precision
    - recall
    - f1_score
    - BCE loss

#### Checkpoint functions

In [13]:
# Loading checkpoint
def load_checkpoint(model, optimizer, filename='checkpoint.pth'):
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from epoch {epoch}.")
    return model, optimizer, epoch, loss

# Saving checkpoint
def save_checkpoint(model, optimizer, epoch, loss, filename='checkpoint.pth'):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved at epoch {epoch}.")

### Setup

In [18]:

val_loader   = DataLoader(val_set, batch_size=2, shuffle=False)
train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
test_loader  = DataLoader(test_set, batch_size=2, shuffle=False)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 60

# === Metric tracking ===
train_losses, val_losses = [], []
train_accuracies, train_recalls, train_pres, train_f1 = [], [], [], []
val_accuracies, val_recalls, val_pres, val_f1 = [], [], [], []

# === Checkpoint handling ===
start_epoch = 0
best_val_loss = float('inf')


### Training (fitting)

In [19]:
try:
    model, optimizer, start_epoch, _ = load_checkpoint(model, optimizer, 'model_checkpoint.pth')
    print(f"Resuming from epoch {start_epoch}")
except FileNotFoundError:
    start_epoch = 0
    print("No checkpoint found, starting from scratch.")

# === Training Loop ===
for epoch in range(start_epoch, num_epochs):
    epoch_start = time.time()
    model.train()
    total_loss = correct_train = total_train = 0
    total_train_recall = total_pres = total_f1 = 0

    for i, (X, Y) in enumerate(train_loader):
        batch_start = time.time()

        X, Y = X.to(device), Y.to(device)
        preds = model(X)
        loss = loss_fn(preds, Y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        acc = accuracy(preds, Y)
        rec = recall(preds, Y)
        pres = precision(preds, Y)
        f1 = f1_score(pres, rec)

        correct_train += acc
        total_train += 1
        total_train_recall += rec
        total_pres += pres
        total_f1 += f1

        batch_time = time.time() - batch_start
        if i % 100 == 0:
            print(f"Batch {i+1}/{len(train_loader)} | Loss: {loss.item():.4f} | Time: {batch_time:.2f}")
    
    epoch_time = time.time() - epoch_start

    # Average training metrics
    avg_loss = total_loss / len(train_loader)
    avg_accuracy = correct_train / total_train
    avg_recall = total_train_recall / total_train
    avg_pres = total_pres / total_train
    avg_f1 = total_f1 / total_train

    train_losses.append(avg_loss)
    train_accuracies.append(avg_accuracy)
    train_recalls.append(avg_recall)
    train_pres.append(avg_pres)
    train_f1.append(avg_f1)

    print(f"Epoch {epoch+1} Train | Loss: {avg_loss:.4f}, Acc: {avg_accuracy:.4f}, Recall: {avg_recall:.4f}, Precision: {avg_pres:.4f}, F1: {avg_f1:.4f}")

    # === Validation ===
    model.eval()
    val_loss = correct_val = total_val = total_val_recall = total_val_pres = total_val_f1 = 0

    with torch.no_grad():
        for X_val, Y_val in val_loader:
            X_val, Y_val = X_val.to(device), Y_val.to(device)
            preds_val = model(X_val)
            loss_val = loss_fn(preds_val, Y_val)
            val_loss += loss_val.item()

            acc_val = accuracy(preds_val, Y_val)
            rec_val = recall(preds_val, Y_val)
            pres_val = precision(preds_val, Y_val)
            f1_val = f1_score(pres_val, rec_val)

            correct_val += acc_val
            total_val += 1
            total_val_recall += rec_val
            total_val_pres += pres_val
            total_val_f1 += f1_val

    # Average validation metrics
    avg_val_loss = val_loss / len(val_loader)
    avg_val_accuracy = correct_val / total_val
    avg_val_recall = total_val_recall / total_val
    avg_val_pres = total_val_pres / total_val
    avg_val_f1 = total_val_f1 / total_val

    val_losses.append(avg_val_loss)
    val_accuracies.append(avg_val_accuracy)
    val_recalls.append(avg_val_recall)
    val_f1.append(avg_val_f1)
    val_pres.append(avg_val_pres)

    print(f"Epoch {epoch+1} Val   | Loss: {avg_val_loss:.4f}, Acc: {avg_val_accuracy:.4f}, Recall: {avg_val_recall:.4f}, Precision: {avg_val_pres:.4f}, F1: {avg_val_f1:.4f}")

    save_checkpoint(model, optimizer, epoch+1, avg_loss, filename="model_checkpoint.pth")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'model.pt')
        print("Best model saved")

# === Final Test Evaluation ==='''

print("\n Evaluating on test set: ")
model.load_state_dict(torch.load('model.pt'))
model.eval()

test_loss = correct_test = total_test = total_test_recall = total_test_pres = total_test_f1 = 0
model = model.to(device)
with torch.no_grad():
    for X_test, Y_test in test_loader:
        X_test, Y_test = X_test.to(device), Y_test.to(device)
        preds_test = model(X_test)
        loss_test = loss_fn(preds_test, Y_test)
        test_loss += loss_test.item()

        acc_test = accuracy(preds_test, Y_test)
        rec_test = recall(preds_test, Y_test)
        pres_test = precision(preds_test, Y_test)
        f1_test = f1_score(pres_test, rec_test)

        correct_test += acc_test
        total_test += 1
        total_test_recall += rec_test
        total_test_pres += pres_test
        total_test_f1 += f1_test

print(f"\n Final test results:")
print(f"Loss:      {test_loss / len(test_loader):.4f}")
print(f"Accuracy:  {correct_test / total_test:.4f}")
print(f"Recall:    {total_test_recall / total_test:.4f}")
print(f"Precision: {total_test_pres / total_test:.4f}")
print(f"F1 Score:  {total_test_f1 / total_test:.4f}")

  checkpoint = torch.load(filename)
  model.load_state_dict(torch.load('model.pt'))


Checkpoint loaded from epoch 60.
Resuming from epoch 60

 Evaluating on test set: 

 Final test results:
Loss:      0.7472
Accuracy:  0.9814
Recall:    0.5500
Precision: 0.3157
F1 Score:  0.3923


In [None]:
import json

all_metrics = {
    "train": {
        "loss": train_losses,
        "acc": train_accuracies,
        "recall": train_recalls,
        "precision": train_pres,
        "f1": train_f1
    },
    "val": {
        "loss": val_losses,
        "acc": val_accuracies,
        "recall": val_recalls,
        "precision": val_pres,
        "f1": val_f1
    },
    "test": {
    "loss": test_loss / total_test,
    "accuracy": correct_test / total_test,
    "recall": total_test_recall / total_test,
    "precision": total_test_pres / total_test,
    "f1_score": total_test_f1 / total_test
    }
}

with open("all_metrics.json", "w") as f:
    json.dump(all_metrics, f, indent=4)