In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import gymnasium as gym
from gymnasium import spaces
from collections import deque
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
class ProteinEnergy:
    """Calculate energy terms for protein structure"""
    
    def __init__(self):
        # Physical constants (in arbitrary units for simplicity)
        self.bond_length = 3.8  # Angstroms, C-alpha distance
        self.bond_k = 100.0  # Spring constant
        self.angle_k = 50.0  # Angle spring constant
        self.vdw_epsilon = 1.0  # Van der Waals depth
        self.vdw_sigma = 4.0  # Van der Waals radius
        
    def bond_energy(self, coords):
        """Energy from bond length deviations"""
        energy = 0.0
        for i in range(len(coords) - 1):
            dist = np.linalg.norm(coords[i+1] - coords[i])
            deviation = dist - self.bond_length
            energy += 0.5 * self.bond_k * deviation**2
        return energy
    
    def angle_energy(self, coords):
        """Energy from bond angle deviations"""
        if len(coords) < 3:
            return 0.0
        
        energy = 0.0
        ideal_angle = np.pi * 110 / 180  # ~110 degrees
        
        for i in range(len(coords) - 2):
            v1 = coords[i] - coords[i+1]
            v2 = coords[i+2] - coords[i+1]
            
            cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-8)
            cos_angle = np.clip(cos_angle, -1.0, 1.0)
            angle = np.arccos(cos_angle)
            
            deviation = angle - ideal_angle
            energy += 0.5 * self.angle_k * deviation**2
        
        return energy
    
    def vdw_energy(self, coords):
        """Van der Waals energy (Lennard-Jones potential)"""
        energy = 0.0
        n = len(coords)
        
        for i in range(n):
            for j in range(i+3, n):  # Skip nearby residues
                dist = np.linalg.norm(coords[i] - coords[j])
                
                if dist < 0.1:  # Prevent division by zero
                    return 1000.0  # Heavy clash penalty
                
                # Lennard-Jones 12-6 potential
                r6 = (self.vdw_sigma / dist)**6
                energy += 4 * self.vdw_epsilon * (r6**2 - r6)
        
        return energy
    
    def clash_penalty(self, coords):
        """Penalty for atoms too close together (steric clashes)"""
        penalty = 0.0
        n = len(coords)
        min_dist = 2.5  # Minimum allowed distance
        
        for i in range(n):
            for j in range(i+2, n):  # Check all pairs
                dist = np.linalg.norm(coords[i] - coords[j])
                if dist < min_dist:
                    penalty += 100 * (min_dist - dist)**2
        
        return penalty
    
    def total_energy(self, coords):
        """Calculate total energy of structure"""
        return (self.bond_energy(coords) + 
                self.angle_energy(coords) + 
                self.vdw_energy(coords) + 
                self.clash_penalty(coords))
