Skip to content

Commit

Permalink
Python 3 support
Browse files Browse the repository at this point in the history
Merges branch 'wip/py3'.
  • Loading branch information
mgedmin committed Jun 4, 2015
2 parents 1574ee0 + 0ef99a5 commit 96f0f23
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 113 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[run]
source = findimports
branch = True
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ check test:

.PHONY: coverage
coverage:
coverage run --source=findimports testsuite.py
coverage run testsuite.py
coverage report

.PHONY: test-all-pythons
Expand Down
160 changes: 81 additions & 79 deletions findimports.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,20 @@
Ave, Cambridge, MA 02139, USA.
"""

import os
import sys
import getopt
from __future__ import print_function

import ast
import doctest
import compiler
import getopt
import linecache
import os
import pickle
import sys
import zipfile
from operator import attrgetter
from compiler.visitor import ASTVisitor


__version__ = '1.3.3.dev0'
__version__ = '1.4.0.dev0'
__author__ = 'Marius Gedminas <marius@gedmin.as>'
__licence__ = 'GPL v2 or later'
__url__ = 'https://github.com/mgedmin/findimports'
Expand Down Expand Up @@ -110,7 +111,7 @@ def __repr__(self):
self.filename, self.lineno, self.level)


class ImportFinder(ASTVisitor):
class ImportFinder(ast.NodeVisitor):
"""AST visitor that collects all imported names in its imports attribute.
For example, the following import statements in the AST tree
Expand Down Expand Up @@ -144,53 +145,55 @@ def processImport(self, name, imported_as, full_name, level, node):
info = ImportInfo(full_name, self.filename, lineno, level)
self.imports.append(info)

def visitImport(self, node):
for name, imported_as in node.names:
self.processImport(name, imported_as, name, None, node)
def visit_Import(self, node):
for alias in node.names:
self.processImport(alias.name, alias.asname, alias.name, None, node)

def visitFrom(self, node):
if node.modname == '__future__':
def visit_ImportFrom(self, node):
if node.module == '__future__':
return

for name, imported_as in node.names:
self.processImport(name, imported_as,
'%s.%s' % (node.modname, name)
if node.modname else name, node.level, node)
for alias in node.names:
name = alias.name
imported_as = alias.asname
fullname = '%s.%s' % (node.module, name) if node.module else name
self.processImport(name, imported_as, fullname, node.level, node)

def visitSomethingWithADocstring(self, node):
self.processDocstring(node.doc, node.lineno)
for c in node.getChildNodes():
self.visit(c)
# ClassDef and FunctionDef have a 'lineno' attribute, Module doesn't.
lineno = getattr(node, 'lineno', None)
self.processDocstring(ast.get_docstring(node, clean=False), lineno)
self.generic_visit(node)

visitModule = visitSomethingWithADocstring
visitClass = visitSomethingWithADocstring
visitFunction = visitSomethingWithADocstring
visit_Module = visitSomethingWithADocstring
visit_ClassDef = visitSomethingWithADocstring
visit_FunctionDef = visitSomethingWithADocstring

def processDocstring(self, docstring, lineno):
if not docstring:
return
if lineno is None:
# Module nodes have a lineno of None.
# Module nodes don't have a lineno
lineno = 0
dtparser = doctest.DocTestParser()
try:
examples = dtparser.get_examples(docstring)
except Exception:
print >> sys.stderr, ("%s:%s: error while parsing doctest"
% (self.filename, lineno))
print("{filename}:{lineno}: error while parsing doctest".format(
filename=self.filename, lineno=lineno), file=sys.stderr)
raise
for example in examples:
try:
source = example.source
if isinstance(source, unicode):
if not isinstance(source, str):
source = source.encode('UTF-8')
ast = compiler.parse(source)
node = ast.parse(source, filename='<docstring>')
except SyntaxError:
print >> sys.stderr, ("%s:%s: syntax error in doctest"
% (self.filename, lineno))
print("{filename}:{lineno}: syntax error in doctest".format(
filename=self.filename, lineno=lineno), file=sys.stderr)
else:
self.lineno_offset += lineno + example.lineno
compiler.walk(ast, self)
self.visit(node)
self.lineno_offset -= lineno + example.lineno


Expand Down Expand Up @@ -260,9 +263,9 @@ def processDocstring(self, docstring, lineno):
ImportFinder.processDocstring(self, docstring, lineno)
self.leaveScope()

def visitFunction(self, node):
def visit_FunctionDef(self, node):
self.newScope(self.scope, 'function %s' % node.name)
ImportFinder.visitFunction(self, node)
ImportFinder.visit_FunctionDef(self, node)
self.leaveScope()

def processImport(self, name, imported_as, full_name, level, node):
Expand All @@ -274,25 +277,25 @@ def processImport(self, name, imported_as, full_name, level, node):
if (self.warn_about_duplicates and
self.scope.haveImport(imported_as)):
where = self.scope.whereImported(imported_as).lineno
print >> sys.stderr, ("%s:%s: %s imported again"
% (self.filename, lineno, imported_as))
print("{filename}:{lineno}: {name} imported again".format(
filename=self.filename, lineno=lineno, name=imported_as), file=sys.stderr)
if self.verbose:
print >> sys.stderr, ("%s:%s: (location of previous import)"
% (self.filename, where))
print("{filename}:{lineno}: (location of previous import)".format(
filename=self.filename, lineno=where), file=sys.stderr)
else:
self.scope.addImport(imported_as, self.filename, level, lineno)

def visitName(self, node):
self.scope.useName(node.name)

def visitGetattr(self, node):
full_name = [node.attrname]
parent = node.expr
while isinstance(parent, compiler.ast.Getattr):
full_name.append(parent.attrname)
parent = parent.expr
if isinstance(parent, compiler.ast.Name):
full_name.append(parent.name)
def visit_Name(self, node):
self.scope.useName(node.id)

def visit_Attribute(self, node):
full_name = [node.attr]
parent = node.value
while isinstance(parent, ast.Attribute):
full_name.append(parent.attr)
parent = parent.value
if isinstance(parent, ast.Name):
full_name.append(parent.id)
full_name.reverse()
name = ""
for part in full_name:
Expand All @@ -301,18 +304,18 @@ def visitGetattr(self, node):
else:
name += part
self.scope.useName(name)
for c in node.getChildNodes():
self.visit(c)
self.generic_visit(node)


def find_imports(filename):
"""Find all imported names in a given file.
Returns a list of ImportInfo objects.
"""
ast = compiler.parseFile(filename)
with open(filename) as f:
root = ast.parse(f.read(), filename)
visitor = ImportFinder(filename)
compiler.walk(ast, visitor)
visitor.visit(root)
return visitor.imports


Expand All @@ -322,11 +325,12 @@ def find_imports_and_track_names(filename, warn_about_duplicates=False,
Returns ``(imports, unused)``. Both are lists of ImportInfo objects.
"""
ast = compiler.parseFile(filename)
with open(filename) as f:
root = ast.parse(f.read(), filename)
visitor = ImportFinderAndNameTracker(filename)
visitor.warn_about_duplicates = warn_about_duplicates
visitor.verbose = verbose
compiler.walk(ast, visitor)
visitor.visit(root)
visitor.leaveAllScopes()
return visitor.imports, visitor.unused_names

Expand Down Expand Up @@ -401,7 +405,7 @@ def warn(self, about, message, *args):
return
if args:
message = message % args
print >> self._stderr, message
print(message, file=self._stderr)
self._warned_about.add(about)

def parsePathname(self, pathname):
Expand All @@ -424,15 +428,13 @@ def parsePathname(self, pathname):

def writeCache(self, filename):
"""Write the graph to a cache file."""
f = file(filename, 'wb')
pickle.dump(self.modules, f)
f.close()
with open(filename, 'wb') as f:
pickle.dump(self.modules, f)

def readCache(self, filename):
"""Load the graph from a cache file."""
f = file(filename, 'rb')
self.modules = pickle.load(f)
f.close()
with open(filename, 'rb') as f:
self.modules = pickle.load(f)

def parseFile(self, filename):
"""Parse a single file."""
Expand Down Expand Up @@ -480,7 +482,7 @@ def findModuleOfName(self, dotted_name, level, filename, extrapath=None):

# extrapath is None only in a couple of test cases; in real life it's
# always present
if level > 1 and extrapath:
if level and level > 1 and extrapath:
# strip trailing path bits for each extra level to account for
# relative imports
# from . import X has level == 1 and nothing is stripped (the level > 1 check accounts for this case)
Expand Down Expand Up @@ -693,20 +695,20 @@ def visit2(u):
def printImportedNames(self):
"""Produce a report of imported names."""
for module in self.listModules():
print "%s:" % module.modname
print " %s" % "\n ".join(imp.name for imp in module.imported_names)
print("%s:" % module.modname)
print(" %s" % "\n ".join(imp.name for imp in module.imported_names))

def printImports(self):
"""Produce a report of dependencies."""
for module in self.listModules():
print "%s:" % module.label
print("%s:" % module.label)
if self.external_dependencies:
imports = list(module.imports)
else:
imports = [modname for modname in module.imports
if modname in self.modules]
imports.sort()
print " %s" % "\n ".join(imports)
print(" %s" % "\n ".join(imports))

def printUnusedImports(self):
"""Produce a report of unused imports."""
Expand All @@ -720,34 +722,34 @@ def printUnusedImports(self):
if '#' in line:
# assume there's a comment explaining why it's not used
continue
print "%s:%s: %s not used" % (module.filename, lineno, name)
print("%s:%s: %s not used" % (module.filename, lineno, name))

def printDot(self):
"""Produce a dependency graph in dot format."""
print "digraph ModuleDependencies {"
print " node[shape=box];"
print("digraph ModuleDependencies {")
print(" node[shape=box];")
allNames = set()
nameDict = {}
for n, module in enumerate(self.listModules()):
module._dot_name = "mod%d" % n
nameDict[module.modname] = module._dot_name
print " %s[label=\"%s\"];" % (module._dot_name,
quote(module.label))
print(" %s[label=\"%s\"];" % (module._dot_name,
quote(module.label)))
allNames |= module.imports
print " node[style=dotted];"
print(" node[style=dotted];")
if self.external_dependencies:
myNames = set(self.modules)
extNames = list(allNames - myNames)
extNames.sort()
for n, name in enumerate(extNames):
nameDict[name] = id = "extmod%d" % n
print " %s[label=\"%s\"];" % (id, name)
print(" %s[label=\"%s\"];" % (id, name))
for modname, module in sorted(self.modules.items()):
for other in sorted(module.imports):
if other in nameDict:
print " %s -> %s;" % (nameDict[module.modname],
nameDict[other])
print "}"
print(" %s -> %s;" % (nameDict[module.modname],
nameDict[other]))
print("}")


def quote(s):
Expand Down Expand Up @@ -775,9 +777,9 @@ def main(argv=sys.argv):
'packages', 'level=', 'help', 'collapse',
'noext', 'tests', 'write-cache=',
'duplicate', 'verbose'])
except getopt.error, e:
print >> sys.stderr, "%s: %s" % (progname, e)
print >> sys.stderr, "Try %s --help." % progname
except getopt.error as e:
print("%s: %s" % (progname, e), file=sys.stderr)
print("Try %s --help." % progname, file=sys.stderr)
return 1
for k, v in opts:
if k in ('-d', '--dot'):
Expand Down Expand Up @@ -807,7 +809,7 @@ def main(argv=sys.argv):
elif k == '--write-cache':
write_cache = v
elif k in ('-h', '--help'):
print helptext
print(helptext)
return 0
g.trackUnusedNames = (action == 'printUnusedImports')
if not args:
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def read(filename):
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4',
'License :: OSI Approved :: GNU General Public License (GPL)'
if licence.startswith('GPL') else
'License :: OSI Approved :: MIT License'
Expand Down
7 changes: 5 additions & 2 deletions tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os
import unittest
from cStringIO import StringIO
try:
from cStringIO import StringIO
except ImportError:
from io import StringIO

import findimports

Expand Down Expand Up @@ -40,6 +43,7 @@ def test_warn_suppresses_duplicates(self):

def test_parsePathname_regular_file(self):
mg = findimports.ModuleGraph()
mg.warn = self.warn
mg.parsePathname(__file__.rstrip('co')) # .pyc -> .py
self.assertTrue('unittest' in mg.modules[__name__].imports)

Expand All @@ -64,7 +68,6 @@ def test_isModule(self):
self.assertTrue(mg.isModule('sys'))
self.assertTrue(mg.isModule('datetime'))
self.assertFalse(mg.isModule('nosuchmodule'))
self.assertFalse(mg.isModule('logging')) # it's a package

def test_isModule_warns_about_bad_zip_files(self):
# anything that's a regular file but isn't a valid zip file
Expand Down
Loading

0 comments on commit 96f0f23

Please sign in to comment.