Skip to content

Commit

Permalink
Pass kwargs to RDKit functions of reading/writing
Browse files Browse the repository at this point in the history
  • Loading branch information
mwojcikowski committed Nov 13, 2020
1 parent 8745590 commit efe25b2
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions oddt/toolkits/rdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@
"""A list of supported forcefields"""


def _filereader_mol2(filename, lazy=False):
def _filereader_mol2(filename, lazy=False, **kwargs):
block = ''
data = ''
n = 0
Expand All @@ -132,7 +132,7 @@ def _filereader_mol2(filename, lazy=False):
if lazy:
yield Molecule(source={'fmt': 'mol2', 'string': block})
else:
yield readstring('mol2', block)
yield readstring('mol2', block, **kwargs)
n += 1
block = data
data = ''
Expand All @@ -142,10 +142,10 @@ def _filereader_mol2(filename, lazy=False):
if lazy:
yield Molecule(source={'fmt': 'mol2', 'string': block})
else:
yield readstring('mol2', block)
yield readstring('mol2', block, **kwargs)


def _filereader_sdf(filename, lazy=False):
def _filereader_sdf(filename, lazy=False, **kwargs):
block = ''
n = 0
with gzip.open(filename, 'rb') if filename.split('.')[-1] == 'gz' else open(filename, 'rb') as f:
Expand All @@ -160,11 +160,11 @@ def _filereader_sdf(filename, lazy=False):
if block: # open last molecule if any
yield Molecule(source={'fmt': 'sdf', 'string': block})
else:
for mol in Chem.ForwardSDMolSupplier(f):
for mol in Chem.ForwardSDMolSupplier(f, **kwargs):
yield Molecule(mol)


def _filereader_pdb(filename, lazy=False, opt=None):
def _filereader_pdb(filename, lazy=False, opt=None, **kwargs):
block = ''
n = 0
with gzip.open(filename, 'rb') if filename.split('.')[-1] == 'gz' else open(filename, 'rb') as f:
Expand All @@ -175,17 +175,17 @@ def _filereader_pdb(filename, lazy=False, opt=None):
if lazy:
yield Molecule(source={'fmt': 'pdb', 'string': block, 'opt': opt})
else:
yield readstring('pdb', block)
yield readstring('pdb', block, **kwargs)
n += 1
block = ''
if block: # open last molecule if any
if lazy:
yield Molecule(source={'fmt': 'pdb', 'string': block, 'opt': opt})
else:
yield readstring('pdb', block)
yield readstring('pdb', block, **kwargs)


def _filereader_pdbqt(filename, lazy=False, opt=None):
def _filereader_pdbqt(filename, lazy=False, opt=None, **kwargs):
block = ''
n = 0
with gzip.open(filename, 'rb') if filename.split('.')[-1] == 'gz' else open(filename, 'rb') as f:
Expand All @@ -196,14 +196,14 @@ def _filereader_pdbqt(filename, lazy=False, opt=None):
if lazy:
yield Molecule(source={'fmt': 'pdbqt', 'string': block, 'opt': opt})
else:
yield readstring('pdbqt', block)
yield readstring('pdbqt', block, **kwargs)
n += 1
block = ''
if block: # open last molecule if any
if lazy:
yield Molecule(source={'fmt': 'pdbqt', 'string': block, 'opt': opt})
else:
yield readstring('pdbqt', block)
yield readstring('pdbqt', block, **kwargs)


def readfile(format, filename, lazy=False, opt=None, **kwargs):
Expand Down Expand Up @@ -237,13 +237,13 @@ def readfile(format, filename, lazy=False, opt=None, **kwargs):
# errors in the format and errors in opening the file.
# Then switch to an iterator...
if format in ["sdf", "mol"]:
return _filereader_sdf(filename, lazy=lazy)
return _filereader_sdf(filename, lazy=lazy, **kwargs)
elif format == "pdb":
return _filereader_pdb(filename, lazy=lazy)
return _filereader_pdb(filename, lazy=lazy, **kwargs)
elif format == "pdbqt":
return _filereader_pdbqt(filename, lazy=lazy)
return _filereader_pdbqt(filename, lazy=lazy, **kwargs)
elif format == "mol2":
return _filereader_mol2(filename, lazy=lazy)
return _filereader_mol2(filename, lazy=lazy, **kwargs)
elif format == "smi":
iterator = Chem.SmilesMolSupplier(filename, delimiter=" \t",
titleLine=False, **kwargs)
Expand Down Expand Up @@ -279,7 +279,7 @@ def readstring(format, string, **kwargs):
string = str(string)
format = format.lower()
if format in ["mol", "sdf"]:
supplier = Chem.SDMolSupplier()
supplier = Chem.SDMolSupplier(**kwargs)
supplier.SetData(string)
mol = next(supplier)
del supplier
Expand Down Expand Up @@ -317,15 +317,15 @@ class Outputfile(object):
write(molecule)
close()
"""
def __init__(self, format, filename, overwrite=False):
def __init__(self, format, filename, overwrite=False, **kwargs):
self.format = format
self.filename = filename
if not overwrite and os.path.isfile(self.filename):
raise IOError("%s already exists. Use 'overwrite=True' to overwrite it." % self.filename)
if format == "sdf":
self._writer = Chem.SDWriter(self.filename)
self._writer = Chem.SDWriter(self.filename, **kwargs)
elif format == "smi":
self._writer = Chem.SmilesWriter(self.filename, isomericSmiles=True, includeHeader=False)
self._writer = Chem.SmilesWriter(self.filename, isomericSmiles=True, includeHeader=False, **kwargs)
elif format in ('inchi', 'inchikey') and Chem.INCHI_AVAILABLE:
self._writer = open(filename, 'w')
elif format in ('mol2', 'pdbqt'):
Expand All @@ -335,6 +335,7 @@ def __init__(self, format, filename, overwrite=False):
else:
raise ValueError("%s is not a recognised RDKit format" % format)
self.total = 0 # The total number of molecules written to the file
self.writer_kwargs = kwargs

def write(self, molecule):
"""Write a molecule to the output file.
Expand All @@ -345,10 +346,10 @@ def write(self, molecule):
if not self.filename:
raise IOError("Outputfile instance is closed.")
if self.format in ('inchi', 'inchikey', 'mol2'):
self._writer.write(molecule.write(self.format) + '\n')
self._writer.write(molecule.write(self.format, **self.writer_kwargs) + '\n')
if self.format == 'pdbqt':
self._writer.write('MODEL %i\n' % (self.total + 1) +
molecule.write(self.format) + '\nENDMDL\n')
molecule.write(self.format, **self.writer_kwargs) + '\nENDMDL\n')
else:
self._writer.write(molecule.Mol)
self.total += 1
Expand Down

0 comments on commit efe25b2

Please sign in to comment.