Skip to content

Commit

Permalink
fix matrix parsing
Browse files Browse the repository at this point in the history
Closes #46
  • Loading branch information
jbyuki committed Oct 9, 2022
1 parent d780a3d commit c58936e
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 62 deletions.
116 changes: 77 additions & 39 deletions lua/nabla/ascii.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ local stack_subsup

local grid_of_exps

local combine_matrix_grid

local unpack_explist

local put_subsup_aside
Expand Down Expand Up @@ -904,6 +906,8 @@ function grid:combine_sub(other)
local spacer = grid:new(self.w, other.h)




local lower = spacer:join_hori(other)
local result = self:join_vert(lower, true)
result.my = self.my
Expand Down Expand Up @@ -1044,7 +1048,6 @@ function stack_subsup(explist, i, g)
end

function grid_of_exps(explist)
local cells = {}
local cellsgrid = {}
local maxheight = 0
local i = 1
Expand All @@ -1057,14 +1060,14 @@ function grid_of_exps(explist)

while i <= #explist do
if explist[i].kind == "symexp" and explist[i].sym == "&" then
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)
i = i+1
break

elseif explist[i].kind == "funexp" and explist[i].sym == "\\" then
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)

Expand All @@ -1079,7 +1082,7 @@ function grid_of_exps(explist)
end

if i > #explist then
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)

Expand All @@ -1090,7 +1093,40 @@ function grid_of_exps(explist)

end

return cells
return cellsgrid, maxheight
end

function combine_matrix_grid(cellsgrid, maxheight)
local res
for i=1,#cellsgrid[1] do
local col
for j=1,#cellsgrid do
local cell = cellsgrid[j][i]
local sup = maxheight - cell.h
local sdown = 0
local up, down
if sup > 0 then up = grid:new(cell.w, sup) end
if sdown > 0 then down = grid:new(cell.w, sdown) end

if up then cell = up:join_vert(cell) end
if down then cell = cell:join_vert(down) end

local colspacer = grid:new(1, cell.h)
colspacer.my = cell.my

if i < #cellsgrid[1] then
cell = cell:join_hori(colspacer)
end

if not col then col = cell
else col = col:join_vert(cell, true) end

end
if not res then res = col
else res = res:join_hori(col, true) end

end
return res
end

function unpack_explist(exp)
Expand Down Expand Up @@ -1674,53 +1710,55 @@ function to_ascii(explist, exp_i)
g = to_ascii({exp.exp}, 1):enclose_paren()

elseif exp.kind == "blockexp" then
local name = exp.sym
local sym = unpack_explist(exp.first)
exp_i = exp_i + 1
local name = sym.sym
if name == "matrix" then
local cells = grid_of_exps(exp.content.exps)
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)

-- @combine_matrix_brackets
res.my = math.floor(res.h/2)
g = res

elseif name == "pmatrix" then
local cells = grid_of_exps(exp.content.exps)

res.my = math.floor(res.h/2)
return res:enclose_paren()
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)
res.my = math.floor(res.h/2)
g = res:enclose_paren()

elseif name == "bmatrix" then
local cells = grid_of_exps(exp.content.exps)

local left_content, right_content = {}, {}
if res.h > 1 then
for y=1,res.h do
if y == 1 then
table.insert(left_content, style.matrix_upper_left)
table.insert(right_content, style.matrix_upper_right)
elseif y == res.h then
table.insert(left_content, style.matrix_lower_left)
table.insert(right_content, style.matrix_lower_right)
else
table.insert(left_content, style.matrix_vert_left)
table.insert(right_content, style.matrix_vert_right)
end
end
else
left_content = { style.matrix_single_left }
right_content = { style.matrix_single_right }
end
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)
local left_content, right_content = {}, {}
if res.h > 1 then
for y=1,res.h do
if y == 1 then
table.insert(left_content, style.matrix_upper_left)
table.insert(right_content, style.matrix_upper_right)
elseif y == res.h then
table.insert(left_content, style.matrix_lower_left)
table.insert(right_content, style.matrix_lower_right)
else
table.insert(left_content, style.matrix_vert_left)
table.insert(right_content, style.matrix_vert_right)
end
end
else
left_content = { style.matrix_single_left }
right_content = { style.matrix_single_right }
end

local leftbracket = grid:new(1, res.h, left_content)
local rightbracket = grid:new(1, res.h, right_content)
local leftbracket = grid:new(1, res.h, left_content)
local rightbracket = grid:new(1, res.h, right_content)

res = leftbracket:join_hori(res, true)
res = res:join_hori(rightbracket, true)
res = leftbracket:join_hori(res, true)
res = res:join_hori(rightbracket, true)

res.my = math.floor(res.h/2)
g = res
res.my = math.floor(res.h/2)
g = res

else
error("Unknown block expression " .. exp.sym)
error("Unknown block expression " .. name)
end


Expand Down
5 changes: 4 additions & 1 deletion lua/nabla/latex.lua
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,17 @@ function parse()

if sym.sym == "begin" then
local explist = parse()
local block_name = explist.exps[1]
table.remove(explist.exps, 1)

exp = {
kind = "blockexp",
first = block_name,
content = explist,
}

elseif sym.sym == "end" then
return explist
break
end

end
Expand Down
100 changes: 79 additions & 21 deletions src/core/ascii.lua.t
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ local result = leftgrid:combine_sub(rightgrid)
@make_spacer_lower_left+=
local spacer = grid:new(self.w, other.h)



@style_variables+=
matrix_upper_left = "",
matrix_upper_right = "",
Expand Down Expand Up @@ -751,21 +753,22 @@ elseif name == "lim" then

@transform_exp_to_grid+=
elseif exp.kind == "blockexp" then
local name = exp.sym
local sym = unpack_explist(exp.first)
exp_i = exp_i + 1
local name = sym.sym
@transform_block_expression
@otherwise_error_with_unknown_block_expression

@otherwise_error_with_unknown_block_expression+=
else
error("Unknown block expression " .. exp.sym)
error("Unknown block expression " .. name)
end

@transform_block_expression+=
if name == "matrix" then
local cells = grid_of_exps(exp.content.exps)
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)

@combine_to_matrix_grid
-- @combine_matrix_brackets
res.my = math.floor(res.h/2)
g = res

Expand All @@ -774,9 +777,8 @@ local grid_of_exps

@utility_functions+=
function grid_of_exps(explist)
local cells = {}
@make_grid_of_cells_from_exp_list
return cells
return cellsgrid, maxheight
end

@make_grid_of_cells_from_exp_list+=
Expand Down Expand Up @@ -808,14 +810,14 @@ while i <= #explist do
end

@switch_to_next_cell+=
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)
i = i+1
break

@switch_to_next_cell_and_row+=
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)

Expand All @@ -826,29 +828,85 @@ break

@if_last_one_add_cellgrid+=
if i > #explist then
local cellgrid = to_ascii(cell_list)
local cellgrid = to_ascii({cell_list}, 1)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)

table.insert(cellsgrid, rowgrid)
end

@make_grid_of_individual_cells+=
local cellsgrid = {}
local maxheight = 0
for _, row in ipairs(exp.rows) do
local rowgrid = {}
for _, cell in ipairs(row) do
local cellgrid = to_ascii(cell)
table.insert(rowgrid, cellgrid)
maxheight = math.max(maxheight, cellgrid.h)
end
table.insert(cellsgrid, rowgrid)
end

@declare_functions+=
local combine_matrix_grid

@utility_functions+=
function combine_matrix_grid(cellsgrid, maxheight)
local res
for i=1,#cellsgrid[1] do
local col
for j=1,#cellsgrid do
local cell = cellsgrid[j][i]
@add_row_spacer_to_center_cell
@add_col_spacer_to_center_cell
@add_cell_grid_to_row_grid
end
@add_row_grid_to_matrix_grid
end
return res
end

@add_cell_grid_to_row_grid+=
if not col then col = cell
else col = col:join_vert(cell, true) end

@add_row_grid_to_matrix_grid+=
if not res then res = col
else res = res:join_hori(col, true) end

@add_row_spacer_to_center_cell+=
local sup = maxheight - cell.h
local sdown = 0
local up, down
if sup > 0 then up = grid:new(cell.w, sup) end
if sdown > 0 then down = grid:new(cell.w, sdown) end

if up then cell = up:join_vert(cell) end
if down then cell = cell:join_vert(down) end

@add_col_spacer_to_center_cell+=
local colspacer = grid:new(1, cell.h)
colspacer.my = cell.my

if i < #cellsgrid[1] then
cell = cell:join_hori(colspacer)
end

@transform_block_expression+=
elseif name == "pmatrix" then
local cells = grid_of_exps(exp.content.exps)

@combine_to_matrix_grid
res.my = math.floor(res.h/2)
return res:enclose_paren()
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)
res.my = math.floor(res.h/2)
g = res:enclose_paren()

@transform_block_expression+=
elseif name == "bmatrix" then
local cells = grid_of_exps(exp.content.exps)

@combine_to_matrix_grid
@combine_matrix_brackets
res.my = math.floor(res.h/2)
g = res
local cellsgrid, maxheight = grid_of_exps(exp.content.exps)
local res = combine_matrix_grid(cellsgrid, maxheight)
@combine_matrix_brackets
res.my = math.floor(res.h/2)
g = res

@put_children_join_horiz+=
table.insert(c.children, { self, 0, s1 })
Expand Down
5 changes: 4 additions & 1 deletion src/latex/parser.lua.t
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,18 @@ end
@if_it_begins_until_enclosing_end+=
if sym.sym == "begin" then
local explist = parse()
local block_name = explist.exps[1]
table.remove(explist.exps, 1)

exp = {
kind = "blockexp",
first = block_name,
content = explist,
}

@if_it_ends_return_explist+=
elseif sym.sym == "end" then
return explist
break
end

@if_space_parse_as_single_space+=
Expand Down

0 comments on commit c58936e

Please sign in to comment.