Skip to content

Commit

Permalink
fix: support side-effect only imports and requires
Browse files Browse the repository at this point in the history
fix #24
  • Loading branch information
jedwards1211 committed Jan 21, 2021
1 parent 1bf9c98 commit 6c24069
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 24 deletions.
12 changes: 12 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"configurations": [
{
"name": "debug tests",
"request": "launch",
"runtimeArgs": ["test:debug"],
"runtimeExecutable": "yarn",
"skipFiles": ["<node_internals>/**"],
"type": "pwa-node"
}
]
}
117 changes: 95 additions & 22 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ const firstNode = (c) => c.at(0).nodes()[0]
const lastNode = (c) => c.at(-1).nodes()[0]

module.exports = function addImports(root, _statements) {
const found = findImports(root, _statements)
const statements = Array.isArray(_statements) ? _statements : [_statements]
const found = findImports(
root,
statements.filter((s) => s.type !== 'ExpressionStatement')
)
for (const name in found) {
if (found[name].type === 'Identifier') found[name] = found[name].name
else delete found[name]
Expand All @@ -31,7 +35,6 @@ module.exports = function addImports(root, _statements) {
} catch (error) {
// ignore
}
const statements = Array.isArray(_statements) ? _statements : [_statements]

const preventNameConflict = babelScope
? (_id) => {
Expand All @@ -54,12 +57,18 @@ module.exports = function addImports(root, _statements) {
return id
}

statements.forEach((statement) => {
for (const statement of statements) {
if (statement.type === 'ImportDeclaration') {
const { importKind } = statement
const source = { value: statement.source.value }
const filter = { source }
if (!definitelyFlow) filter.importKind = importKind
if (!statement.specifiers.length) {
if (!isSourceImported(root, statement.source.value)) {
addStatements(root, statement)
}
continue
}
let existing = root.find(j.ImportDeclaration, filter)
for (let specifier of statement.specifiers) {
if (found[specifier.local.name]) continue
Expand Down Expand Up @@ -99,12 +108,7 @@ module.exports = function addImports(root, _statements) {
j.stringLiteral(statement.source.value),
importKind
)
const allImports = root.find(j.ImportDeclaration)
if (allImports.size()) {
lastPath(allImports).insertAfter(newDeclaration)
} else {
insertProgramStatement(root, newDeclaration)
}
addStatements(root, newDeclaration)
existing = root.find(j.ImportDeclaration, { source })
}
}
Expand All @@ -118,6 +122,7 @@ module.exports = function addImports(root, _statements) {
},
init: {
type: 'CallExpression',
callee: { type: 'Identifier', name: 'require' },
arguments: [{ value: declarator.init.arguments[0].value }],
},
})
Expand All @@ -139,12 +144,7 @@ module.exports = function addImports(root, _statements) {
const newDeclaration = j.variableDeclaration('const', [
j.variableDeclarator(j.objectPattern([prop]), declarator.init),
])
const allImports = root.find(j.ImportDeclaration)
if (allImports.size()) {
lastPath(allImports).insertAfter(newDeclaration)
} else {
insertProgramStatement(root, newDeclaration)
}
addStatements(root, newDeclaration)
}
}
} else if (declarator.id.type === 'Identifier') {
Expand All @@ -155,21 +155,86 @@ module.exports = function addImports(root, _statements) {
const newDeclaration = j.variableDeclaration('const', [
j.variableDeclarator(declarator.id, declarator.init),
])
const allImports = root.find(j.ImportDeclaration)
if (allImports.size()) {
lastPath(allImports).insertAfter(newDeclaration)
} else {
insertProgramStatement(root, newDeclaration)
}
addStatements(root, newDeclaration)
}
}
})
} else if (statement.type === 'ExpressionStatement') {
if (isNodeRequireCall(statement.expression)) {
if (!isSourceImported(root, getSource(statement.expression))) {
addStatements(root, statement)
}
} else {
throw new Error(`statement must be an import or require`)
}
}
})
}

return found
}

function findTopLevelImports(root, predicate = () => true) {
const program = root.find(j.Program).at(0).paths()[0]
if (!program) return []
return j(
program
.get('body')
.filter((p) => p.node.type === 'ImportDeclaration' && predicate(p))
)
}

function isNodeRequireCall(node) {
return (
node.type === 'CallExpression' &&
node.callee.type === 'Identifier' &&
node.callee.name === 'require' &&
node.arguments[0] &&
(node.arguments[0].type === 'StringLiteral' ||
node.arguments[0].type === 'Literal')
)
}

function isPathRequireCall(path) {
return isNodeRequireCall(path.node) && !path.scope.lookup('require')
}

function findTopLevelRequires(root, predicate = () => true) {
const paths = []
const program = root.find(j.Program).at(0).paths()[0]
if (program) {
program.get('body').each((path) => {
if (path.node.type === 'ExpressionStatement') {
const expression = path.get('expression')
if (isPathRequireCall(expression) && predicate(expression))
paths.push(expression)
} else if (path.node.type === 'VariableDeclaration') {
for (const declaration of path.get('declarations')) {
const init = declaration.get('init')
if (isPathRequireCall(init) && predicate(init)) paths.push(init)
}
}
})
}
return j(paths)
}

function getSource(node) {
if (node.type === 'ImportDeclaration') return node.source.value
if (isNodeRequireCall(node)) {
const arg = node.arguments[0]
if (arg && (arg.type === 'Literal' || arg.type === 'StringLiteral'))
return arg.value
}
}

function isSourceImported(root, source) {
const hasSource = (p) => getSource(p.node) === source
return (
findTopLevelImports(root, hasSource).size() ||
findTopLevelRequires(root, hasSource).size()
)
}

function insertProgramStatement(root, ...statements) {
const program = root.find(j.Program).at(0).nodes()[0]
const firstProgramStatement = program.body[0]
Expand All @@ -193,3 +258,11 @@ function insertProgramStatement(root, ...statements) {
}
program.body.unshift(...statements)
}

function addStatements(root, ...statements) {
const imports = findTopLevelImports(root)
if (imports.size()) {
const last = lastPath(imports)
for (const statement of statements.reverse()) last.insertAfter(statement)
} else insertProgramStatement(root, ...statements)
}
93 changes: 91 additions & 2 deletions test/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const addImports = require('..')
for (const parser of ['babylon', 'ts']) {
describe(`with parser: ${parser}`, function () {
const j = jscodeshift.withParser(parser)
const { statement } = j.template
const { statement, statements } = j.template

const format = (code) =>
prettier
Expand All @@ -29,7 +29,7 @@ for (const parser of ['babylon', 'ts']) {
addImports(
root,
typeof importsToAdd === 'string'
? statement([importsToAdd])
? statements([importsToAdd])
: Array.isArray(importsToAdd)
? importsToAdd.map((i) =>
typeof i === 'string' ? statement([i]) : i
Expand All @@ -54,6 +54,13 @@ for (const parser of ['babylon', 'ts']) {
expectedError: 'statement must be an import or require',
})
})
it(`throws if statement is an ExpressionStatement that's not a require`, function () {
testCase({
code: `import Baz from 'baz'`,
add: `1 + 2`,
expectedError: 'statement must be an import or require',
})
})
it(`leaves existing non-default imports with alias untouched`, function () {
testCase({
code: `import {foo as bar} from 'baz'`,
Expand Down Expand Up @@ -550,6 +557,88 @@ for (const parser of ['babylon', 'ts']) {
})
})
describe(`bugs`, function () {
it(`adding side-effect only imports`, async function () {
testCase({
code: '',
add: `import 'foo'`,
expectedCode: `import 'foo'`,
})
testCase({
code: `
import 'bar'
`,
add: `import 'foo'`,
expectedCode: `
import 'bar'
import 'foo'
`,
})
})
it(`doesn't re-add side-effect only import`, async function () {
testCase({
code: `import 'foo'`,
add: `import 'foo'`,
expectedCode: `import 'foo'`,
})
testCase({
code: `require('foo')`,
add: `import 'foo'`,
expectedCode: `require('foo')`,
})
testCase({
code: `require('foo')`,
add: `
import 'foo'
import 'bar'
`,
expectedCode: `
import 'bar'
require('foo')
`,
})
})
it(`adding side-effect only requires`, async function () {
testCase({
code: '',
add: `require('foo')`,
expectedCode: `require('foo')`,
})
testCase({
code: `
require('bar')
`,
add: `require('foo')`,
expectedCode: `
require('foo')
require('bar')
`,
})
})
it(`doesn't re-add side-effect only require`, async function () {
testCase({
code: `require('foo')`,
add: `require('foo')`,
expectedCode: `require('foo')`,
})
testCase({
code: `import 'foo'`,
add: `require('foo')`,
expectedCode: `import 'foo'`,
})
testCase({
code: `
require('bar')
`,
add: `
require('foo')
require('bar')
`,
expectedCode: `
require('foo')
require('bar')
`,
})
})
if (parser !== 'ts') {
it(`import type { foo, type bar }`, function () {
testCase({
Expand Down

0 comments on commit 6c24069

Please sign in to comment.