Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dev.project.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
}
}
},
"ServerScriptService": {
"TestService": {
"tests": {
"$path": "tests"
}
Expand Down
6 changes: 2 additions & 4 deletions scripts/analyze.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

curl -s -O https://raw.githubusercontent.com/JohnnyMorganz/luau-lsp/master/scripts/globalTypes.d.lua

cp .github/workflows/.luaurc Packages
rojo sourcemap dev.project.json -o sourcemap.json

luau-lsp analyze --sourcemap=sourcemap.json --defs=globalTypes.d.lua --defs=testez.d.lua src/
luau-lsp analyze --sourcemap=sourcemap.json --defs=globalTypes.d.lua --defs=testez.d.lua --ignore=**/_Index/** src/

rm Packages/.luaurc
rm globalTypes.d.lua
rm globalTypes.d.lua
1 change: 1 addition & 0 deletions scripts/test.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

rojo build dev.project.json -o studio-tests.rbxl
run-in-roblox --place studio-tests.rbxl --script tests/init.server.lua
pkill -n RobloxStudio
32 changes: 32 additions & 0 deletions src/createTablePassthrough.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
--[[
Creates a table that can be indexed and added to while also adding to a base
table.

This is used for module globals so that a module can define variables on _G
which are maintained in a dictionary of all globals AND a dictionary of the
globals a given module has defined.

This makes it easy to clear out the globals a modeule defines when removing
it from the cache.
]]

type AnyTable = { [any]: any }

local function createTablePassthrough(base: AnyTable): AnyTable
local proxy = {}

setmetatable(proxy, {
__index = function(self, key)
local global = rawget(self, key)
return if global then global else base[key]
end,
__newindex = function(self, key, value)
base[key] = value
rawset(self, key, value)
end,
})

return proxy :: any
end

return createTablePassthrough
23 changes: 23 additions & 0 deletions src/createTablePassthrough.spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
return function()
local createTablePassthrough = require(script.Parent.createTablePassthrough)

it("should work for the use case of maintaining global variables", function()
local allGlobals = {}
local moduleGlobals1 = createTablePassthrough(allGlobals)
local moduleGlobals2 = createTablePassthrough(allGlobals)

moduleGlobals1.foo = true
moduleGlobals2.bar = true

expect(moduleGlobals1.foo).to.equal(true)
expect(moduleGlobals1.bar).to.equal(true)
expect(rawget(moduleGlobals1, "bar")).never.to.be.ok()

expect(moduleGlobals2.bar).to.equal(true)
expect(moduleGlobals2.foo).to.equal(true)
expect(rawget(moduleGlobals2, "foo")).never.to.be.ok()

expect(allGlobals.foo).to.equal(true)
expect(allGlobals.bar).to.equal(true)
end)
end
3 changes: 2 additions & 1 deletion src/getEnv.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
local baseEnv = getfenv()

local function getEnv(scriptRelativeTo: LuaSourceContainer?)
local function getEnv(scriptRelativeTo: LuaSourceContainer?, globals: { [any]: any }?)
local newEnv = {}

setmetatable(newEnv, {
Expand All @@ -13,6 +13,7 @@ local function getEnv(scriptRelativeTo: LuaSourceContainer?)
end,
})

newEnv._G = globals
newEnv.script = scriptRelativeTo

local realDebug = debug
Expand Down
10 changes: 10 additions & 0 deletions src/getEnv.spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,14 @@ return function()
local env = getEnv(script.Parent.getEnv)
expect(env.script).to.equal(script.Parent.getEnv)
end)

it("should set _G to the 'globals' argument", function()
local globals = {}
local env = getEnv(script.Parent.getEnv, globals)

expect(env._G).to.be.ok()
expect(env._G).to.equal(globals)
-- selene: allow(global_usage)
expect(env._G).never.to.equal(_G)
end)
end
11 changes: 11 additions & 0 deletions src/getRobloxTsRuntime.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
local ReplicatedStorage = game:GetService("ReplicatedStorage")

local function getRobloxTsRuntime()
local rbxtsInclude = ReplicatedStorage:FindFirstChild("rbxts_include")
if rbxtsInclude then
return rbxtsInclude:FindFirstChild("RuntimeLib")
end
return nil
end

return getRobloxTsRuntime
26 changes: 26 additions & 0 deletions src/getRobloxTsRuntime.spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
return function()
local ReplicatedStorage = game:GetService("ReplicatedStorage")

local getRobloxTsRuntime = require(script.Parent.getRobloxTsRuntime)

it("should retrieve the roblox-ts runtime library", function()
local includes = Instance.new("Folder")
includes.Name = "rbxts_include"
includes.Parent = ReplicatedStorage

local mockRuntime = Instance.new("ModuleScript")
mockRuntime.Name = "RuntimeLib"
mockRuntime.Parent = includes

local runtime = getRobloxTsRuntime()

includes:Destroy()

expect(runtime == mockRuntime).to.equal(true)
end)

it("should return nil if the runtime can't be found", function()
local runtime = getRobloxTsRuntime()
expect(runtime).never.to.be.ok()
end)
end
105 changes: 82 additions & 23 deletions src/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ local GoodSignal = require(script.Parent.GoodSignal)
local bind = require(script.bind)
local getCallerPath = require(script.getCallerPath)
local getEnv = require(script.getEnv)
local createTablePassthrough = require(script.createTablePassthrough)
local getRobloxTsRuntime = require(script.getRobloxTsRuntime)
local types = require(script.types)

type ModuleConsumers = types.ModuleConsumers
type ModuleGlobals = types.ModuleGlobals

--[=[
ModuleScript loader that bypasses Roblox's require cache.
Expand All @@ -22,7 +28,8 @@ export type CachedModule = {
module: ModuleScript,
isLoaded: boolean,
result: any,
consumers: { string },
consumers: ModuleConsumers,
globals: ModuleGlobals,
}

--[=[
Expand All @@ -35,6 +42,7 @@ function ModuleLoader.new()
self._loadstring = loadstring
self._debugInfo = debug.info
self._janitors = {}
self._globals = {}

--[=[
Fired when any ModuleScript required through this class has its ancestry
Expand Down Expand Up @@ -93,19 +101,6 @@ function ModuleLoader:_getSource(module: ModuleScript): any?
return if success then result else nil
end

function ModuleLoader:_clearConsumerFromCache(moduleFullName: string)
local cachedModule: CachedModule = self._cache[moduleFullName]

if cachedModule then
for _, consumer in ipairs(cachedModule.consumers) do
self._cache[consumer] = nil
self:_clearConsumerFromCache(consumer)
end

self._cache[moduleFullName] = nil
end
end

--[=[
Tracks the changes to a required module's ancestry and `Source`.

Expand All @@ -122,13 +117,12 @@ function ModuleLoader:_trackChanges(module: ModuleScript)
janitor:Cleanup()

janitor:Add(module.AncestryChanged:Connect(function()
self.loadedModuleChanged:Fire(module)
self:clearModule(module)
end))

janitor:Add(module.Changed:Connect(function(prop: string)
if prop == "Source" then
self:_clearConsumerFromCache(module:GetFullName())
self.loadedModuleChanged:Fire(module)
self:clearModule(module)
end
end))

Expand Down Expand Up @@ -156,6 +150,7 @@ function ModuleLoader:cache(module: ModuleScript, result: any)
result = result,
isLoaded = true,
consumers = {},
globals = createTablePassthrough(self._globals),
}

self._cache[module:GetFullName()] = cachedModule
Expand All @@ -178,10 +173,7 @@ function ModuleLoader:require(module: ModuleScript)
local callerPath = getCallerPath()

if cachedModule then
if self._cache[callerPath] then
table.insert(cachedModule.consumers, callerPath)
end

cachedModule.consumers[callerPath] = true
return self:_loadCachedModule(module)
end

Expand All @@ -192,17 +184,20 @@ function ModuleLoader:require(module: ModuleScript)
error(("Could not parse %s: %s"):format(module:GetFullName(), parseError))
end

local globals = createTablePassthrough(self._globals)

local newCachedModule: CachedModule = {
module = module,
result = nil,
isLoaded = false,
consumers = {
if self._cache[callerPath] then callerPath else nil,
[callerPath] = true,
},
globals = globals,
}
self._cache[module:GetFullName()] = newCachedModule

local env = getEnv(module)
local env = getEnv(module, globals)
env.require = bind(self, self.require)
setfenv(moduleFn, env)

Expand All @@ -220,6 +215,69 @@ function ModuleLoader:require(module: ModuleScript)
return self:_loadCachedModule(module)
end

function ModuleLoader:_getConsumers(module: ModuleScript): { ModuleScript }
local function getConsumersRecursively(cachedModule: CachedModule, found: { [ModuleScript]: true })
for consumer in cachedModule.consumers do
local cachedConsumer = self._cache[consumer]

if cachedConsumer then
if not found[cachedConsumer.module] then
found[cachedConsumer.module] = true
getConsumersRecursively(cachedConsumer, found)
end
end
end
end

local cachedModule: CachedModule = self._cache[module:GetFullName()]
local found = {}

getConsumersRecursively(cachedModule, found)

local consumers = {}
for consumer in found do
table.insert(consumers, consumer)
end

return consumers
end

function ModuleLoader:clearModule(moduleToClear: ModuleScript)
if not self._cache[moduleToClear:GetFullName()] then
return
end

local consumers = self:_getConsumers(moduleToClear)
local modulesToClear = { moduleToClear, table.unpack(consumers) }

local index = table.find(modulesToClear, getRobloxTsRuntime())
if index then
table.remove(modulesToClear, index)
end

for _, module in modulesToClear do
local fullName = module:GetFullName()

local cachedModule = self._cache[fullName]

if cachedModule then
self._cache[fullName] = nil

for key in cachedModule.globals do
self._globals[key] = nil
end

local janitor = self._janitors[fullName]
janitor:Cleanup()
end
end

for _, module in modulesToClear do
print("loadedModuleChanged", module:GetFullName())
self.loadedModuleChanged:Fire(module)
end
end

--[=[
Clears out the internal cache.

Expand All @@ -243,6 +301,7 @@ end
]=]
function ModuleLoader:clear()
self._cache = {}
self._globals = {}

for _, janitor in self._janitors do
janitor:Cleanup()
Expand Down
Loading