<a href="https://colab.research.google.com/github/greenrace666/biocolabs/blob/main/proteindiff.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!uv pip install diffusers BioPython torch_geometric hf_xet --prerelease disallow

[2mUsing Python 3.11.12 environment at: /usr[0m
[2mAudited [1m4 packages[0m [2min 220ms[0m[0m


In [None]:
#!pip uninstall torch-cluster
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch._y_version__}.html

Found existing installation: torch_cluster 1.6.3+pt26cu124
Uninstalling torch_cluster-1.6.3+pt26cu124:
  Would remove:
    /usr/local/lib/python3.11/dist-packages/torch_cluster-1.6.3+pt26cu124.dist-info/*
    /usr/local/lib/python3.11/dist-packages/torch_cluster/*
Proceed (Y/n)? y
  Successfully uninstalled torch_cluster-1.6.3+pt26cu124
Looking in links: https://data.pyg.org/whl/torch-2.6.0+cu124.html
Collecting torch-cluster
  Using cached https://data.pyg.org/whl/torch-2.6.0%2Bcu124/torch_cluster-1.6.3%2Bpt26cu124-cp311-cp311-linux_x86_64.whl (3.4 MB)
Installing collected packages: torch-cluster
Successfully installed torch-cluster-1.6.3+pt26cu124


In [None]:
import requests
import random
import os
from pathlib import Path

def get_random_pdb_ids(num_ids=10):
    """Get a list of random PDB IDs from RCSB."""
    # Use RCSB REST API to get a list of all PDB IDs
    url = "https://data.rcsb.org/rest/v1/holdings/current/entry_ids"
    response = requests.get(url)

    if response.status_code == 200:
        all_pdb_ids = response.json()
        # Randomly select num_ids from the list
        return random.sample(all_pdb_ids, num_ids)
    else:
        raise Exception(f"Failed to get PDB IDs. Status code: {response.status_code}")

def download_mmcif(pdb_id, output_dir):
    """Download mmCIF file for a given PDB ID."""
    # Create URL for mmCIF file
    url = f"https://files.rcsb.org/download/{pdb_id}.cif"

    # Create output directory if it doesn't exist
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Set output file path
    output_file = output_dir / f"{pdb_id}.cif"

    try:
        # Download the file
        response = requests.get(url)
        response.raise_for_status()  # Raise an exception for bad status codes

        # Write the file
        with open(output_file, 'wb') as f:
            f.write(response.content)

        print(f"Successfully downloaded {pdb_id}.cif")
        return True

    except requests.exceptions.RequestException as e:
        print(f"Error downloading {pdb_id}: {str(e)}")
        return False

def main():
    # Set the output directory
    output_dir = "data"

    try:
        # Get 10 random PDB IDs
        pdb_ids = get_random_pdb_ids(10)

        print(f"Downloading {len(pdb_ids)} mmCIF files...")

        # Download each mmCIF file
        successful_downloads = 0
        for pdb_id in pdb_ids:
            if download_mmcif(pdb_id, output_dir):
                successful_downloads += 1

        print(f"\nDownload complete! Successfully downloaded {successful_downloads} out of {len(pdb_ids)} files.")
        print(f"Files are saved in the '{output_dir}' directory.")

    except Exception as e:
        print(f"An error occurred: {str(e)}")

if __name__ == "__main__":
    main()

Downloading 10 mmCIF files...
Successfully downloaded 4PQB.cif
Successfully downloaded 8R39.cif
Successfully downloaded 4JIT.cif
Successfully downloaded 8F5U.cif
Successfully downloaded 2Q8L.cif
Successfully downloaded 1AIP.cif
Successfully downloaded 1CJA.cif
Successfully downloaded 8XSG.cif
Successfully downloaded 1G9I.cif
Successfully downloaded 7ESG.cif

Download complete! Successfully downloaded 10 out of 10 files.
Files are saved in the 'data' directory.


In [None]:
# -----------------GPU----------------------------------
# Import statements
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import EsmModel, EsmTokenizer
from diffusers import DDPMScheduler
from torch_geometric.data import Data, Batch
from torch_geometric.nn import knn_graph
import numpy as np
import pickle
from Bio.PDB import MMCIFParser, PPBuilder
import argparse
import torch_cluster

# Placeholder for SE(3)-equivariant GNN layer
class EGNNLayer(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.scalar_net = nn.Linear(in_scalar_dim, hidden_dim)
        self.vector_net = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h, x, edge_index):
        h = F.relu(self.scalar_net(h))
        # Simplified: Actual EGNN updates coordinates (x) with equivariant message passing
        return h, x

class GraphUNet(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.down1 = EGNNLayer(in_scalar_dim, hidden_dim)
        self.down2 = EGNNLayer(hidden_dim, hidden_dim)
        self.bottleneck = EGNNLayer(hidden_dim, hidden_dim)
        self.up2 = EGNNLayer(hidden_dim * 2, hidden_dim)  # Concatenated skip connection
        self.up1 = EGNNLayer(hidden_dim * 2, hidden_dim)
        self.out = nn.Linear(hidden_dim, 3)  # Predict noise for each atom (x, y, z)

    def forward(self, data):
        x_scalar, x_vector, edge_index = data.x_scalar, data.x_vector, data.edge_index

        h1, _ = checkpoint(self.down1, x_scalar, x_vector, edge_index)
        h2, _ = checkpoint(self.down2, h1, x_vector, edge_index)
        h_b, _ = checkpoint(self.bottleneck, h2, x_vector, edge_index)
        h_up2, _ = checkpoint(self.up2, torch.cat([h_b, h2], dim=1), x_vector, edge_index)
        h_up1, _ = checkpoint(self.up1, torch.cat([h_up2, h1], dim=1), x_vector, edge_index)

        pred_epsilon = self.out(h_up1)
        return pred_epsilon

# Dataset class for loading preprocessed protein data
class ProteinDataset(Dataset):
    def __init__(self, data_dir):
        self.data = []
        for file in os.listdir(data_dir):
            if file.endswith('.pkl'):
                with open(os.path.join(data_dir, file), 'rb') as f:
                    self.data.append(pickle.load(f))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Function to parse CIF files and extract all-atom coordinates and sequence
def parse_cif(cif_file):
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure('protein', cif_file)
    ppb = PPBuilder()
    peptides = ppb.build_peptides(structure)
    sequence = ''.join([str(pp.get_sequence()) for pp in ppb.build_peptides(structure)])
    seq_residues = []
    for pp in peptides:
        seq_residues.extend(pp)

    # Create a set of residue IDs for filtering
    seq_residue_ids = set(res.get_id() for res in seq_residues)
    coords = []
    valid_atoms=[]
    for model in structure:
        for chain in model:
            for residue in chain:
              if residue.get_id() in seq_residue_ids:
                for atom in residue:
                    coords.append(atom.get_coord())
                    valid_atoms.append(residue.get_id())

    coords = torch.tensor(np.array(coords), dtype=torch.float32)
    return sequence, coords, valid_atoms

# Preprocess data: compute ESM-2 embeddings and save with coordinates
def preprocess_data(data_dir):
    os.makedirs(data_dir, exist_ok=True)
    tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
    esm_model = EsmModel.from_pretrained("facebook/esm2_t30_150M_UR50D").eval()
    print(f"ESM-2 hidden size: {esm_model.config.hidden_size}")  # Should be 1280

    cif_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.cif')]
    if not cif_files:
        print(f"No .cif files found in {data_dir}")
        return

    for cif_file in cif_files:
        output_file = os.path.join(data_dir, os.path.basename(cif_file).replace('.cif', '.pkl'))
        if os.path.exists(output_file):
            print(f"Skipping {cif_file}, already preprocessed.")
            continue
        print(f"Processing {cif_file}")
        sequence, coords, valid_atoms = parse_cif(cif_file)
        inputs = tokenizer(sequence, return_tensors="pt")
        with torch.no_grad():
            outputs = esm_model(**inputs)
        embeddings = outputs.last_hidden_state[0]  # [seq_len, 1280]
        print(f"Embeddings shape before mapping: {embeddings.shape}")  # Debug

        seq_residues = []
        for pp in PPBuilder().build_peptides(MMCIFParser().get_structure('protein', cif_file)):
            seq_residues.extend(pp)
        residue_id_to_idx = {res.get_id(): idx for idx, res in enumerate(seq_residues)}

        atom_to_residue = []
        for res_id in valid_atoms:
            if res_id in residue_id_to_idx:
                atom_to_residue.append(residue_id_to_idx[res_id])
            else:
                print(f"Warning: Residue ID {res_id} not found in sequence for {cif_file}")

        if not atom_to_residue:
            print(f"Error: No valid atom-to-residue mappings for {cif_file}")
            continue
        if max(atom_to_residue, default=-1) >= len(embeddings):
            print(f"Error: Residue index out of bounds for {cif_file}")
            continue
        embeddings = embeddings[atom_to_residue]
        print(f"Embeddings shape after mapping: {embeddings.shape}")  # Should be [num_atoms, 1280]
        if embeddings.shape[-1] != 640:
            print(f"Error: Expected 640 dimensions, got {embeddings.shape[-1]} for {cif_file}")
            continue
        data = Data(
            x_scalar=embeddings,
            x_vector=coords,
            edge_index=knn_graph(coords, k=6)
        )
        with open(output_file, 'wb') as f:
            pickle.dump(data, f)
        print(f"Saved preprocessed data to {output_file}")
# Timestep embedding for diffusion model
def timestep_embedding(timesteps, dim, max_period=10000):
    half = dim // 2
    freqs = torch.exp(-torch.linspace(0, 1, half, device=timesteps.device) * np.log(max_period))
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    return embedding

# Training function
def train(args):
    device = torch.device('cuda' if args.device == 'gpu' and torch.cuda.is_available() else 'cpu')
    dataset = ProteinDataset(args.data_dir)
    if len(dataset) == 0:
        print(f"No preprocessed .pkl files found in {args.data_dir}. Please run preprocessing first.")
        return
    loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=Batch.from_data_list)
    model = GraphUNet(in_scalar_dim=640 + 32, hidden_dim=256).to(device)
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = GradScaler(enabled=(args.device == 'gpu'))

    for epoch in range(args.epochs):
        model.train()
        for batch in loader:
            batch = batch.to(device)
            print(f"batch.x_scalar shape: {batch.x_scalar.shape}")  # Debug
            t = torch.randint(0, scheduler.config.num_train_timesteps, (batch.num_graphs,), device=device)
            alpha_bar = scheduler.alphas_cumprod[t]
            alpha_bar_per_atom = alpha_bar[batch.batch]
            sqrt_alpha_bar = torch.sqrt(alpha_bar_per_atom)
            sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar_per_atom)
            noise = torch.randn_like(batch.x_vector)
            x_t = sqrt_alpha_bar[:, None] * batch.x_vector + sqrt_one_minus_alpha_bar[:, None] * noise
            t_emb = timestep_embedding(t, dim=32).to(device)
            print(f"t_emb_per_atom shape: {t_emb[batch.batch].shape}")  # Debug
            t_emb_per_atom = t_emb[batch.batch]
            x_scalar = torch.cat([batch.x_scalar, t_emb_per_atom], dim=1)
            print(f"x_scalar shape: {x_scalar.shape}")  # Debug
            data = Data(x_scalar=x_scalar, x_vector=x_t, edge_index=batch.edge_index)
            with autocast(device_type='cuda', enabled=(args.device == 'gpu')):
                pred_epsilon = model(data)
                loss = F.mse_loss(pred_epsilon, noise)
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        print(f"Epoch {epoch + 1}/{args.epochs}, Loss: {loss.item()}")

        # Save standard checkpoint
        from google.colab import drive
        drive.mount('/content/drive')
        # Update checkpoint_path
        checkpoint_path = os.path.join('/content/drive/MyDrive/', f"model_epoch_{epoch + 1}.pth")
        jit_checkpoint_path = os.path.join('/content/drive/MyDrive/', f"model_epoch_{epoch + 1}_jit.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Saved standard checkpoint to {checkpoint_path}")

        # Save TorchScript checkpoint
        model.eval()  # JIT requires eval mode
        try:
            # Create example input for tracing
            example_data = Data(
                x_scalar=torch.randn(100, 640 + 32, device=device),
                x_vector=torch.randn(100, 3, device=device),
                edge_index=torch.randint(0, 100, (2, 200), device=device)
            )
            # Trace the model
            traced_model = torch.jit.trace(model, example_data)
            traced_model.save(jit_checkpoint_path)
            print(f"Saved TorchScript checkpoint to {jit_checkpoint_path}")
        except Exception as e:
            print(f"Failed to save TorchScript checkpoint: {e}")
        model.train()
        # Set up arguments
class Args:
    def __init__(self):
        self.data_dir = 'data'
        self.device = 'gpu'
        self.distributed = False
        self.epochs = 100  # For testing

args = Args()

# Preprocess CIF files in the data directory
preprocess_data(args.data_dir)

# Run training
train(args)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ESM-2 hidden size: 640
Processing data/4PQB.cif
Embeddings shape before mapping: torch.Size([155, 640])
Embeddings shape after mapping: torch.Size([1216, 640])
Saved preprocessed data to data/4PQB.pkl
Processing data/8F5U.cif
Embeddings shape before mapping: torch.Size([2074, 640])




Embeddings shape after mapping: torch.Size([15901, 640])
Saved preprocessed data to data/8F5U.pkl
Processing data/1AIP.cif
Embeddings shape before mapping: torch.Size([2271, 640])




Embeddings shape after mapping: torch.Size([17560, 640])
Saved preprocessed data to data/1AIP.pkl
Processing data/4JIT.cif
Embeddings shape before mapping: torch.Size([593, 640])




Embeddings shape after mapping: torch.Size([4511, 640])
Saved preprocessed data to data/4JIT.pkl
Processing data/2Q8L.cif
Embeddings shape before mapping: torch.Size([316, 640])
Embeddings shape after mapping: torch.Size([2593, 640])
Saved preprocessed data to data/2Q8L.pkl
Processing data/8R39.cif
Embeddings shape before mapping: torch.Size([210, 640])
Embeddings shape after mapping: torch.Size([3172, 640])
Saved preprocessed data to data/8R39.pkl
Processing data/1G9I.cif
Embeddings shape before mapping: torch.Size([247, 640])
Embeddings shape after mapping: torch.Size([1791, 640])
Saved preprocessed data to data/1G9I.pkl
Processing data/1CJA.cif




Embeddings shape before mapping: torch.Size([656, 640])




Embeddings shape after mapping: torch.Size([6240, 640])
Saved preprocessed data to data/1CJA.pkl
Processing data/7ESG.cif
Embeddings shape before mapping: torch.Size([591, 640])
Embeddings shape after mapping: torch.Size([4514, 640])
Saved preprocessed data to data/7ESG.pkl
Processing data/8XSG.cif
Embeddings shape before mapping: torch.Size([344, 640])
Embeddings shape after mapping: torch.Size([5553, 640])
Saved preprocessed data to data/8XSG.pkl




batch.x_scalar shape: torch.Size([30907, 640])
t_emb_per_atom shape: torch.Size([30907, 32])
x_scalar shape: torch.Size([30907, 672])


  return fn(*args, **kwargs)


batch.x_scalar shape: torch.Size([13071, 640])
t_emb_per_atom shape: torch.Size([13071, 32])
x_scalar shape: torch.Size([13071, 672])
batch.x_scalar shape: torch.Size([19073, 640])
t_emb_per_atom shape: torch.Size([19073, 32])
x_scalar shape: torch.Size([19073, 672])
Epoch 1/100, Loss: 1.012090802192688
Mounted at /content/drive
Saved standard checkpoint to /content/drive/MyDrive/model_epoch_1.pth
Failed to save TorchScript checkpoint: Type 'Tuple[Tuple[str, Tensor], Tuple[str, Tensor], Tuple[str, Tensor]]' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and Tuples of Tensors can be traced
batch.x_scalar shape: torch.Size([16516, 640])
t_emb_per_atom shape: torch.Size([16516, 32])
x_scalar shape: torch.Size([16516, 672])
batch.x_scalar shape: torch.Size([27184, 640])
t_emb_per_atom shape: torch.Size([27184, 32])
x_scalar shape: torch.Size([27184, 672])
batch.x_scalar shape: torch.Size([19351, 640])
t_emb_per_atom shape: torch.Size([19351, 32])
x_scalar shape: torch.S

In [None]:
#----------------------TPU (currently not working)---------------------------
import os
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import knn_graph
from transformers import EsmModel, EsmTokenizer
from diffusers import DDPMScheduler
from torch.cuda.amp import GradScaler, autocast
from Bio.PDB import PPBuilder, MMCIFParser
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
import numpy as np
import torch_cluster

# Args class
class Args:
    def __init__(self):
        self.data_dir = "data"
        self.device = "tpu"  # Updated to TPU
        self.epochs = 100

# Placeholder for timestep_embedding (adjust if different)
def timestep_embedding(t, dim):
    return torch.sin(t.view(-1, 1) * torch.linspace(0, 1, dim // 2, device=t.device)).repeat(1, 2)

# EGNNLayer (assumed, simplified for JIT compatibility)
class EGNNLayer(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.scalar_net = nn.Linear(in_scalar_dim, hidden_dim)
        self.message_net = nn.Linear(hidden_dim + 3, hidden_dim)
        self.update_net = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, h, x, edge_index):
        h = F.relu(self.scalar_net(h))
        row, col = edge_index
        dist = torch.norm(x[row] - x[col], dim=-1, keepdim=True)
        msg = F.relu(self.message_net(torch.cat([h[row], dist], dim=-1)))
        aggr = torch.zeros_like(h).scatter_add_(0, col.view(-1, 1).expand(-1, h.size(1)), msg)
        h = F.relu(self.update_net(aggr))
        return h, x

# GraphUNet with gradient checkpointing
class GraphUNet(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.down1 = EGNNLayer(in_scalar_dim, hidden_dim)
        self.down2 = EGNNLayer(hidden_dim, hidden_dim)
        self.bottleneck = EGNNLayer(hidden_dim, hidden_dim)
        self.up2 = EGNNLayer(hidden_dim * 2, hidden_dim)
        self.up1 = EGNNLayer(hidden_dim * 2, hidden_dim)
        self.out = nn.Linear(hidden_dim, 3)

    def forward(self, data):
        x_scalar, x_vector, edge_index = data.x_scalar, data.x_vector, data.edge_index
        h1, _ = checkpoint(self.down1, x_scalar, x_vector, edge_index)
        h2, _ = checkpoint(self.down2, h1, x_vector, edge_index)
        h_b, _ = checkpoint(self.bottleneck, h2, x_vector, edge_index)
        h_up2, _ = checkpoint(self.up2, torch.cat([h_b, h2], dim=1), x_vector, edge_index)
        h_up1, _ = checkpoint(self.up1, torch.cat([h_up2, h1], dim=1), x_vector, edge_index)
        pred_epsilon = self.out(h_up1)
        return pred_epsilon

# ProteinDataset (assumed, minimal for loading .pkl files)
class ProteinDataset:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.data = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.pkl')]
        self.data = [pickle.load(open(f, 'rb')) for f in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

#parse cif
def parse_cif(cif_file):
    parser = MMCIFParser(QUIET=True)
    structure = parser.get_structure('protein', cif_file)
    ppb = PPBuilder()
    peptides = ppb.build_peptides(structure)
    sequence = ''.join([str(pp.get_sequence()) for pp in ppb.build_peptides(structure)])
    seq_residues = []
    for pp in peptides:
        seq_residues.extend(pp)

    # Create a set of residue IDs for filtering
    seq_residue_ids = set(res.get_id() for res in seq_residues)
    coords = []
    valid_atoms=[]
    for model in structure:
        for chain in model:
            for residue in chain:
              if residue.get_id() in seq_residue_ids:
                for atom in residue:
                    coords.append(atom.get_coord())
                    valid_atoms.append(residue.get_id())

    coords = torch.tensor(np.array(coords), dtype=torch.float32)
    return sequence, coords, valid_atoms


# Preprocess data (unchanged, outputs 640-dimensional embeddings)
def preprocess_data(data_dir):
    os.makedirs(data_dir, exist_ok=True)
    tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
    esm_model = EsmModel.from_pretrained("facebook/esm2_t30_150M_UR50D", ignore_mismatched_sizes=True).eval()
    print(f"ESM-2 hidden size: {esm_model.config.hidden_size}")  # Expect 640 per user

    cif_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.cif')]
    if not cif_files:
        print(f"No .cif files found in {data_dir}")
        return

    for cif_file in cif_files:
        output_file = os.path.join(data_dir, os.path.basename(cif_file).replace('.cif', '.pkl'))
        if os.path.exists(output_file):
            print(f"Skipping {cif_file}, already preprocessed.")
            continue
        print(f"Processing {cif_file}")
        sequence, coords, valid_atoms = parse_cif(cif_file)
        inputs = tokenizer(sequence, return_tensors="pt")
        with torch.no_grad():
            outputs = esm_model(**inputs)
        embeddings = outputs.last_hidden_state[0]  # [seq_len, 640]
        print(f"Embeddings shape before mapping: {embeddings.shape}")

        seq_residues = []
        for pp in PPBuilder().build_peptides(MMCIFParser().get_structure('protein', cif_file)):
            seq_residues.extend(pp)
        residue_id_to_idx = {res.get_id(): idx for idx, res in enumerate(seq_residues)}

        atom_to_residue = []
        for res_id in valid_atoms:
            if res_id in residue_id_to_idx:
                atom_to_residue.append(residue_id_to_idx[res_id])
            else:
                print(f"Warning: Residue ID {res_id} not found in sequence for {cif_file}")

        if not atom_to_residue:
            print(f"Error: No valid atom-to-residue mappings for {cif_file}")
            continue
        if max(atom_to_residue, default=-1) >= len(embeddings):
            print(f"Error: Residue index out of bounds for {cif_file}")
            continue
        embeddings = embeddings[atom_to_residue]
        print(f"Embeddings shape after mapping: {embeddings.shape}")
        data = Data(
            x_scalar=embeddings,
            x_vector=coords,
            edge_index=knn_graph(coords, k=6)
        )
        with open(output_file, 'wb') as f:
            pickle.dump(data, f)
        print(f"Saved preprocessed data to {output_file}")

# Train function with TPU, gradient checkpointing, and TorchScript
def train(args):
    # Set up TPU device
    device = xm.xla_device()
    print(f"Using device: {device}")

    # Initialize dataset and loader
    dataset = ProteinDataset(args.data_dir)
    if len(dataset) == 0:
        print(f"No preprocessed .pkl files found in {args.data_dir}. Please run preprocessing first.")
        return
    loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=Batch.from_data_list)

    # Wrap loader for TPU
    xm_loader = pl.MpDeviceLoader(loader, device)

    # Initialize model and move to TPU
    model = GraphUNet(in_scalar_dim=640 + 32, hidden_dim=256).to(device)
    scheduler = DDPMScheduler(num_train_timesteps=1000)
    scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    # Ensure checkpoint directory
    os.makedirs(args.data_dir, exist_ok=True)

    for epoch in range(args.epochs):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for batch in xm_loader:
            batch = batch.to(device)
            t = torch.randint(0, scheduler.config.num_train_timesteps, (batch.num_graphs,), device=device)
            alpha_bar = scheduler.alphas_cumprod[t]
            alpha_bar_per_atom = alpha_bar[batch.batch]
            sqrt_alpha_bar = torch.sqrt(alpha_bar_per_atom)
            sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bar_per_atom)
            noise = torch.randn_like(batch.x_vector, device=device)
            x_t = sqrt_alpha_bar[:, None] * batch.x_vector + sqrt_one_minus_alpha_bar[:, None] * noise
            t_emb = timestep_embedding(t, dim=32).to(device)
            t_emb_per_atom = t_emb[batch.batch]
            x_scalar = torch.cat([batch.x_scalar, t_emb_per_atom], dim=1)
            data = Data(x_scalar=x_scalar, x_vector=x_t, edge_index=batch.edge_index)

            optimizer.zero_grad()
            pred_epsilon = model(data)
            loss = F.mse_loss(pred_epsilon, noise)
            loss.backward()
            xm.optimizer_step(optimizer)

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{args.epochs}, Average Loss: {avg_loss:.4f}")

        # Save standard checkpoint
        from google.colab import drive
        drive.mount('/content/drive')
        # Update checkpoint_path
        checkpoint_path = os.path.join('/content/drive/MyDrive/', f"model_epoch_{epoch + 1}.pth")
        jit_checkpoint_path = os.path.join('/content/drive/MyDrive/', f"model_epoch_{epoch + 1}_jit.pt")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Saved standard checkpoint to {checkpoint_path}")

        # Save TorchScript checkpoint
        model.eval()  # JIT requires eval mode
        try:
            # Create example input for tracing
            example_data = Data(
                x_scalar=torch.randn(100, 640 + 32, device=device),
                x_vector=torch.randn(100, 3, device=device),
                edge_index=torch.randint(0, 100, (2, 200), device=device)
            )
            # Trace the model
            traced_model = torch.jit.trace(model, example_data)
            traced_model.save(jit_checkpoint_path)
            print(f"Saved TorchScript checkpoint to {jit_checkpoint_path}")
        except Exception as e:
            print(f"Failed to save TorchScript checkpoint: {e}")
        model.train()

# Main execution
def main():
    args = Args()
    preprocess_data(args.data_dir)
    train(args)

if __name__ == "__main__":
    main()

INFERENCE

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import knn_graph
from transformers import EsmModel
import numpy as np
try:
    from Bio.PDB.MMCIFIO import MMCIFIO
except ImportError:
    print("Warning: MMCIFIO not available. Falling back to PDB output.")
    MMCIFIO = None
from Bio.PDB import Structure, Model, Chain, Residue, Atom, PDBIO
from Bio.PDB.MMCIFParser import MMCIFParser

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Timestep embedding
def timestep_embedding(t, dim):
    return torch.sin(t.view(-1, 1) * torch.linspace(0, 1, dim // 2, device=t.device)).repeat(1, 2)

# EGNNLayer
class EGNNLayer(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.scalar_net = nn.Linear(in_scalar_dim, hidden_dim)
        self.vector_net = nn.Linear(hidden_dim , hidden_dim)

    def forward(self, h, x, edge_index):
        h = F.relu(self.scalar_net(h))
        return h, x

# GraphUNet
class GraphUNet(nn.Module):
    def __init__(self, in_scalar_dim, hidden_dim):
        super().__init__()
        self.down1 = EGNNLayer(in_scalar_dim, hidden_dim)
        self.down2 = EGNNLayer(hidden_dim, hidden_dim)
        self.bottleneck = EGNNLayer(hidden_dim, hidden_dim)
        self.up2 = EGNNLayer(hidden_dim * 2, hidden_dim)
        self.up1 = EGNNLayer(hidden_dim * 2, hidden_dim)
        self.out = nn.Linear(hidden_dim, 3)

    def forward(self, data):
        x_scalar, x_vector, edge_index = data.x_scalar, data.x_vector, data.edge_index
        h1, x1 = self.down1(x_scalar, x_vector, edge_index)
        h2, x2 = self.down2(h1, x1, edge_index)
        h_b, x_b = self.bottleneck(h2, x2, edge_index)
        h_up2, x_up2 = self.up2(torch.cat([h_b, h2], dim=1), x_b, edge_index)
        h_up1, x_up1 = self.up1(torch.cat([h_up2, h1], dim=1), x_up2, edge_index)
        pred_epsilon = self.out(h_up1)
        return pred_epsilon

# Inference function
def infer_from_sequence(sequence, checkpoint_path, output_cif_path, num_timesteps=1000):
    # Validate sequence
    valid_aas = set("ACDEFGHIKLMNPQRSTVWY")
    if not sequence or not all(aa in valid_aas for aa in sequence):
        raise ValueError("Invalid sequence. Use single-letter codes: ACDEFGHIKLMNPQRSTVWY")

    # Initialize model and scheduler
    model = GraphUNet(in_scalar_dim=640 + 32, hidden_dim=256).to(device)
    scheduler = DDPMScheduler(num_train_timesteps=num_timesteps)

    # Load checkpoint
    if checkpoint_path.endswith('.pth'):
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    elif checkpoint_path.endswith('.pt'):
        model = torch.jit.load(checkpoint_path, map_location=device)
    else:
        raise ValueError("Checkpoint must be .pth or .pt")
    model.eval()

    # Generate ESM-2 embeddings
    tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t30_150M_UR50D")
    esm_model = EsmModel.from_pretrained("facebook/esm2_t30_150M_UR50D").to(device).eval()
    inputs = tokenizer(sequence, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = esm_model(**inputs)
    x_scalar = outputs.last_hidden_state[0, 1:-1, :]  # [seq_len, 640]
    num_residues = x_scalar.size(0)
    if num_residues != len(sequence):
        raise ValueError(f"Embedding size {num_residues} does not match sequence length {len(sequence)}")

    # Initialize noisy coordinates
    x_t = torch.randn(num_residues, 3, device=device)
    edge_index = knn_graph(x_t, k=6).to(device)
    batch = torch.zeros(num_residues, dtype=torch.long, device=device)

    # Denoising loop
    with torch.no_grad():
        for t in reversed(range(num_timesteps)):
            t_tensor = torch.full((1,), t, dtype=torch.long, device=device)
            t_emb = timestep_embedding(t_tensor, dim=32)
            t_emb_per_atom = t_emb.repeat(num_residues, 1)
            x_scalar_t = torch.cat([x_scalar, t_emb_per_atom], dim=1)

            if t % 200 == 0:
                edge_index = knn_graph(x_t, k=6).to(device)

            data_t = Data(x_scalar=x_scalar_t, x_vector=x_t, edge_index=edge_index)
            pred_epsilon = model(data_t)

            # Clip pred_epsilon to prevent explosions
            pred_epsilon = torch.clamp(pred_epsilon, -10.0, 10.0)

            noise_scale = torch.sqrt(1 - scheduler.alphas_cumprod[t]).to(device)
            signal_scale = torch.sqrt(scheduler.alphas_cumprod[t]).to(device)
            x_t = (x_t - noise_scale * pred_epsilon) / signal_scale if t > 0 else x_t - pred_epsilon

            # Normalize x_t to prevent divergence
            x_t = x_t / (x_t.norm(dim=-1, keepdim=True) + 1e-8) * 10.0  # Scale to ~10Å

            # Check for inf/nan
            if torch.isnan(x_t).any() or torch.isinf(x_t).any():
                print(f"Warning: inf/nan detected at step {t}")
                break

            if t % 100 == 0:
                print(f"Denoising step {t}/{num_timesteps}, x_t mean: {x_t.mean().item():.4f}")

    # Final coordinates
    pred_coords = x_t.cpu()

    # Validate coordinates
    if torch.isnan(pred_coords).any() or torch.isinf(pred_coords).any():
        raise ValueError("Invalid coordinates: inf/nan detected")

    # Normalize to reasonable scale
    pred_coords = (pred_coords - pred_coords.mean(dim=0)) / (pred_coords.std(dim=0) + 1e-8) * 3.8  # ~3.8Å CA-CA distance

    # Create structure
    structure = Structure.Structure('protein')
    model = Model.Model(0)
    chain = Chain.Chain('A')
    for i, aa in enumerate(sequence, 1):
        res = Residue.Residue((' ', i, ' '), aa, ' ')
        atom = Atom.Atom('CA', pred_coords[i-1].numpy(), 0.0, 1.0, ' ', 'CA', i)
        res.add(atom)
        chain.add(res)
    model.add(chain)
    structure.add(model)

    # Save outputs
    output_pdb_path = output_cif_path.replace('.cif', '.pdb')
    output_npy_path = output_cif_path.replace('.cif', '.npy')

    if MMCIFIO is not None:
        try:
            io = MMCIFIO()
            io.set_structure(structure)
            io.save(output_cif_path)
            print(f"Saved predicted structure to {output_cif_path}")
        except Exception as e:
            print(f"Failed to save CIF: {e}")
    else:
        print("MMCIFIO unavailable, skipping CIF output.")

    try:
        io = PDBIO()
        io.set_structure(structure)
        io.save(output_pdb_path)
        print(f"Saved predicted structure to {output_pdb_path}")
    except Exception as e:
        print(f"Failed to save PDB: {e}")

    np.save(output_npy_path, pred_coords.numpy())
    print(f"Saved raw coordinates to {output_npy_path}")

    return pred_coords

# Example usage
sequence = "CDAFVGTWKLVSSENFDDYMKEVGVGFATRKVAGMAKPNMIISVNGDLVTIRSESTFKNT"  # First 60 residues from your PDB
checkpoint_path = "/content/drive/MyDrive/model_epoch_100.pth"
output_cif_path = "./predicted_structure2.cif"
pred_coords = infer_from_sequence(sequence, checkpoint_path, output_cif_path)

Using device: cuda


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Denoising step 900/1000, x_t mean: -0.4875
Denoising step 800/1000, x_t mean: -0.4984
Denoising step 700/1000, x_t mean: -0.5484
Denoising step 600/1000, x_t mean: -0.5607
Denoising step 500/1000, x_t mean: -0.6061
Denoising step 400/1000, x_t mean: -0.6205
Denoising step 300/1000, x_t mean: -0.6081
Denoising step 200/1000, x_t mean: -0.5845
Denoising step 100/1000, x_t mean: -0.5853
Denoising step 0/1000, x_t mean: -0.5809
MMCIFIO unavailable, skipping CIF output.
Saved predicted structure to ./predicted_structure2.pdb
Saved raw coordinates to ./predicted_structure2.npy


In [None]:
!rm data/*.pkl

In [None]:
!rm -rf ~/.cache/huggingface/