Skip to content

Commit

Permalink
Merge pull request #73 from RMeli/gsoc19/flex
Browse files Browse the repository at this point in the history
Optimisation of flexible side chains
  • Loading branch information
dkoes committed Aug 21, 2019
2 parents 1e2d49d + 42bc72c commit 86c3c10
Show file tree
Hide file tree
Showing 14 changed files with 7,219 additions and 67 deletions.
6 changes: 3 additions & 3 deletions caffe/include/caffe/layers/molgrid_data_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ class MolGridDataLayer : public BaseDataLayer<Dtype> {
void getMappedLigandGradient(int batch_idx, std::unordered_map<string, gfloat3>& gradient);
void getMappedLigandRelevance(int batch_idx, std::unordered_map<string, float>& relevance);

virtual void setReceptor(const vector<atom>& receptor, const vec& translate =
virtual void setReceptor(const vector<float3>& coords, const vector<smt>& smtypes, const vec& translate =
{}, const qt& rotate = {});
virtual void setLigand(const vector<atom>& ligand, const vector<vec>& coords,
bool calcCenter=true);
virtual void setLigand(const vector<float3>& coords, const vector<smt>& smtypes,
bool calcCenter = true);

//set center to use for memory ligand
void setGridCenter(const vec& center) {
Expand Down
42 changes: 18 additions & 24 deletions caffe/src/caffe/layers/molgrid_data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1039,21 +1039,20 @@ void MolGridDataLayer<Dtype>::copyToBlobs(const vector<Blob<Dtype>*>& top, bool
//set in memory buffer
//will apply translate and rotate iff rotate is valid
template <typename Dtype>
void MolGridDataLayer<Dtype>::setReceptor(const vector<atom>& receptor, const vec& translate, const qt& rotate ) {
void MolGridDataLayer<Dtype>::setReceptor(const vector<float3>& coords, const vector<smt>& smtypes, const vec& translate, const qt& rotate ) {
CHECK_GT(batch_info.size(), 0) << "Empty batch info in setReceptor";

vector<float3> cs; cs.reserve(receptor.size());
vector<float> types; types.reserve(receptor.size());
vector<float> radii; radii.reserve(receptor.size());
vector<float> types; types.reserve(smtypes.size());
vector<float> radii; radii.reserve(smtypes.size());

CHECK_EQ(coords.size(), smtypes.size()) << "Size mismatch between receptor coords and smtypes";


//receptor atoms
for (unsigned i = 0, n = receptor.size(); i < n; i++) {
const atom& a = receptor[i];
smt origt = a.sm;
for (unsigned i = 0, n = smtypes.size(); i < n; i++) {
smt origt = smtypes[i];
auto t_r = recTypes->get_int_type(origt);
int t = t_r.first;
const vec& coord = a.coords;
cs.push_back(gfloat3(coord[0],coord[1],coord[2]));
types.push_back(t);
radii.push_back(t_r.second);

Expand All @@ -1062,7 +1061,7 @@ void MolGridDataLayer<Dtype>::setReceptor(const vector<atom>& receptor, const ve
}
}

CoordinateSet rec(cs, types, radii, recTypes->num_types());
CoordinateSet rec(coords, types, radii, recTypes->num_types());
if(rotate.real() != 0) {
//apply transformation
vec c = getGridCenter();
Expand All @@ -1079,34 +1078,29 @@ void MolGridDataLayer<Dtype>::setReceptor(const vector<atom>& receptor, const ve

//set in memory buffer, will set grid_Center if it isn't set, but will only overwrite set grid_center if calcCenter
template <typename Dtype>
void MolGridDataLayer<Dtype>::setLigand(const vector<atom>& ligand, const vector<vec>& coords, bool calcCenter) {
void MolGridDataLayer<Dtype>::setLigand(const vector<float3>& coords, const vector<smt>& smtypes, bool calcCenter) {

CHECK_GT(batch_info.size(), 0) << "Empty batch info in setLigand";

vector<float3> cs; cs.reserve(ligand.size());
vector<float> types; types.reserve(ligand.size());
vector<float> radii; radii.reserve(ligand.size());
vector<float> types; types.reserve(coords.size());
vector<float> radii; radii.reserve(coords.size());

CHECK_EQ(coords.size(), smtypes.size()) << "Size mismatch between ligand coords and smtypes";

CHECK_EQ(ligand.size(), coords.size()) << "Size mismatch between ligand atoms and coords";
//ligand atoms, grid positions offset and coordinates are specified separately
vec center(0, 0, 0);
for (unsigned i = 0, n = ligand.size(); i < n; i++) {
smt origt = ligand[i].sm;
for (unsigned i = 0, n = smtypes.size(); i < n; i++) {
smt origt = smtypes[i];
auto t_r = ligTypes->get_int_type(origt);
int t = t_r.first;
float r = t_r.second;
//unsupported types are kept around with t == -1
const vec& coord = coords[i];
cs.push_back(gfloat3(coord[0],coord[1],coord[2]));
types.push_back(t);
radii.push_back(r);
radii.push_back(t_r.second);

if(t < 0 && origt > 1) { //don't warn about hydrogens
std::cerr << "Unsupported ligand atom type " << GninaIndexTyper::gnina_type_name(origt) << "\n";
}
}

CoordinateSet ligatoms(cs, types, radii, ligTypes->num_types());
CoordinateSet ligatoms(coords, types, radii, ligTypes->num_types());
batch_info[0].setLigand(ligatoms);

if (calcCenter || !isfinite(grid_center[0])) {
Expand Down
217 changes: 206 additions & 11 deletions gninasrc/lib/cnn_scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,12 @@ void CNNScorer::lrp(const model& m, const string& layer_to_ignore,
boost::lock_guard<boost::recursive_mutex> guard(*mtx);

caffe::Caffe::set_random_seed(cnnopts.seed); //same random rotations for each ligand..

setLigand(m);
setReceptor(m);

mgrid->setReceptor(m.get_fixed_atoms());
mgrid->setLigand(m.get_movable_atoms(), m.coordinates());
mgrid->setReceptor(receptor_coords, receptor_smtypes);
mgrid->setLigand(ligand_coords, ligand_smtypes);
mgrid->setLabels(1); //for now pose optimization only
mgrid->enableLigandGradients();
mgrid->enableReceptorGradients();
Expand All @@ -194,8 +197,11 @@ void CNNScorer::gradient_setup(const model& m, const string& recname,

caffe::Caffe::set_random_seed(cnnopts.seed); //same random rotations for each ligand..

mgrid->setReceptor(m.get_fixed_atoms());
mgrid->setLigand(m.get_movable_atoms(), m.coordinates());
setLigand(m);
setReceptor(m);

mgrid->setReceptor(receptor_coords, receptor_smtypes);
mgrid->setLigand(ligand_coords, ligand_smtypes);
mgrid->setLabels(1); //for now pose optimization only
mgrid->enableLigandGradients();
mgrid->enableReceptorGradients();
Expand Down Expand Up @@ -295,6 +301,170 @@ void CNNScorer::get_net_output(Dtype& score, Dtype& aff, Dtype& loss) {

loss = lossblob->cpu_data()[0];
}

// Extract ligand atoms and coordinates
void CNNScorer::setLigand(const model& m){

// Ligand atoms
sz n = std::distance(
m.get_movable_atoms().cbegin() + m.ligands[0].node.begin,
m.get_movable_atoms().cbegin() + m.m_num_movable_atoms
);

// Get ligand types and radii
if (ligand_smtypes.size() == 0){ // Do once at setup
ligand_coords.reserve(n);

auto cbegin = m.get_movable_atoms().cbegin() + m.ligands[0].node.begin;
auto cend = m.get_movable_atoms().cbegin() + m.m_num_movable_atoms;

for(auto it = cbegin; it != cend; ++it){
smt origt = it->sm; // Original smina type
ligand_smtypes.push_back(origt);
}
}

// Coordinates need to be updated at every call
auto cbegin = m.coordinates().cbegin() + m.ligands[0].node.begin;
auto cend = m.coordinates().cbegin() + m.m_num_movable_atoms;
if(ligand_coords.size() == 0){
ligand_smtypes.reserve(n);

// Back insertion on first call
std::transform(cbegin, cend, std::back_inserter(ligand_coords),
[](const vec& coord) -> float3 {return float3({coord[0], coord[1], coord[2]});}
);
}else if(ligand_coords.size() == n){
// Update coordinates (transforming from vec to float3)
std::transform(cbegin, cend, ligand_coords.begin(),
[](const vec& coord) -> float3 {return float3({coord[0], coord[1], coord[2]});}
);
}

// Check final size
CHECK_EQ(ligand_smtypes.size(), n);
CHECK_EQ(ligand_coords.size(), n);
}

// Extracts receptor atoms and coordinates
// Flex and inflex coordinates are taken from the model's movable atoms
// Flex coordinates are stored at the beginning, then inflex, then fixed
void CNNScorer::setReceptor(const model& m){

// Number of receptor movable atoms
num_flex_atoms = std::distance(
m.get_movable_atoms().cbegin(),
m.get_movable_atoms().cbegin() + m.ligands[0].node.begin
);

// Number of inflex atoms
sz n_inflex = std::distance(
m.get_movable_atoms().cbegin() + m.m_num_movable_atoms,
m.get_movable_atoms().cend()
);

// Number of fixed receptor atoms
sz n_rigid = m.get_fixed_atoms().size();

// Total receptor size
sz n = num_flex_atoms + n_inflex + n_rigid;

if(receptor_smtypes.size() == 0){ // Do once at setup

receptor_smtypes.reserve(n);

// Insert flexible residues movable atoms
auto cbegin = m.get_movable_atoms().cbegin();
auto cend = m.get_movable_atoms().cbegin() + m.ligands[0].node.begin;
for(auto it = cbegin; it != cend; ++it){
smt origt = it->sm; // Original smina type
receptor_smtypes.push_back(origt);
}

CHECK_EQ(receptor_smtypes.size(), num_flex_atoms);

// Insert inflex atoms
cbegin = m.get_movable_atoms().cbegin() + m.m_num_movable_atoms;
cend = m.get_movable_atoms().cend();
for(auto it = cbegin; it != cend; ++it){
smt origt = it->sm; // Original smina type
receptor_smtypes.push_back(origt);
}

CHECK_EQ(receptor_smtypes.size(), num_flex_atoms + n_inflex);

// Insert fixed receptor atoms
cbegin = m.get_fixed_atoms().cbegin();
cend = m.get_fixed_atoms().cend();
for(auto it = cbegin; it != cend; ++it){
smt origt = it->sm; // Original smina type
receptor_smtypes.push_back(origt);
}
}

if(receptor_coords.size() == 0 ){ // Do once at setup

// Reserve memory, but size() == 0
receptor_coords.reserve(n);

// Append flex
auto cbegin_flex = m.coordinates().cbegin();
auto cend_flex = m.coordinates().cbegin() + num_flex_atoms;
std::transform(cbegin_flex, cend_flex, std::back_inserter(receptor_coords),
[](const vec& coord) -> float3 {return float3({coord[0], coord[1], coord[2]});}
);

// Append inflex
auto cbegin_inflex = m.coordinates().cbegin() + m.m_num_movable_atoms;
auto cend_inflex = m.coordinates().cend();
std::transform(cbegin_inflex, cend_inflex, std::back_inserter(receptor_coords),
[](const vec& coord) -> float3 {return float3({coord[0], coord[1], coord[2]});}
);

// Append rigid receptor
auto cbegin_rigid = m.get_fixed_atoms().cbegin();
auto cend_rigid = m.get_fixed_atoms().cend();
std::transform(cbegin_rigid, cend_rigid, std::back_inserter(receptor_coords),
[](const atom& a) -> float3 {
const vec& coord = a.coords;
return float3({coord[0], coord[1], coord[2]});
}
);

}
else if(receptor_coords.size() == n){ // Update flex coordinates at every call
auto cbegin = m.coordinates().cbegin();
auto cend = m.coordinates().cbegin() + num_flex_atoms;
std::transform(cbegin, cend, receptor_coords.begin(),
[](const vec& coord) -> float3 {return float3({coord[0], coord[1], coord[2]});}
);
}

// Check final size
CHECK_EQ(receptor_smtypes.size(), n);
CHECK_EQ(receptor_coords.size(), n);
}

// Get ligand (and flexible receptor) gradient
void CNNScorer::getGradient(){
gradient.reserve(ligand_coords.size() + num_flex_atoms);

// Get ligand gradient
mgrid->getLigandGradient(0, gradient);

// Get receptor gradient
std::vector<gfloat3> gradient_rec;
if (num_flex_atoms != 0) { // Optimization of flexible residues
mgrid->getReceptorGradient(0, gradient_rec);
}

// Merge ligand and flexible residues gradient
// Flexible residues, if any, come first
gradient.insert(gradient.begin(), gradient_rec.cbegin(), gradient_rec.cbegin() + num_flex_atoms);

CHECK_EQ(gradient.size(), ligand_coords.size() + num_flex_atoms);
}

//return score of model, assumes receptor has not changed from initialization
//also sets affinity (if available) and loss (for use with minimization)
//if compute_gradient is set, also adds cnn atom gradient to m.minus_forces
Expand All @@ -314,11 +484,26 @@ float CNNScorer::score(model& m, bool compute_gradient, float& affinity,
mgrid->setGridCenter(current_center);
}

mgrid->setLigand(m.get_movable_atoms(), m.coordinates(), cnnopts.move_minimize_frame);
// Get ligand atoms and coords from movable atoms
setLigand(m);

// Get receptor atoms and flex/inflex coordinats from movable atoms
setReceptor(m);

// Checks
if(num_flex_atoms == 0){ // No flexible residues
CHECK_EQ(ligand_coords.size(), m.coordinates().size());
CHECK_EQ(receptor_coords.size(), m.get_fixed_atoms().size());
}
CHECK_EQ(num_flex_atoms + ligand_coords.size(), m.m_num_movable_atoms);

mgrid->setLigand(ligand_coords, ligand_smtypes, cnnopts.move_minimize_frame);

if (!cnnopts.move_minimize_frame) { //if fixed_receptor, rec_conf will be identify
mgrid->setReceptor(m.get_fixed_atoms(), m.rec_conf.position, m.rec_conf.orientation);
} else { //don't move receptor
mgrid->setReceptor(m.get_fixed_atoms());
mgrid->setReceptor(receptor_coords, receptor_smtypes, m.rec_conf.position, m.rec_conf.orientation);
}
else { //don't move receptor
mgrid->setReceptor(receptor_coords, receptor_smtypes);
current_center = mgrid->getGridCenter(); //has been recalculated from ligand
if(cnnopts.verbose) {
std::cout << "current center: ";
Expand All @@ -329,8 +514,12 @@ float CNNScorer::score(model& m, bool compute_gradient, float& affinity,

if(compute_gradient || cnnopts.outputxyz) {
mgrid->enableLigandGradients();
if(cnnopts.moving_receptor() || cnnopts.outputxyz)
if(cnnopts.moving_receptor() || cnnopts.outputxyz){
mgrid->enableReceptorGradients();
}
else if(num_flex_atoms != 0){
mgrid->enableReceptorGradients(); // rmeli: TODO flexres gradients only
}
}

m.clear_minus_forces();
Expand Down Expand Up @@ -358,8 +547,14 @@ float CNNScorer::score(model& m, bool compute_gradient, float& affinity,
if (compute_gradient || cnnopts.outputxyz) {

net->Backward();
mgrid->getLigandGradient(0, gradient);

// Get gradient from mgrid into CNNScorer::gradient
getGradient();

// Update ligand (and flexible residues) gradient
m.add_minus_forces(gradient);

// Gradient for rigid receptor transformation: translation and torque
if(cnnopts.moving_receptor())
mgrid->getReceptorTransformationGradient(0, m.rec_change.position, m.rec_change.orientation);
}
Expand All @@ -375,7 +570,7 @@ float CNNScorer::score(model& m, bool compute_gradient, float& affinity,
mgrid->getLigandChannels(0, channels);
outputXYZ(ligname, atoms, channels, gradient);

mgrid->getReceptorGradient(0, gradient);
mgrid->getReceptorGradient(0, gradient); // rmeli: TODO Full gradient or just flexres?
mgrid->getReceptorAtoms(0, atoms);
mgrid->getReceptorChannels(0, channels);
outputXYZ(recname, atoms, channels, gradient);
Expand Down
12 changes: 12 additions & 0 deletions gninasrc/lib/cnn_scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ class CNNScorer {
std::vector<short> channels;
vec current_center; //center last time set_center was called, if min frame is moving, the mgrid center will be changing

// Receptor and ligand information
std::vector<float3> ligand_coords, receptor_coords;
std::vector<smt> ligand_smtypes, receptor_smtypes;

std::size_t num_flex_atoms; // Number of flexible atoms

// Set ligand and receptor atoms and coordinates from model
void setLigand(const model& m);
void setReceptor(const model& m);

void getGradient();

public:
CNNScorer()
: mgrid(NULL), mtx(new boost::recursive_mutex), current_center(NAN,NAN,NAN) {
Expand Down
Loading

0 comments on commit 86c3c10

Please sign in to comment.