diff --git a/examples/examplescommon.lua b/examples/examplescommon.lua index bcf37ff..5617008 100644 --- a/examples/examplescommon.lua +++ b/examples/examplescommon.lua @@ -19,9 +19,82 @@ C.identity = memoize(function(A) return identity end) -C.print = memoize(function(A) - local identity = RM.lift( "print_"..J.verilogSanitize(tostring(A)), A, A, 0, function(sinp) return sinp end, function() return CT.print(A) end, "C.print") - return identity +C.fassert = memoize(function(filename,A) + err(types.isBasic(A),"C.fassert: type should be basic, but is: "..tostring(A) ) + + local fassert = {name=J.verilogSanitize("fassert_"..tostring(A).."_file"..tostring(filename))} + fassert.inputType = A + fassert.outputType = A + fassert.sdfInput={{1,1}} + fassert.sdfOutput={{1,1}} + fassert.stateful=true + fassert.delay=0 + function fassert.makeTerra() return CT.fassert(filename,A) end + + function fassert.makeSystolic() + local sm = Ssugar.moduleConstructor(fassert.name) + local inp = S.parameter( "process_input", rigel.lower(fassert.inputType) ) + sm:addFunction( S.lambda("process", inp, inp, "process_output", nil,nil,S.CE("process_CE")) ) + sm:addFunction( S.lambda("reset", S.parameter("r",types.null()), nil, "reset_out") ) + return sm + end + return rigel.newFunction(fassert) +end) + +C.print = memoize(function(A,str) + err(types.isBasic(A),"C.print: type should be basic, but is: "..tostring(A) ) + + local function constructPrint(A,symb) + if A:isUint() or A:isInt() or A:isBits() then + return {A}, "%d", {symb} + elseif A:isArray() then + local resTypes={} + local resStr="[" + local resValues={} + for y=0,A.size[2]-1 do + for x=0,A.size[1]-1 do + local tT,tS,tV = constructPrint(A:arrayOver(),S.index(symb,x,y)) + resStr = resStr..tS.."," + table.insert(resTypes,tT) + table.insert(resValues,tV) + end + end + return J.flatten(resTypes), resStr.."]", J.flatten(resValues) + elseif A:isTuple() then + local resTypes={} + local resStr="{" + local resValues={} + for i=1,#A.list do + local tT,tS,tV = constructPrint(A.list[1],S.index(symb,i-1)) + resStr = resStr..tS.."," + table.insert(resTypes,tT) + table.insert(resValues,tV) + end + return J.flatten(resTypes), resStr.."}", J.flatten(resValues) + else + err(false,"C.print NYI - type: "..tostring(A)) + end + end + + --local identity = RM.lift( , A, A, 0, function(sinp) return sinp end, function() return CT.print(A,str) end, "C.print") + local res = {name = J.verilogSanitize("print_"..tostring(A).."_STR"..tostring(str)),inputType=A,outputType=A,sdfInput={{1,1}}, sdfOutput={{1,1}}} + res.stateful=false + res.delay=0 + function res.makeSystolic() + local sm = Ssugar.moduleConstructor(res.name) + local inp = S.parameter( "process_input", rigel.lower(res.inputType) ) + + local typelist, printStr, valuelist = constructPrint(A,inp) + if str~=nil then printStr = str.." "..printStr end + local printInst = sm:add( S.module.print( types.tuple(typelist), printStr, true):instantiate("printInst") ) + local pipelines = {printInst:process( S.tuple(valuelist) )} + sm:addFunction( S.lambda("process", inp, inp, "process_output", pipelines,nil,S.CE("process_CE")) ) + return sm + end + + function res.makeTerra() return CT.print(A,str) end + + return rigel.newFunction(res) end) C.cast = memoize(function(A,B) @@ -36,6 +109,15 @@ C.cast = memoize(function(A,B) return docast end) +C.bitSlice = memoize( + function(A,low,high) + local bitslice = RM.lift( J.sanitize("bitslice_"..tostring(A).."_"..tostring(low).."_"..tostring(high)), A, nil, 0, + function(sinp) + return S.bitSlice(sinp,low,high) + end) + return bitslice + end) + C.bitcast = memoize(function(A,B) err(types.isType(A),"cast: A should be type") err(types.isType(B),"cast: B should be type") @@ -948,7 +1030,7 @@ end) -- V -> RV -C.downsampleSeq = memoize(function( A, W, H, T, scaleX, scaleY, X ) +C.downsampleSeq = memoize(function( A, W, H, T, scaleX, scaleY, framed, X ) err( types.isType(A), "C.downsampleSeq: A must be type") err( type(W)=="number", "C.downsampleSeq: W must be number") err( type(H)=="number", "C.downsampleSeq: H must be number") @@ -959,6 +1041,12 @@ C.downsampleSeq = memoize(function( A, W, H, T, scaleX, scaleY, X ) err( scaleY>=1, "C.downsampleSeq: scaleY must be >=1") err( X==nil, "C.downsampleSeq: too many arguments" ) + err( W%scaleX==0,"C.downsampleSeq: NYI - scaleX does not divide W") + err( H%scaleY==0,"C.downsampleSeq: NYI - scaleY does not divide H") + + if framed==nil then framed=false end + err( type(framed)=="boolean", "C.donwsampleSeq: framed must be boolean" ) + if scaleX==1 and scaleY==1 then return C.identity(A) end @@ -970,6 +1058,7 @@ C.downsampleSeq = memoize(function( A, W, H, T, scaleX, scaleY, X ) end if scaleX>1 then local mod = modules.liftDecimate(modules.downsampleXSeq( A, W, H, T, scaleX )) + print(mod) if scaleY>1 then mod=modules.RPassthrough(mod) end out = rigel.apply("downsampleSeq_X", mod, out) @@ -979,7 +1068,14 @@ C.downsampleSeq = memoize(function( A, W, H, T, scaleX, scaleY, X ) out = rigel.apply("downsampleSeq_incrate", modules.RPassthrough(modules.changeRate(A,1,downsampleT,T)), out ) elseif downsampleT>T then assert(false) end end - return modules.lambda( J.sanitize("downsampleSeq_"..tostring(A).."_W"..tostring(W).."_H"..tostring(H).."_T"..tostring(T).."_scaleX"..tostring(scaleX).."_scaleY"..tostring(scaleY)), inp, out,nil,"C.downsampleSeq") + local res = modules.lambda( J.sanitize("downsampleSeq_"..tostring(A).."_W"..tostring(W).."_H"..tostring(H).."_T"..tostring(T).."_scaleX"..tostring(scaleX).."_scaleY"..tostring(scaleY).."_framed"..tostring(framed)), inp, out,nil,"C.downsampleSeq") + + if framed then + res.inputType = res.inputType:addDim(W,H,true) + res.outputType = res.outputType:addDim(math.ceil(W/scaleX),math.ceil(H/scaleY),true) + end + + return res end) @@ -1017,7 +1113,7 @@ end) -- takes A to A[T] by duplicating the input C.broadcast = memoize(function(A,W,H) err( types.isType(A), "C.broadcast: A must be type A") - rigel.expectBasic(A) + err( types.isBasic(A), "C.broadcast: type should be basic, but is: "..tostring(A)) err( type(W)=="number", "broadcast: W should be number") if H==nil then return C.broadcast(A,W,1) end err( type(H)=="number", "broadcast: H should be number") @@ -1107,11 +1203,11 @@ C.cropHelperSeq = memoize(function( A, W, H, T, L, R, B, Top, X ) end) -C.stencilLinebuffer = memoize(function( A, w, h, T, xmin, xmax, ymin, ymax ) - err(types.isType(A), "stencilLinebuffer: A must be type") +C.stencilLinebuffer = memoize(function( A, w, h, T, xmin, xmax, ymin, ymax, framed, X ) + err(types.isType(A), "stencilLinebuffer: A must be type, but is: "..tostring(A)) err(type(T)=="number","stencilLinebuffer: T must be number") - err(type(w)=="number","stencilLinebuffer: w must be number") + err(type(w)=="number","stencilLinebuffer: w must be number, but is: "..tostring(w)) err(type(h)=="number","stencilLinebuffer: h must be number") err(type(xmin)=="number","stencilLinebuffer: xmin must be number") err(type(xmax)=="number","stencilLinebuffer: xmax must be number") @@ -1126,7 +1222,15 @@ C.stencilLinebuffer = memoize(function( A, w, h, T, xmin, xmax, ymin, ymax ) err(xmax==0,"stencilLinebuffer: xmax must be 0") err(ymax==0,"stencilLinebuffer: ymax must be 0") - return C.compose( J.sanitize("stencilLinebuffer_A"..tostring(A).."_w"..w.."_h"..h.."_T"..T.."_xmin"..tostring(math.abs(xmin)).."_ymin"..tostring(math.abs(ymin))), modules.SSR( A, T, xmin, ymin), modules.linebuffer( A, w, h, T, ymin ), "C.stencilLinebuffer" ) + err(X==nil,"C.stencilLinebuffer: Too many arguments") + + local SSRFn = modules.SSR( A, T, xmin, ymin) + + if framed then + SSRFn = modules.mapFramed(SSRFn,w/T,h,false) + end + + return C.compose( J.sanitize("stencilLinebuffer_A"..tostring(A).."_w"..w.."_h"..h.."_T"..T.."_xmin"..tostring(math.abs(xmin)).."_ymin"..tostring(math.abs(ymin))), SSRFn, modules.linebuffer( A, w, h, T, ymin, framed ), "C.stencilLinebuffer" ) end) -- this is basically the same as a stencilLinebuffer, but implemend using a register chain instead of rams @@ -1187,7 +1291,9 @@ end) -- purely wiring. This should really be implemented as a lift. -C.unpackStencil = memoize(function( A, stencilW, stencilH, T, arrHeight, X ) +-- framed: this fn is a bit strange (actually introduces a new dimension), so can't use mapFramed +-- instead, special case this +C.unpackStencil = memoize(function( A, stencilW, stencilH, T, arrHeight, framed, framedW, framedH, X ) assert(types.isType(A)) assert(type(stencilW)=="number") assert(stencilW>0) @@ -1196,11 +1302,29 @@ C.unpackStencil = memoize(function( A, stencilW, stencilH, T, arrHeight, X ) assert(type(T)=="number") assert(T>=1) err(arrHeight==nil, "Error: NYI - unpackStencil on non-height-1 arrays") + err( framed==nil or type(framed)=="boolean", "unpackStencil: framed must be nil or bool") + if framed==nil then framed=false end assert(X==nil) local res = {kind="unpackStencil", stencilW=stencilW, stencilH=stencilH,T=T,generator="C.unpackStencil"} res.inputType = types.array2d( A, stencilW+T-1, stencilH) res.outputType = types.array2d( types.array2d( A, stencilW, stencilH), T ) + + if framed then + err( type(framedW)=="number", "unpackStencil: framedW must be nil or number") + err( type(framedH)=="number", "unpackStencil: framedH must be nil or number") + +-- local idims = types.FramedCollectParallelDims(res.inputType) +-- local odims = types.FramedCollectParallelDims(res.outputType) + +-- table.insert(idims,{framedW/T,framedH}) +-- odims[#odims]={framedW,framedH} + + res.inputType = res.inputType:addDim(framedW/T,framedH,false) + res.outputType = res.outputType:addDim(framedW,framedH,true) + print("UPSOP",res.outputType) + end + res.sdfInput, res.sdfOutput = {{1,1}}, {{1,1}} res.stateful = false res.delay=0 @@ -1209,7 +1333,7 @@ C.unpackStencil = memoize(function( A, stencilW, stencilH, T, arrHeight, X ) if terralib~=nil then res.terraModule = CT.unpackStencil(res, A, stencilW, stencilH, T, arrHeight) end res.systolicModule = Ssugar.moduleConstructor(res.name) - local sinp = S.parameter("inp", res.inputType) + local sinp = S.parameter("inp", rigel.extractData(res.inputType) ) local out = {} for i=1,T do out[i] = {} @@ -1220,7 +1344,7 @@ C.unpackStencil = memoize(function( A, stencilW, stencilH, T, arrHeight, X ) end end - res.systolicModule:addFunction( S.lambda("process", sinp, S.cast( S.tuple(J.map(out,function(n) return S.cast( S.tuple(n), types.array2d(A,stencilW,stencilH) ) end)), res.outputType ), "process_output", nil, nil, S.CE("process_CE") ) ) + res.systolicModule:addFunction( S.lambda("process", sinp, S.cast( S.tuple(J.map(out,function(n) return S.cast( S.tuple(n), types.array2d(A,stencilW,stencilH) ) end)), rigel.extractData(res.outputType) ), "process_output", nil, nil, S.CE("process_CE") ) ) --res.systolicModule:addFunction( S.lambda("reset", S.parameter("r",types.null()), nil, "ro" ) ) return rigel.newFunction(res) @@ -1313,7 +1437,7 @@ C.generalizedChangeRate = memoize(function(inputBitsPerCyc, minTotalInputBits, i assert(type(minTotalOutputBits)=="number") assert(type(inputFactor)=="number") assert(type(outputFactor)=="number") - assert(minTotalInputBits%inputBitsPerCyc==0) + err(minTotalInputBits%inputBitsPerCyc==0,"generalizedChangeRate: inputBitsPerCycle ("..inputBitsPerCyc..") must divide minTotalInputBits ("..minTotalInputBits..")") assert(minTotalOutputBits%outputBitsPerCyc==0) assert(X==nil) @@ -1463,4 +1587,29 @@ function C.linearPipeline(t,modulename) return RM.lambda(modulename,inp,out) end +-- Hacky module for internal use: just convert a Handshake to a HandshakeFramed +C.handshakeToHandshakeFramed = memoize( + function(A,mixed,dims,X) + err( type(dims)=="table", "handshakeToHandshakeFramed: dims should be table") + assert(X==nil) + err(R.isHandshake(A),"handshakeToHandshakeFramed: input should be handshake") + local res = {inputType=A,outputType=types.HandshakeFramed(A.params.A,mixed,dims),sdfInput={{1,1}},sdfOutput={{1,1}},stateful=false} + res.name=J.sanitize("HandshakeToHandshakeFramed_"..tostring(A)) + + function res.makeSystolic() + local sm = Ssugar.moduleConstructor(res.name):onlyWire(true) + local r = S.parameter("ready_downstream",types.bool()) + sm:addFunction( S.lambda("ready", r, r, "ready") ) + local I = S.parameter("process_input", R.lower(A) ) + sm:addFunction( S.lambda("process",I,I,"process_output") ) + sm:addFunction( S.lambda("reset", S.parameter("r",types.null()), nil, "reset_out") ) + return sm + end + function res.makeTerra() + return CT.handshakeToHandshakeFramed(res,A,mixed,dims) + end + + return rigel.newFunction(res) + end) + return C diff --git a/examples/examplescommonTerra.t b/examples/examplescommonTerra.t index dddbed2..c54809e 100644 --- a/examples/examplescommonTerra.t +++ b/examples/examplescommonTerra.t @@ -18,11 +18,39 @@ function CT.identity(A) end end -function CT.print(A) +function CT.fassert(filename,ty) + assert(ty:verilogBits()%8==0) + assert(ty:verilogBits()==ty:sizeof()*8) + local struct Fassert { file : &cstdio.FILE, token:uint } + terra Fassert:init() + self.file = cstdio.fopen(filename, "rb") + [J.darkroomAssert](self.file~=nil, ["file "..filename.." doesnt exist"]) + end + terra Fassert:reset() self.token=0 end + terra Fassert:free() cstdio.fclose(self.file) end + terra Fassert:process(inp : &ty:toTerraType(), out : &ty:toTerraType()) + var outBytes = cstdio.fread(out,1,[ty:sizeof()],self.file) + [J.darkroomAssert](outBytes==[ty:sizeof()], "Error, freadSeq failed, probably end of file?") + + if @inp~=@out then + cstdio.printf("Fassert: file and input don't match!!\n") + cstdio.printf("input: %d/%#x\n",@inp,@inp) + cstdio.printf("file: %d/%#x\n",@out,@out) + cstdio.printf("at token number: %d, byte: %d\n",self.token,self.token*[ty:sizeof()]) + cstdlib.exit(1) + end + self.token = self.token+1 + end + return MT.new(Fassert) +end + +function CT.print(A,str) + err(types.isBasic(A),"CT.print: type should be basic, but is: "..tostring(A) ) + local function doprint(A,symb) assert(symb~=nil) - + if A:isArray() then local tab = {} table.insert(tab,quote cstdio.printf("[") end) @@ -42,19 +70,25 @@ function CT.print(A) table.insert(tab,quote cstdio.printf("}") end) return quote [tab] end elseif A:isUint() or A:isInt() or A:isBits() then - return quote cstdio.printf("%d",symb) end + return quote cstdio.printf("%d/%#x",symb,symb) end else print(A) assert(false) end end - - return terra( a : &A:toTerraType(), out : &A:toTerraType() ) + + local printS = quote end + if str~=nil then printS = quote cstdio.printf("%s:",str) end end + + local struct PrintModule {} + terra PrintModule:process( a : &A:toTerraType(), out : &A:toTerraType() ) var aa = @a + printS [doprint(A,aa)] cstdio.printf("\n") @out = @a end + return MT.new(PrintModule) end function CT.cast(A,B) @@ -295,4 +329,13 @@ function CT.plusConsttfn(ty,value) end end +function CT.handshakeToHandshakeFramed(res,A,mixed,dims) + local struct HandshakeToHandshakeFramed {ready:bool} + terra HandshakeToHandshakeFramed:process( inp:&rigel.lower(res.inputType):toTerraType(), out:&rigel.lower(res.outputType):toTerraType()) + @out = @inp + end + terra HandshakeToHandshakeFramed:calculateReady(readyDownstream:bool) self.ready = readyDownstream end + return MT.new(HandshakeToHandshakeFramed) +end + return CT diff --git a/examples/gold/convgenTaps.terra.cycles.txt b/examples/gold/convgenTaps.terra.cycles.txt deleted file mode 120000 index c0bfeb5..0000000 --- a/examples/gold/convgenTaps.terra.cycles.txt +++ /dev/null @@ -1 +0,0 @@ -convgen.terra.cycles.txt \ No newline at end of file diff --git a/examples/gold/convgen.bmp b/examples/gold/soc_convgen.bmp similarity index 100% rename from examples/gold/convgen.bmp rename to examples/gold/soc_convgen.bmp diff --git a/examples/gold/convgen.regout.lua b/examples/gold/soc_convgen.regout.lua similarity index 100% rename from examples/gold/convgen.regout.lua rename to examples/gold/soc_convgen.regout.lua diff --git a/examples/gold/convgen.terra.cycles.txt b/examples/gold/soc_convgen.terra.cycles.txt similarity index 100% rename from examples/gold/convgen.terra.cycles.txt rename to examples/gold/soc_convgen.terra.cycles.txt diff --git a/examples/gold/convgenTaps.bmp b/examples/gold/soc_convgenTaps.bmp similarity index 100% rename from examples/gold/convgenTaps.bmp rename to examples/gold/soc_convgenTaps.bmp diff --git a/examples/gold/convgenTaps.regout.lua b/examples/gold/soc_convgenTaps.regout.lua similarity index 100% rename from examples/gold/convgenTaps.regout.lua rename to examples/gold/soc_convgenTaps.regout.lua diff --git a/examples/gold/soc_convgenTaps.terra.cycles.txt b/examples/gold/soc_convgenTaps.terra.cycles.txt new file mode 100644 index 0000000..da25838 --- /dev/null +++ b/examples/gold/soc_convgenTaps.terra.cycles.txt @@ -0,0 +1 @@ +2092819 diff --git a/examples/gold/soc_convtest.bmp b/examples/gold/soc_convtest.bmp new file mode 100644 index 0000000..dd89f62 Binary files /dev/null and b/examples/gold/soc_convtest.bmp differ diff --git a/examples/gold/soc_convtest.regout.lua b/examples/gold/soc_convtest.regout.lua new file mode 100644 index 0000000..3941af5 --- /dev/null +++ b/examples/gold/soc_convtest.regout.lua @@ -0,0 +1 @@ +return {} \ No newline at end of file diff --git a/examples/gold/soc_convtest.terra.cycles.txt b/examples/gold/soc_convtest.terra.cycles.txt new file mode 100644 index 0000000..99605f1 --- /dev/null +++ b/examples/gold/soc_convtest.terra.cycles.txt @@ -0,0 +1 @@ +7144 \ No newline at end of file diff --git a/examples/gold/soc_read.bmp b/examples/gold/soc_read.bmp new file mode 120000 index 0000000..117869b --- /dev/null +++ b/examples/gold/soc_read.bmp @@ -0,0 +1 @@ +soc_convtest.bmp \ No newline at end of file diff --git a/examples/gold/soc_read.regout.lua b/examples/gold/soc_read.regout.lua new file mode 100644 index 0000000..3941af5 --- /dev/null +++ b/examples/gold/soc_read.regout.lua @@ -0,0 +1 @@ +return {} \ No newline at end of file diff --git a/examples/gold/soc_read.terra.cycles.txt b/examples/gold/soc_read.terra.cycles.txt new file mode 100644 index 0000000..24e1814 --- /dev/null +++ b/examples/gold/soc_read.terra.cycles.txt @@ -0,0 +1 @@ +3900 \ No newline at end of file diff --git a/examples/harnessTerraSOC.t b/examples/harnessTerraSOC.t index d4ff7ed..ed18170 100644 --- a/examples/harnessTerraSOC.t +++ b/examples/harnessTerraSOC.t @@ -193,7 +193,7 @@ return function(top, memStart, memEnd) local addr = tonumber("0x"..addr) print("BYTES",bytes,"addr",addr) - for b=0,bytes/4-1 do + for b=0,math.ceil(bytes/4)-1 do local dat = string.sub(v,b*8+1,(b+1)*8) local data = tonumber("0x"..dat) print("DAT",dat,data) @@ -237,9 +237,9 @@ return function(top, memStart, memEnd) var cycle = 0 - if round==0 then +-- if round==0 then [setTaps] - end +-- end setReg( IP_CLK, IP_ARESET_N, m, 0xA0000000, 1 ) diff --git a/examples/makefile b/examples/makefile index 037f700..189004d 100644 --- a/examples/makefile +++ b/examples/makefile @@ -1,7 +1,7 @@ BUILDDIR ?= out # soc_flipWrite.lua -SRCS_SOC = soc_simple.lua soc_2in.lua convgen.lua convgenTaps.lua soc_flip.lua soc_15x15.lua soc_15x15x15.lua soc_flipWrite.lua soc_regin.lua soc_regout.lua +SRCS_SOC = soc_simple.lua soc_2in.lua soc_convgen.lua soc_convgenTaps.lua soc_flip.lua soc_15x15.lua soc_15x15x15.lua soc_flipWrite.lua soc_regin.lua soc_regout.lua soc_convtest.lua soc_read.lua VERILATOR_SOC = $(patsubst %.lua,$(BUILDDIR)/%.verilatorSOC.bit,$(SRCS_SOC)) VERILATOR_SOC = $(patsubst %.lua,$(BUILDDIR)/%.verilatorSOC.raw,$(SRCS_SOC)) diff --git a/examples/convgen.lua b/examples/soc_convgen.lua similarity index 69% rename from examples/convgen.lua rename to examples/soc_convgen.lua index 57c30a9..82f2fbf 100644 --- a/examples/convgen.lua +++ b/examples/soc_convgen.lua @@ -18,10 +18,10 @@ padSize = { 1920+16, 1080+3 } local conv = Module{ ar(u(8),ConvWidth,ConvWidth), function(inp) inp = Map{AddMSBs{24}}(inp) - local coeff = c{ar(u(32),ConvWidth,ConvWidth),{4, 14, 14, 4, - 14, 32, 32, 14, - 14, 32, 32, 14, - 4, 14, 14, 4}} + local coeff = c({4, 14, 14, 4, + 14, 32, 32, 14, + 14, 32, 32, 14, + 4, 14, 14, 4},ar(u(32),ConvWidth,ConvWidth)) local z = Zip(inp,coeff) local out = Map{Mul}(z) local res = Reduce{Add}(out) @@ -35,6 +35,6 @@ harness{ -- RS.HS(C.print(ar(u(8),1))), HS{Linebuffer{padSize,1,{3,0,3,0}}}, HS{Map{conv}}, - HS{Crop{padSize,1,{9,7,3,0}}}, - SOC.writeBurst("out/convgen",1920,1080,u(8),1), + HS{CropSeq{padSize,1,{9,7,3,0}}}, + SOC.writeBurst("out/soc_convgen",1920,1080,u(8),1), regs.done} diff --git a/examples/convgenTaps.lua b/examples/soc_convgenTaps.lua similarity index 90% rename from examples/convgenTaps.lua rename to examples/soc_convgenTaps.lua index d14009f..b5151cf 100644 --- a/examples/convgenTaps.lua +++ b/examples/soc_convgenTaps.lua @@ -35,6 +35,6 @@ harness{ HS{Pad{inSize,1,{8,8,2,1}}}, HS{Linebuffer{padSize,1,{3,0,3,0}}}, HS{Map{conv}}, - HS{Crop{padSize,1,{9,7,3,0}}}, - SOC.writeBurst("out/convgenTaps",1920,1080,u(8),1), + HS{CropSeq{padSize,1,{9,7,3,0}}}, + SOC.writeBurst("out/soc_convgenTaps",1920,1080,u(8),1), regs.done} diff --git a/examples/soc_convtest.lua b/examples/soc_convtest.lua new file mode 100644 index 0000000..8e81b4f --- /dev/null +++ b/examples/soc_convtest.lua @@ -0,0 +1,24 @@ +local R = require "rigel" +local SOC = require "soc" +local C = require "examplescommon" +local harness = require "harnessSOC" +local G = require "generators" +local RS = require "rigelSimple" +require "types".export() + +regs = SOC.axiRegs{}:instantiate() + +ConvTop = G.Module{ + function(i) + local readStream = G.AXIReadBurst{"frame_128.raw",{128,64},u(8),1}(i) + local O = G.HS{G.Stencil{{2,0,2,0}}}(readStream) + local OC = G.HS{G.Crop{{2,6,2,6}}}(O) + OC = G.HS{G.Downsample{{4,4}}}(OC) +-- print(OC.fn) + OC = G.HS{G.Map{G.Map{G.Rshift{3}}}}(OC) + local OM = G.HS{G.Map{G.Reduce{G.Add}}}(OC) +-- print("OMTYPE",OM.type) + return G.AXIWriteBurst{"out/soc_convtest"}(OM) + end} + +harness{regs.start, ConvTop, regs.done} diff --git a/examples/soc_flip.lua b/examples/soc_flip.lua index 652dcf4..f207a52 100644 --- a/examples/soc_flip.lua +++ b/examples/soc_flip.lua @@ -13,7 +13,7 @@ local W,H = 128,64 addrGen = Module{function(inp) local x, y = Index{0}(Index{0}(inp)), Index{1}(Index{0}(inp)) local resx = AddMSBs{16}(x) - local resy = Mul( Sub(c{H-1,u(32)},AddMSBs{16}(y)),c{W/8,u(32)} ) + local resy = Mul( Sub(c(H-1,u(32)),AddMSBs{16}(y)),c(W/8,u(32)) ) return Add(resx,resy) end} diff --git a/examples/soc_flipWrite.lua b/examples/soc_flipWrite.lua index 94670b9..b93f69d 100644 --- a/examples/soc_flipWrite.lua +++ b/examples/soc_flipWrite.lua @@ -13,7 +13,7 @@ local W,H = 128,64 AddrGen = Module{function(inp) local x, y = Index{0}(Index{0}(inp)), Index{1}(Index{0}(inp)) local resx = AddMSBs{16}(x) - local resy = Mul( Sub(c{H-1,u(32)},AddMSBs{16}(y)),c{W/8,u(32)} ) + local resy = Mul( Sub(c(H-1,u(32)),AddMSBs{16}(y)),c(W/8,u(32)) ) return Add(resx,resy) end} diff --git a/examples/soc_read.lua b/examples/soc_read.lua new file mode 100644 index 0000000..bb59efe --- /dev/null +++ b/examples/soc_read.lua @@ -0,0 +1,37 @@ +local R = require "rigel" +local G = require "generators" +local SOC = require "soc" +local harness = require "harnessSOC" +require "types".export() + +regs = SOC.axiRegs{}:instantiate() + +PosToAddr = G.Module{ "PosToAddr", ar(u16,2), + function(loc) + local i = G.PosSeq{{3,3},0}() -- inner loop from 0...2 + local x, y = G.Add(loc[0],i[0]), G.Add(loc[1],i[1]) -- x=loc.x+i.x, y=loc.y+i.y + local res = G.Add(x,G.Mul(y,R.constant(128,u16))) -- addr = x + y*W (W=128) + return G.AddMSBs{16}(res) + end} + +ReadStencilDMA = G.Module{ "ReadStencilDMA", Handshake(ar(u16,2)), + function(loc) + local locb = G.HS{G.BroadcastSeq{{3,3},0}}(loc) -- duplicate input (x,y) 9 times + local addrStream = G.HS{PosToAddr}(locb) + local pxStream = G.AXIRead{"frame_128.raw",128*64}(addrStream) + local pxArr = G.HS{G.Broadcast{1}}(pxStream) -- convert to array with 1 element + return G.HS{G.Deser{3*3}}(pxArr) -- Deserialize 9 reads into 9 element array + end} + +ConvTop = G.Module{ "ConvTop", + function(i) + i = G.TriggerBroadcast{30*14}(i) + local pos = G.HS{G.Pos{{30,14},0}}(i) + local posScaled = G.HS{G.Map{G.Map{G.Mul{4}}}}(pos) -- mult coords by 4 + local stencil = G.Map{ReadStencilDMA}(posScaled) + local shifted = G.HS{G.Map{G.Map{G.Rshift{3}}}}(stencil) + local fin = G.HS{G.Map{G.Reduce{G.Add}}}(shifted) + return G.AXIWriteBurst{"out/soc_read"}(fin) + end} + +harness{regs.start, ConvTop, regs.done} diff --git a/examples/soc_regout.lua b/examples/soc_regout.lua index 22d1f68..e8e0c30 100644 --- a/examples/soc_regout.lua +++ b/examples/soc_regout.lua @@ -4,26 +4,22 @@ local C = require "examplescommon" local harness = require "harnessSOC" local G = require "generators" local RS = require "rigelSimple" -require "types".export() +local types = require "types" +types.export() -local Regs = SOC.axiRegs{offset={u(32),200},lastPx={u(8),0,"out"}} +local Regs = SOC.axiRegs{offset={u(8),200},lastPx={u(8),0,"out"}} regs = Regs:instantiate() -print("REGS",Regs.offset) +local AddReg = G.Module{"AddReg",function(i) return G.Add(i,Regs.offset) end} -local AddReg = G.Module{"AddReg",function(i) return G.Add(i,G.RemoveMSBs{24}(Regs.offset)) end} - -local Top = G.Module{"Top", +local RegInOut = G.Module{ function(i) - local o = regs.start(i) - o = SOC.readBurst("frame_128.raw",128,64,u(8),0)(o) - local pipeOut = G.HS{AddReg}(o) + local inputStream = G.AXIReadBurst{"frame_128.raw",{128,64},u(8),0}(i) + local pipeOut = G.HS{G.Map{AddReg}}(inputStream) pipeOut = G.FanOut{2}(pipeOut) - local pipeOut0 = R.selectStream("ob0",pipeOut,0) - local pipeOut1 = R.selectStream("ob1",pipeOut,1) - o = SOC.writeBurst("out/soc_regout",128,64,u(8),0)(pipeOut0) - o = regs.done(o) - return R.statements{o,R.writeGlobal("go",Regs.lastPx,pipeOut1)} + local doneFlag = G.AXIWriteBurst{"out/soc_regout"}(pipeOut[0]) + local writeRegStatement = G.Map{G.WriteGlobal{Regs.lastPx}}(pipeOut[1]) + return R.statements{doneFlag,writeRegStatement} end} -harness(Top) +harness{regs.start,RegInOut,regs.done} diff --git a/examples/soc_simple.lua b/examples/soc_simple.lua index 061e100..d58fb2f 100644 --- a/examples/soc_simple.lua +++ b/examples/soc_simple.lua @@ -4,15 +4,19 @@ local C = require "examplescommon" local harness = require "harnessSOC" local G = require "generators" local RS = require "rigelSimple" -require "types".export() +local types = require "types" +types.export() regs = SOC.axiRegs{}:instantiate() -OffsetModule = G.Module{ +OffsetModule = G.Module{ "OffsetModule", R.HandshakeTrigger, function(i) local readStream = G.AXIReadBurst{"frame_128.raw",{128,64},u(8),8}(i) local offset = G.HS{G.Map{G.Add{200}}}(readStream) - return G.AXIWriteBurst{"out/soc_simple",{128,64} ,u(8),8}(offset) + return G.AXIWriteBurst{"out/soc_simple"}(offset) end} +--print(OffsetModule) +--print(OffsetModule:toVerilog()) + harness{regs.start, OffsetModule, regs.done} diff --git a/misc/rigelSimple.lua b/misc/rigelSimple.lua index 670e3c7..15d361b 100644 --- a/misc/rigelSimple.lua +++ b/misc/rigelSimple.lua @@ -355,18 +355,18 @@ function RS.HS(t,handshakeTrigger) if types.isType(t) then return R.Handshake(t) elseif R.isFunction(t) then - if R.isV(t.inputType) and R.isRV(t.outputType) then + if (R.isV(t.inputType) or t.inputType:is("VFramed")) and (R.isRV(t.outputType) or t.outputType:is("RVFramed")) then return RM.liftHandshake(t) elseif R.isHandshake(t.inputType) then return t -- elseif (R.isBasic(t.inputType) and R.isV(t.outputType)) or (t.outputType:isTuple() and #t.outputType.list>1 and t.outputType.list[2]:isBool()) then - elseif (R.isBasic(t.inputType) and R.isV(t.outputType)) then + elseif (R.isBasic(t.inputType) or t.inputType:is("StaticFramed")) and (R.isV(t.outputType) or t.outputType:is("VFramed")) then return RM.liftHandshake(RM.liftDecimate(t)) - elseif R.isBasic(t.inputType) and R.isBasic(t.outputType) then + elseif (R.isBasic(t.inputType) and R.isBasic(t.outputType)) or ((t.inputType:is("StaticFramed") or t.inputType==types.null()) and t.outputType:is("StaticFramed")) then if handshakeTrigger==nil then handshakeTrigger = true end return RM.makeHandshake(t,nil,handshakeTrigger) else - print(t.inputType, t.outputType) + print("RS.HS: could not lift fn with input type:",t.inputType, "outputType:",t.outputType) assert(false) end else diff --git a/modules/soc.lua b/modules/soc.lua index 82e627a..36619bd 100644 --- a/modules/soc.lua +++ b/modules/soc.lua @@ -889,8 +889,8 @@ SOC.axiReadBytes = J.memoize(function(filename,Nbytes,port,addressBase, X) J.err( type(port)=="number", "axiReadBytes: port must be number" ) J.err( port>=0 and port<=SOC.ports,"axiReadBytes: port out of range" ) J.err( type(Nbytes)=="number","axiReadBytes: Nbytes must be number" ) - J.err( Nbytes%8==0, "axiReadBytes: Nbytes must have 8 as a factor" ) - J.err( Nbytes>=8, "axiReadBytes: NYI - Nbytes must be >=8" ) +-- J.err( Nbytes%8==0, "axiReadBytes: Nbytes must have 8 as a factor" ) +-- J.err( Nbytes>=8, "axiReadBytes: NYI - Nbytes must be >=8" ) J.err( X==nil, "axiReadBytes: too many arguments" ) J.err( type(addressBase)=="number", "axiReadBytes: addressBase must be number") @@ -908,7 +908,7 @@ SOC.axiReadBytes = J.memoize(function(filename,Nbytes,port,addressBase, X) local ModuleName = J.sanitize("AXI_READ_BYTES_"..tostring(Nbytes).."_"..tostring(port)) - local burstCount = Nbytes/8 + local burstCount = math.ceil(Nbytes/8) J.err( burstCount<=16,"axiReadBytes: NYI - burst longer than 16") local res = RM.liftVerilog( ModuleName, R.Handshake(types.uint(32)), R.Handshake(types.bits(64)), @@ -1383,7 +1383,7 @@ SOC.bulkRamWrite = J.memoize(function(port) return BRR end) -SOC.readBurst = J.memoize(function(filename,W,H,ty,V,X) +SOC.readBurst = J.memoize(function(filename,W,H,ty,V,framed,X) J.err( type(filename)=="string","readBurst: filename must be string") J.err( type(W)=="number", "readBurst: W must be number") J.err( type(H)=="number", "readBurst: H must be number") @@ -1394,6 +1394,8 @@ SOC.readBurst = J.memoize(function(filename,W,H,ty,V,X) --J.err( nbytes%128==0,"NYI - readBurst requires 128-byte aligned size" ) J.err( V==nil or type(V)=="number", "readBurst: V must be number or nil") if V==nil then V=0 end + + J.err( framed==nil or type(framed)=="boolean", "framed must be nil or bool, but is: "..tostring(framed)) J.err(X==nil, "readBurst: too many arguments") local globalMetadata={} @@ -1430,6 +1432,10 @@ SOC.readBurst = J.memoize(function(filename,W,H,ty,V,X) local outType = ty if V>0 then outType=types.array2d(outType,V) end out = RM.makeHandshake(C.cast(types.bits(outType:verilogBits()),outType))(out) + + if framed then + out = C.handshakeToHandshakeFramed(out.type,V>0,{{W,H}})(out) + end local res = RM.lambda("ReadBurst_Wf"..W.."_H"..H.."_v"..V.."_port"..SOC.currentMAXIReadPort.."_addr"..SOC.currentAddr.."_"..tostring(ty),inp,out,nil,nil,nil,globalMetadata) @@ -1441,11 +1447,14 @@ end) -- This works like C pointer deallocation: -- input address N of type T actually reads at physical memory address N*sizeof(T)+base +-- readType: this is the output type we want SOC.read = J.memoize(function(filename,fileBytes,readType,X) - J.err( type(filename)=="string","read: filename must be string") + J.err( type(filename)=="string","read: filename must be string, but is: "..tostring(filename)) J.err( types.isType(readType), "read: type must be type") - J.err( readType:verilogBits()%8==0, "SOC.read: NYI - type must be byte aligned") - J.err( readType:verilogBits()%64==0, "SOC.read: NYI - type must be 8 byte aligned") + J.err( types.isBasic(readType), "read: type must be basic type, but is: "..tostring(readType)) + J.err( readType:verilogBits()%8==0, "SOC.read: NYI - type must be byte aligned, but is: "..tostring(readType)) +-- J.err( readType:verilogBits()%64==0, "SOC.read: NYI - type must be 8 byte aligned, but is: "..tostring(readType)) + J.err( X==nil, "SOC.read: too many arguments" ) local globalMetadata={} globalMetadata["MAXI"..SOC.currentMAXIReadPort.."_read_address"] = SOC.currentAddr @@ -1481,6 +1490,9 @@ SOC.read = J.memoize(function(filename,fileBytes,readType,X) out = RM.liftHandshake(RM.changeRate(types.bits(64),1,1,N))(out) out = RM.makeHandshake(C.bitcast(types.array2d(types.bits(64),N),readType))(out) + elseif readType:verilogBits()<64 then + out = RM.makeHandshake(C.bitSlice(types.bits(64),0,readType:verilogBits()-1))(out) + out = RM.makeHandshake(C.bitcast(types.bits(readType:verilogBits()),readType))(out) else assert(false) end @@ -1493,7 +1505,7 @@ SOC.read = J.memoize(function(filename,fileBytes,readType,X) return res end) -SOC.writeBurst = J.memoize(function(filename,W,H,ty,V,X) +SOC.writeBurst = J.memoize(function(filename,W,H,ty,V,framed,X) J.err( type(filename)=="string","writeBurst: filename must be string") J.err( type(W)=="number", "writeBurst: W must be number") J.err( type(H)=="number", "writeBurst: H must be number") @@ -1503,8 +1515,10 @@ SOC.writeBurst = J.memoize(function(filename,W,H,ty,V,X) --J.err( nbytes%8==0,"NYI - writeBurst requires 8-byte aligned size (input bytes is: "..nbytes..")" ) J.err( V==nil or type(V)=="number", "writeBurst: V must be number or nil") if V==nil then V=1 end + if framed==nil then framed=false end + J.err( type(framed)=="boolean","writeBurst: framed must be boolean") J.err(X==nil, "writeBurst: too many arguments") - + local globalMetadata={} globalMetadata["MAXI"..SOC.currentMAXIWritePort.."_write_W"] = W globalMetadata["MAXI"..SOC.currentMAXIWritePort.."_write_H"] = H @@ -1546,6 +1560,11 @@ SOC.writeBurst = J.memoize(function(filename,W,H,ty,V,X) SOC.currentMAXIWritePort = SOC.currentMAXIWritePort+1 SOC.currentAddr = SOC.currentAddr+totalBits/8 + if framed then + -- HACK + res.inputType = types.HandshakeFramed(inputType,V>0,{{W,H}}) + end + return res end) diff --git a/modules/socTerra.t b/modules/socTerra.t index 74ed672..a96cc2f 100644 --- a/modules/socTerra.t +++ b/modules/socTerra.t @@ -32,7 +32,7 @@ function SOCMT.axiRegs(mod,tab,port) local ty = mod.globalMetadata["TypeOfRegister_"..k] local glob = mod:getGlobal(k):terraValue() - for i=0,ty:verilogBits()/32-1 do + for i=0,math.ceil(ty:verilogBits()/32)-1 do table.insert(regSet, quote if data([mod:getGlobal("IP_SAXI"..port.."_AWADDR"):terraValue()])==[addr+i*4] then @@ -192,24 +192,24 @@ function SOCMT.axiBurstWriteN( mod, Nbytes, port, baseAddress ) assert(type(baseAddress)=="number") --print("AXIBURSTWRITE ",Nbytes,port) - local struct WriteBurst { nextByteToWrite:uint, ready:bool, addrReadyDownstream:bool, doneReg:bool, writtenBytes:uint, readyDownstream:bool, dataBuffer:uint64, writeFirst:bool } + local struct WriteBurst { nextAddrToWrite:uint, ready:bool, addrReadyDownstream:bool, doneReg:bool, writtenBytes:uint, readyDownstream:bool, dataBuffer:uint64, writeFirst:bool } local inputType = R.Handshake(types.bits(64)) --local stride = (R.extractData(inputType):sizeof()) terra WriteBurst:reset() - self.nextByteToWrite = Nbytes + self.nextAddrToWrite = Nbytes self.doneReg = true self.writtenBytes = Nbytes self.writeFirst = false end terra WriteBurst:process( dataIn:&R.lower(inputType):toTerraType(), done:&bool ) - if valid(dataIn) and self.nextByteToWrite==Nbytes and self.writtenBytes==Nbytes and self.writeFirst==false then + if valid(dataIn) and self.nextAddrToWrite==Nbytes and self.writtenBytes==Nbytes and self.writeFirst==false then -- starting - cstdio.printf("BurstWrite: start to send addresses %d %d %d\n",self.nextByteToWrite,valid(dataIn),data(dataIn)) + cstdio.printf("BurstWrite: start to send addresses nextAddrToWrite:%d valid:%d data:%d/%#x\n",self.nextAddrToWrite,valid(dataIn),data(dataIn),data(dataIn)) - self.nextByteToWrite = 0 + self.nextAddrToWrite = 0 self.dataBuffer = data(dataIn) self.doneReg = false self.writeFirst = true @@ -217,15 +217,15 @@ function SOCMT.axiBurstWriteN( mod, Nbytes, port, baseAddress ) valid([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]) = false elseif self.writeFirst then - cstdio.printf("WRITEFIRST %d %d\n",self.nextByteToWrite,[mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraReady()]) - valid([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]) = (self.nextByteToWrite>0) + cstdio.printf("WRITEFIRST nextAddrToWrite:%d IP_MAXI_WDATA_READY:%d data:%d/%#x\n",self.nextAddrToWrite,[mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraReady()],self.dataBuffer,self.dataBuffer) + valid([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]) = (self.nextAddrToWrite>0) cstring.memcpy( &data([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]), &self.dataBuffer, [R.extractData(inputType):sizeof()] ) - if (self.nextByteToWrite>0) and [mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraReady()] then + if (self.nextAddrToWrite>0) and [mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraReady()] then self.writeFirst=false self.writtenBytes = self.writtenBytes + 8 end - elseif valid(dataIn) and self.nextByteToWrite>0 then + elseif valid(dataIn) and self.nextAddrToWrite>0 then valid([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]) = true cstring.memcpy( &data([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]), &data(dataIn), [R.extractData(inputType):sizeof()] ) @@ -233,17 +233,24 @@ function SOCMT.axiBurstWriteN( mod, Nbytes, port, baseAddress ) if self.ready then self.writtenBytes = self.writtenBytes + 8 - if self.writtenBytes>self.nextByteToWrite then - cstdio.printf("WROTE TOO MUCH? writtenBytes:%d nextByteToWrite:%d totalBytes:%d\n",self.writtenBytes, self.nextByteToWrite,[Nbytes]) + if self.writtenBytes>self.nextAddrToWrite then + cstdio.printf("WROTE TOO MUCH? writtenBytes:%d nextAddrToWrite:%d totalBytes:%d\n",self.writtenBytes, self.nextAddrToWrite,[Nbytes]) cstdlib.exit(1) end + else + cstdio.printf("Internal ERROR: ready should be true? AA\n") + cstdlib.exit(1) end else valid([mod:getGlobal("IP_MAXI"..port.."_WDATA"):terraValue()]) = false + if self.ready and valid(dataIn) then + cstdio.printf("Internal ERROR: we're ready and valid, but no write is occuring?\n") + cstdlib.exit(1) + end end - if self.nextByteToWriteaddr, *(unsigned long*)(&memory[t->addr]), t->burst, QSize(&readQ[port])); + printf("MAXI%d Service Read Addr(base rel):%d data:%d/0x%x remaining_burst:%d outstanding_requests:%d\n", port, t->addr, *(unsigned long*)(&memory[t->addr]), *(unsigned long*)(&memory[t->addr]), t->burst, QSize(&readQ[port])); } t->burst--; @@ -520,7 +520,7 @@ void masterWriteDataLatchFlops( *(unsigned long*)(&memory[t->addr]) = *WDATA; if(verbose){ - printf("MAXI%d Accept Write, Addr: %d/%#x data: %d remaining_burst: %d outstanding_requests: %d\n", port, t->addr, t->addr, *WDATA, t->burst, QSize(&writeQ[port]) ); + printf("MAXI%d Accept Write, Addr(base rel): %d/%#x data: %d/%#x remaining_burst: %d outstanding_requests: %d\n", port, t->addr, t->addr, *WDATA, *WDATA, t->burst, QSize(&writeQ[port]) ); } t->burst--; diff --git a/rigel.lua b/rigel.lua index 4b3ae96..a51d2d7 100644 --- a/rigel.lua +++ b/rigel.lua @@ -51,11 +51,7 @@ function darkroom.RV(A) return types.named("RV("..tostring(A)..")",types.tuple{A,types.bool()}, "RV", {A=A}) end -function darkroom.Handshake(A) - err(types.isType(A),"Handshake: argument should be type") - err(darkroom.isBasic(A),"Handshake: argument should be basic type, but is: "..tostring(A)) - return types.named("Handshake("..tostring(A)..")", types.tuple{A,types.bool()}, "Handshake", {A=A} ) -end +darkroom.Handshake=types.Handshake function darkroom.HandshakeSparse(A) err(types.isType(A),"HandshakeSparse: argument should be type") @@ -124,31 +120,13 @@ function darkroom.isHandshakeTriggerArray( a ) return a:isNamed() and a.generato function darkroom.isHandshakeTuple( a ) return a:isNamed() and a.generator=="HandshakeTuple" end -- is this any of the handshaked types? -function darkroom.isHandshakeAny( a ) return darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeTmuxed(a) or darkroom.isHandshakeArrayOneHot(a) end +function darkroom.isHandshakeAny( a ) return darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeTmuxed(a) or darkroom.isHandshakeArrayOneHot(a) or a:is("HandshakeFramed") end function darkroom.isV( a ) return a:isNamed() and a.generator=="V" end function darkroom.isVTrigger( a ) return a:isNamed() and a.generator=="VTrigger" end function darkroom.isRV( a ) return a:isNamed() and a.generator=="RV" end function darkroom.isRVTrigger( a ) return a:isNamed() and a.generator=="RVTrigger" end -function darkroom.isBasic(A) - assert(types.isType(A)) - if A:isArray() then - return darkroom.isBasic(A:arrayOver()) - elseif A:isTuple() then - for _,v in ipairs(A.list) do - if darkroom.isBasic(v)==false then - return false - end - end - return true - elseif A:isNamed() and A.generator=="fixed" then - return true -- COMPLETE HACK, REMOVE - elseif A:isNamed() then - return false - end - - return true -end +darkroom.isBasic = types.isBasic function darkroom.expectBasic( A ) err( darkroom.isBasic(A), "type should be basic but is "..tostring(A) ) end function darkroom.expectV( A, er ) if darkroom.isV(A)==false then error(er or "type should be V but is "..tostring(A)) end end function darkroom.expectRV( A, er ) if darkroom.isRV(A)==false then error(er or "type should be RV") end end @@ -160,8 +138,8 @@ function darkroom.expectHandshake( A, er ) if darkroom.isHandshake(A)==false the -- RV(A) => {A,bool} -- Handshake(A) => {A,bool} function darkroom.lower( a, loc ) - assert(types.isType(a)) - if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isVTrigger(a) or darkroom.isRVTrigger(a) or darkroom.isRV(a) or darkroom.isV(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeArrayOneHot(a) or darkroom.isHandshakeTmuxed(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeTriggerArray(a) then + err( types.isType(a), "lower: input is not a type. is: "..tostring(a)) + if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isVTrigger(a) or darkroom.isRVTrigger(a) or darkroom.isRV(a) or darkroom.isV(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeArrayOneHot(a) or darkroom.isHandshakeTmuxed(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeTriggerArray(a) or a:is("StaticFramed") or a:is("HandshakeFramed") or a:is("VFramed") or a:is("RVFramed") or a:is("HandshakeArrayFramed") then return a.structure elseif darkroom.isBasic(a) then return a @@ -175,16 +153,16 @@ end -- RV(A) => A -- Handshake(A) => A function darkroom.extractData(a) - if darkroom.isHandshake(a) or darkroom.isV(a) or darkroom.isRV(a) then return a.params.A end + if darkroom.isHandshake(a) or darkroom.isV(a) or darkroom.isRV(a) or a:is("StaticFramed") or a:is("HandshakeFramed") or a:is("VFramed") or a:is("RVFramed") then return a.params.A end if darkroom.isHandshakeTrigger(a) or darkroom.isVTrigger(a) or darkroom.isRVTrigger(a) then return types.null() end if darkroom.isHandshakeArray(a) then return types.array2d(a.params.A,a.params.N) end return a -- pure end function darkroom.hasReady(a) - if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isRV(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeArrayOneHot(a) or darkroom.isHandshakeTmuxed(a) then + if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isRV(a) or darkroom.isHandshakeArray(a) or darkroom.isHandshakeTuple(a) or darkroom.isHandshakeArrayOneHot(a) or darkroom.isHandshakeTmuxed(a) or a:is("HandshakeFramed") or a:is("RVFramed") then return true - elseif darkroom.isBasic(a) or darkroom.isV(a) then + elseif darkroom.isBasic(a) or darkroom.isV(a) or a:is("StaticFramed") or a:is("VFramed") then return false else print("UNKNOWN READY",a) @@ -193,9 +171,7 @@ function darkroom.hasReady(a) end function darkroom.extractReady(a) - if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) then return types.bool() - elseif darkroom.isV(a) then return types.bool() - elseif darkroom.isRV(a) then return types.bool() + if darkroom.isHandshake(a) or darkroom.isHandshakeTrigger(a) or darkroom.isV(a) or darkroom.isRV(a) or a:is("HandshakeFramed") then return types.bool() elseif darkroom.isHandshakeTuple(a) then return types.array2d(types.bool(),#a.params.list) -- we always use arrays for ready bits. no reason not to. elseif darkroom.isHandshakeArrayOneHot(a) then @@ -214,11 +190,11 @@ function darkroom.extractValid(a) end function darkroom.streamCount(A) - if darkroom.isBasic(A) or darkroom.isV(A) or darkroom.isRV(A) then + if darkroom.isBasic(A) or darkroom.isV(A) or darkroom.isRV(A) or A:is("StaticFramed") or A:is("VFramed") then return 0 - elseif darkroom.isHandshake(A) or darkroom.isHandshakeTrigger(A) then + elseif darkroom.isHandshake(A) or A:is("HandshakeFramed") or darkroom.isHandshakeTrigger(A) then return 1 - elseif darkroom.isHandshakeArray(A) or darkroom.isHandshakeTriggerArray(A) then + elseif darkroom.isHandshakeArray(A) or darkroom.isHandshakeTriggerArray(A) or A:is("HandshakeArrayFramed") then return A.params.W*A.params.H elseif darkroom.isHandshakeTuple(A) then return #A.params.list @@ -312,6 +288,8 @@ local function typeToKey(t) outk="type" elseif darkroom.isFunction(v) then outk="rigelFunction" + elseif darkroom.isGlobal(v) then + outk="global" elseif type(v)=="table" and J.keycount(v)==2 and #v==2 and type(v[1])=="number" and type(v[2])=="number" then outk="size" elseif type(v)=="table" and J.keycount(v)==4 and #v==4 and type(v[1])=="number" and type(v[2])=="number" @@ -373,7 +351,7 @@ generatorMT.__call=function(tab,...) return res end else - J.err(false, "Called a generator with something other than a Rigel value or table ("..tostring(rawarg[1])..")? Make sure you call generator with curly brackets {}") + J.err(false, "Called generator '"..tab.name.."' with something other than a Rigel value or table ("..tostring(rawarg[1])..")? Make sure you call generator with curly brackets {}") end end generatorMT.__tostring=function(tab) @@ -414,9 +392,15 @@ function generatorFunctions:listArgs() end function generatorFunctions:complete(arglist) - self:checkArgs(arglist) + if self:checkArgs(arglist)==false then + print("Generator '"..self.name.."' is missing arguments!") + for k,v in pairs(self.requiredArgs) do + if self.curriedArgs[k]==nil then print("Argument '"..k.."'") end + end + assert(false) + end local mod = self.completeFn(arglist) - J.err( darkroom.isModule(mod), "generator returned something other than a rigel module?" ) + J.err( darkroom.isModule(mod), "generator '"..self.namespace.."."..self.name.."' returned something other than a rigel module?" ) mod.generator = self mod.generatorArgs = arglist return mod @@ -447,7 +431,7 @@ __index=function(tab,key) if key=="systolicModule" then -- build the systolic module as needed - assert( rawget(tab, "makeSystolic")~=nil ) + err( rawget(tab, "makeSystolic")~=nil, "missing makeSystolic() for module '"..tab.name.."'" ) local sm = rawget(tab,"makeSystolic")() assert(S.isModule(sm)) @@ -458,7 +442,7 @@ __index=function(tab,key) --err(A==B,"makeSystolic: side channels doesn't match globals") for k,_ in pairs(tab.globals) do err( sm.sideChannels[k.systolicValue]~=nil, "makeSystolic: systolic module lacks side channel for global "..k.name ) - err( k.systolicValueReady==nil or sm.sideChannels[k.systolicValueReady]~=nil, "makeSystolic: systolic module lacks side channel for global ready "..k.name ) + err( k.systolicValueReady==nil or sm.sideChannels[k.systolicValueReady]~=nil, "makeSystolic: systolic module '"..sm.name.."' lacks side channel for global ready "..k.name ) end for k,_ in pairs(sm.sideChannels) do @@ -680,8 +664,30 @@ function darkroom.newFunction(tab,X) end darkroomIRFunctions = {} +darkroomIRFunctions.isIR=true setmetatable( darkroomIRFunctions,{__index=IR.IRFunctions}) -darkroomIRMT = {__index = darkroomIRFunctions} +darkroomIRMT = {}--{__index = darkroomIRFunctions} +function darkroomIRMT.__index(tab,key) + local res = rawget(tab,key) + if res~=nil then return res end + + if darkroomIRFunctions[key]~=nil then + return darkroomIRFunctions[key] + end + + if type(key)=="number" and (darkroom.isHandshakeArray(tab.type) or darkroom.isHandshakeTuple(tab.type) or tab.type:is("HandshakeArrayFramed") ) then + print("TY",tab.type) + --assert(false) + local res = darkroom.selectStream(nil,tab,key) + rawset(tab,key,res) + return res + elseif type(key)=="number" and darkroom.isBasic(tab.type) and (tab.type:isArray() or tab.type:isTuple()) then + local G = require "generators" + local res = G.Index{key}(tab) + rawset(tab,key,res) + return res + end +end darkroomIRMT.__tostring = function(tab) if tab.kind=="apply" then @@ -721,6 +727,12 @@ function darkroom.newIR(tab) assert( type(tab) == "table" ) err( type(tab.name)=="string", "IR node "..tab.kind.." is missing name" ) err( type(tab.loc)=="string", "IR node "..tab.kind.." is missing loc" ) + err( type(tab.inputs)=="table","IR node "..tab.kind..", inputs should be table") + assert( #tab.inputs==J.keycount(tab.inputs)) + for i=1,#tab.inputs do + assert(darkroom.isIR(tab.inputs[i])) + end + IR.new( tab ) local r = setmetatable( tab, darkroomIRMT ) r:typecheck() @@ -976,10 +988,14 @@ function darkroomIRFunctions:typecheck() elseif n.kind=="statements" then n.type = n.inputs[1].type elseif n.kind=="selectStream" then - if darkroom.isHandshakeArray(n.inputs[1].type) then + if darkroom.isHandshakeArray(n.inputs[1].type) or n.inputs[1].type:is("HandshakeArrayFramed") then err( n.i < n.inputs[1].type.params.W, "selectStream index out of bounds") err( n.j==nil or (n.j < n.inputs[1].type.params.H), "selectStream index out of bounds") - n.type = darkroom.Handshake(n.inputs[1].type.params.A) + if n.inputs[1].type:is("HandshakeArrayFramed") then + n.type = types.HandshakeFramed(n.inputs[1].type.params.A,n.inputs[1].type.params.mixed,n.inputs[1].type.params.dims) + else + n.type = darkroom.Handshake(n.inputs[1].type.params.A) + end elseif darkroom.isHandshakeTriggerArray(n.inputs[1].type) then err( n.i < n.inputs[1].type.params.W, "selectStream index out of bounds") err( n.j==nil or (n.j < n.inputs[1].type.params.H), "selectStream index out of bounds") @@ -1161,13 +1177,30 @@ function darkroom.applyMethod( name, inst, fnname, input ) return darkroom.newIR( {kind = "applyMethod", name = name, fnname=fnname, loc=getloc(), inst = inst, inputs = {input} } ) end -function darkroom.constant( name, value, ty ) - err( type(name) == "string", "constant name must be string" ) - err( types.isType(ty), "constant type must be rigel type" ) - ty:checkLuaValue(value) +-- can be called either as constant(name,value,ty) or constant(value,ty) +function darkroom.constant( name, value, ty, X ) + err( X==nil, "rigel.constant: too many arguments" ) + + local res = {kind="constant", loc=getloc(), inputs = {}} + if type(name)=="string" then + res.name = name + res.value = value + res.type = ty + else + res.defaultName=true + res.name="const"..darkroom.__unnamedID + darkroom.__unnamedID = darkroom.__unnamedID+1 + res.value = name + res.type = value + err( ty==nil, "rigel.constant: too many arguments" ) + end + + err( types.isType(res.type), "rigel.constant: type must be rigel type" ) + res.type:checkLuaValue(res.value) - return darkroom.newIR( {kind="constant", name=name, loc=getloc(), value=value, type=ty, inputs = {}} ) + return darkroom.newIR( res ) end +darkroom.c = darkroom.constant function darkroom.concat( name, t, X ) local r = {kind="concat", name=name, loc=getloc(), inputs={} } @@ -1201,10 +1234,27 @@ function darkroom.concatArray2d( name, t, W, H, X ) end function darkroom.selectStream( name, input, i, X ) + + local r = {kind="selectStream"} + + if name==nil then + r.defaultName=true + r.name="selectStream"..darkroom.__unnamedID + darkroom.__unnamedID = darkroom.__unnamedID+1 + else + err( type(name)=="string", "first selectStream input should be name") + r.name=name + end + err( type(i)=="number", "i must be number") err( darkroom.isIR(input), "input must be IR") err(X==nil,"selectStream: too many arguments") - return darkroom.newIR({kind="selectStream", name=name, i=i, loc=getloc(), inputs={input}}) + + r.i=i + r.loc = getloc() + r.inputs={input} + + return darkroom.newIR(r) end function darkroom.statements( t ) @@ -1223,6 +1273,8 @@ function darkroom.handshakeMode(output) HANDSHAKE_MODE = HANDSHAKE_MODE or darkroom.isHandshakeAny(n.fn.inputType) or darkroom.isHandshakeAny(n.fn.outputType) elseif n.kind=="applyMethod" then HANDSHAKE_MODE = HANDSHAKE_MODE or darkroom.isHandshakeAny(n.inst.fn.inputType) or darkroom.isHandshakeAny(n.inst.fn.outputType) + elseif n.kind=="writeGlobal" then + HANDSHAKE_MODE = HANDSHAKE_MODE or n.global.type:is("Handshake") end end) return HANDSHAKE_MODE @@ -1231,28 +1283,7 @@ end function darkroom.export(t) if t==nil then t=_G end - -- constants - local t_c = function(arg) - J.err( type(arg)=="table", "c: argument should be table" ) - J.err( #arg==2, "c: should have 2 args" ) - - local ty, val - if types.isType(arg[1]) then - ty = arg[1] - val = arg[2] - else - ty = arg[2] - val = arg[1] - end - - local res = darkroom.constant("const"..darkroom.__unnamedID,val,ty) - res.defaultName = true - darkroom.__unnamedID = darkroom.__unnamedID+1 - - return res - end - - rawset(t,"c",t_c) + rawset(t,"c",darkroom.constant) end return darkroom diff --git a/src/common.lua b/src/common.lua index d8a5001..79a5e9c 100644 --- a/src/common.lua +++ b/src/common.lua @@ -691,4 +691,18 @@ function common.flattenTables(arr) return t end +-- flatten an array of arrays +function common.flatten(arr) + common.err( #arr==common.keycount(arr),"flatten: input must be array") + + local t = {} + for _,v in ipairs(arr) do + common.err( #v==common.keycount(v),"flatten: input must be array of arrays") + for _,vv in ipairs(v) do + table.insert(t,vv) + end + end + return t +end + return common diff --git a/src/generators.lua b/src/generators.lua index c4b8eda..bb36c71 100644 --- a/src/generators.lua +++ b/src/generators.lua @@ -21,6 +21,21 @@ function(args) end end) +-- a broadcast is basically the same as a upsample, but has no W,H +generators.BroadcastSeq = R.newGenerator("generators","BroadcastSeq",{"type","number","size"},{}, +function(args) + J.err( types.isBasic(args.type),"generators.BroadcastSeq: unsupported type: "..tostring(args.type)) + return RM.upsampleXSeq(args.type,args.number,args.size[1]*args.size[2]) +end) + +-- bits to unsigned +generators.BtoU = R.newGenerator("generators","BtoU",{"type"},{}, +function(args) + J.err( args.type:isBits(), "BtoU input type should be Bits, but is: "..tostring(args.type)) + return C.cast(args.type,types.uint(args.type.precision)) +end) + +generators.Print = R.newGenerator("generators","Print",{"type"},{"string"},function(args) return C.print(args.type,args.string) end) generators.FlattenBits = R.newGenerator("generators","FlattenBits",{"type"},{},function(args) return C.flattenBits(args.type) end) generators.PartitionBits = R.newGenerator("generators","PartitionBits",{"type","number"},{},function(args) return C.partitionBits(args.type,args.number) end) generators.Rshift = R.newGenerator("generators","Rshift",{"type"},{"number"}, function(args) return C.rshift(args.type,args.number) end) @@ -42,10 +57,19 @@ function(args) return RM.triggerCounter(args.number) end) +generators.TriggerBroadcast = R.newGenerator("generators","TriggerBroadcast",{"type","number"},{}, +function(args) + J.err( R.isHandshakeTrigger(args.type), "TriggerBroadcast: input should be HandshakeTrigger" ) + return C.triggerUp(args.number) +end) + generators.FanOut = R.newGenerator("generators","FanOut",{"type","number"},{}, function(args) - J.err( R.isHandshake(args.type),"FanOut: expected handshake input") - return RM.broadcastStream( R.extractData(args.type), args.number ) + J.err( R.isHandshake(args.type) or args.type:is("HandshakeFramed"),"FanOut: expected handshake input, but is: "..tostring(args.type)) + local mixed, dims + if args.type:is("HandshakeFramed") then mixed,dims=args.type.params.mixed,args.type.params.dims end + print("FO",mixed,dims) + return RM.broadcastStream( R.extractData(args.type), args.number, args.type:is("HandshakeFramed"), mixed, dims ) end) generators.FIFO = R.newGenerator("generators","FIFO",{"type","number"},{}, @@ -101,23 +125,26 @@ end) generators.HS = R.newGenerator("generators","HS",{"rigelFunction","type"},{}, function(args) + local mod - if R.isGenerator(args.rigelFunction) then + if R.isGenerator(args.rigelFunction) and args.type:is("HandshakeFramed") then + -- fill in args from HSF + --mod = args.rigelFunction{args.type.params.A,types.HSFV(args.type),types.HSFSize(args.type)} + mod = args.rigelFunction{ types.StaticFramed( args.type.params.A, args.type.params.mixed, args.type.params.dims ) } + elseif R.isGenerator(args.rigelFunction) then mod = args.rigelFunction{R.extractData(args.type)} else mod = args.rigelFunction end - - J.err( R.isGenerator(mod)==false, "generators.HS: input rigel function is a generator, not a module (arguments must be missing)" ) + J.err( R.isGenerator(mod)==false, "generators.HS: input rigel function is a generator, not a module (arguments must be missing)" ) J.err( R.isModule(mod), "generators.HS: input Rigel function didn't yield a Rigel module? (is "..tostring(mod)..")" ) - + return RS.HS(mod) end) generators.Linebuffer = R.newGenerator("generators","Linebuffer",{"type","size","number","bounds"},{}, function(args) - local itype if args.number==0 then itype = args.type @@ -126,18 +153,43 @@ function(args) itype = args.type:arrayOver() end - local a = C.stencilLinebuffer( itype, args.size[1], args.size[2], math.max(args.number,1), -args.bounds[1], args.bounds[2], -args.bounds[3], args.bounds[4] ) local b = C.unpackStencil( itype, args.bounds[1]+1,args.bounds[3]+1, math.max(args.number,1) ) local mod = C.compose("generators_Linebuffer_"..a.name.."_"..b.name,b,a) - + if args.number==0 then mod = C.linearPipeline({C.arrayop(itype,1,1),mod,C.index(mod.outputType,0)},"generators_Linebuffer_0wrap_"..mod.name) end - + return mod end) +generators.Stencil = R.newGenerator("generators","Linebuffer",{"type","bounds"},{}, +function(args) + + if args.type:is("StaticFramed") then + local pixelType = types.HSFPixelType(args.type) + local size = types.HSFSize(args.type) + local V = types.HSFV(args.type) + print("ST",pixelType,size[1],size[2],V) + + local a = C.stencilLinebuffer( pixelType, size[1], size[2], V, -args.bounds[1], args.bounds[2], -args.bounds[3], args.bounds[4], true ) + local b = C.unpackStencil( pixelType, args.bounds[1]+1,args.bounds[3]+1, V, nil, true, size[1], size[2] ) + local mod = C.compose("generators_Linebuffer_"..a.name.."_"..b.name,b,a) + + if args.number==0 then + mod = C.linearPipeline({C.arrayop(itype,1,1),mod,C.index(mod.outputType,0)},"generators_Linebuffer_0wrap_"..mod.name) + end + + return mod + elseif types.isBasic(args.type) then + -- fully parallel + assert(false) + else + err(false,"generators.Stencil: unsupported input type: "..tostring(args.type)) + end +end) + generators.Pad = R.newGenerator("generators","Pad",{"type","size","number","bounds"},{}, function(args) J.err(args.number>0,"NYI - V<=0") @@ -145,32 +197,111 @@ function(args) return RM.padSeq(A,args.size[1],args.size[2],args.number,args.bounds[1],args.bounds[2],args.bounds[3],args.bounds[4],0) end) -generators.Crop = R.newGenerator("generators","Crop",{"type","size","number","bounds"},{}, + +generators.CropSeq = R.newGenerator("generators","Crop",{"type","size","number","bounds"},{}, function(args) J.err(args.number>0,"NYI - V<=0") local A = args.type:arrayOver() return RM.cropSeq(A,args.size[1],args.size[2],args.number,args.bounds[1],args.bounds[2],args.bounds[3],args.bounds[4]) end) +generators.Crop = R.newGenerator("generators","Crop",{"type","bounds"},{}, +function(args) + if args.type:is("StaticFramed") then + local pixelType = types.HSFPixelType(args.type) + local size = types.HSFSize(args.type) + local V = types.HSFV(args.type) + + return RM.cropSeq(pixelType,size[1],size[2],V,args.bounds[1],args.bounds[2],args.bounds[3],args.bounds[4],true) + else + err(false,"generators.Crop: unsupported input type: "..tostring(args.type)) + end +end) + +generators.Downsample = R.newGenerator("generators","Downsample",{"type","size"},{}, +function(args) + if args.type:is("StaticFramed") then + local size = types.HSFSize(args.type) + local V = types.HSFV(args.type) + + return C.downsampleSeq( args.type:FPixelType(), args.type:FW(), args.type:FH(), args.type:FV(), args.size[1], args.size[2], true ) + else + err(false,"generators.Downsample: unsupported input type: "..tostring(args.type)) + end +end) + generators.PosSeq = R.newGenerator("generators","PosSeq",{"size","number"},{"type"}, function(args) - J.err( args.number>0, "NYI - V<=0" ) if args.type~=nil then J.err( args.type==types.null(), "PosSeq: expected null input") end return RM.posSeq(args.size[1],args.size[2],args.number) end) -generators.Map = R.newGenerator("generators","Map",{"type","rigelFunction"},{}, +generators.Pos = R.newGenerator("generators","PosSeq",{"size","number"},{"type"}, function(args) - J.err( args.type:isArray(), "generators.Map: type should be array, but is: "..tostring(args.type) ) - local mod = args.rigelFunction - if R.isGenerator(mod) then - mod = mod{args.type:arrayOver()} + if args.type~=nil then + J.err( args.type==types.null(), "PosSeq: expected null input") + end + return RM.posSeq( args.size[1], args.size[2], args.number, nil, true, true ) +end) + +-- size/bool: output size/mixed for framed types (needed if vector width/ SDF rate changes...) +-- think of this like a combined map&flatten +generators.Map = R.newGenerator("generators","Map",{"type","rigelFunction"},{"size","bool"}, +function(args) + if args.type:is("StaticFramed") then + + if args.type.params.mixed and #args.type.params.dims==1 then + local mod = args.rigelFunction + if R.isGenerator(mod) then + mod = mod{args.type.params.A:arrayOver()} + end + J.err( R.isModule(mod), "generators.Map: input didn't yield a rigel module?") + + assert( args.type.params.A.size[2]==1 ) + local res = RM.map( mod, args.type.params.A.size[1], args.type.params.A.size[2] ) + return RM.mapFramed( res, args.type.params.dims[1][1], args.type.params.dims[1][2], true ) + else + local size = types.HSFSize(args.type) + local mod = args.rigelFunction + if R.isGenerator(mod) then + mod = mod{args.type:framedOver()} + end + J.err( R.isModule(mod), "generators.Map: input didn't yield a rigel module?") + + -- fully serial + return RM.mapFramed( mod, size[1], size[2], false ) + end + elseif args.type:is("HandshakeFramed") then + if args.type.params.mixed==false and #args.type.params.dims==1 then + -- this is basically just applying a HSF wrapper + local size = types.HSFSize(args.type) + local mod = args.rigelFunction + if R.isGenerator(mod) then + mod = mod{args.type:framedOver()} + end + J.err( R.isModule(mod), "generators.Map: input didn't yield a rigel module?") + + local ow,oh + if args.size~=nil then ow,oh=args.size[1],args.size[2] end + + -- fully serial + return RM.mapFramed( mod, size[1], size[2], false, ow, oh, args.bool ) + else + J.err(false, "generators.Map: attempted to apply to value with unsupported type: "..tostring(args.type)) + end + elseif args.type:isArray() then + local mod = args.rigelFunction + if R.isGenerator(mod) then + mod = mod{args.type:arrayOver()} + end + J.err( R.isModule(mod), "generators.Map: input didn't yield a rigel module?") + + return RM.map( mod, args.type.size[1], args.type.size[2] ) + else + J.err(false, "generators.Map: attempted to apply to value with unsupported type: "..tostring(args.type)) end - J.err( R.isModule(mod), "generators.Map: input didn't yield a rigel module?") - - return RM.map( mod, args.type.size[1], args.type.size[2] ) end) generators.Reduce = R.newGenerator("generators","Reduce",{"type","rigelFunction"},{}, @@ -208,9 +339,22 @@ function(args) return RM.changeRate( args.type:arrayOver(), args.type:arrayLength()[2], args.type:arrayLength()[1], args.type:arrayLength()[1]/args.number ) end) -generators.Fwrite = R.newGenerator("generators","Fwrite",{"type","string","size"},{}, +generators.Deser = R.newGenerator("generators","Deser",{"type","number"},{}, function(args) - return RS.modules.fwriteSeq({type=args.type,filename=args.string}) + J.err( args.type:isArray(), "generators.Deser: type should be array" ) + return RM.changeRate( args.type:arrayOver(), args.type:arrayLength()[2], args.type:arrayLength()[1], args.type:arrayLength()[1]*args.number ) +end) + +generators.Fwrite = R.newGenerator("generators","Fwrite",{"type","string"},{"size"}, +function(args) + --return RS.modules.fwriteSeq({type=args.type,filename=args.string}) + return RM.fwriteSeq(args.string,args.type,args.string..".verilog.raw",true) +end) + +-- assert that the input stream is the same as the file. Early out on errors. +generators.Fassert = R.newGenerator("generators","Fassert",{"type","string"},{}, +function(args) + return C.fassert(args.string,args.type) end) generators.WriteBurst = R.newGenerator("generators","WriteBurst",{"type","string","size"},{}, @@ -233,16 +377,39 @@ function(args) return RM.lambda( args.string, input, out ) end) +generators.Generator = generators.Module + generators.AXIReadBurst = R.newGenerator("generators","AXIReadBurst",{"string","size","type","number"},{}, function(args) - return SOC.readBurst( args.string, args.size[1], args.size[2], args.type, args.number ) + return SOC.readBurst( args.string, args.size[1], args.size[2], args.type, args.number, true ) +end) + +generators.AXIRead = R.newGenerator("generators","AXIRead",{"string","type","number"},{}, +function(args) + if args.type:is("Handshake") then + return SOC.read( args.string, args.number, types.uint(8) ) + else + J.err( false, "AXIRead: unsupported input type: "..tostring(args.type)) + end end) -generators.AXIWriteBurst = R.newGenerator("generators","AXIWriteBurst",{"string","size","type","number"},{}, +generators.AXIWriteBurst = R.newGenerator("generators","AXIWriteBurst",{"string","type"},{"size","number"}, function(args) - return SOC.writeBurst( args.string, args.size[1], args.size[2], args.type, args.number ) + if args.type:is("HandshakeFramed") then + return SOC.writeBurst( args.string, args.type:FW(), args.type:FH(), args.type:FPixelType(), args.type:FV(), true ) + else + J.err( false, "AXIWriteBurst: unsupported input type: "..tostring(args.type)) + end end) +generators.WriteGlobal = R.newGenerator("generators","WriteGlobal",{"global"},{}, +function(args) + local WG = generators.Module{"WG",function(i) return R.writeGlobal("go",args.global,i) end} + WG = WG{args.global.type} + assert(R.isModule(WG)) + return WG +end) + function generators.export(t) if t==nil then t=_G end for k,v in pairs(generators) do rawset(t,k,v) end diff --git a/src/ir.lua b/src/ir.lua index 1307ebf..e8372e8 100644 --- a/src/ir.lua +++ b/src/ir.lua @@ -193,6 +193,9 @@ function IR.new(node) end function IR.isIR(v) + -- hack: if table declares itself to be IR, just trust it + if v.isIR==true then return true end + local mt = getmetatable(v) if type(mt)~="table" then return false end if type(mt.__index)~="table" then return false end diff --git a/src/modules.lua b/src/modules.lua index 7b27638..bb51352 100644 --- a/src/modules.lua +++ b/src/modules.lua @@ -505,10 +505,12 @@ modules.liftDecimate = memoize(function(f, handshakeTrigger, X) if handshakeTrigger==nil then handshakeTrigger=true end local res = {kind="liftDecimate", fn = f} - rigel.expectBasic(f.inputType) + err( types.isBasic(f.inputType) or f.inputType:is("StaticFramed"), "liftDecimate: fn input type should be basic or StaticFramed" ) if f.inputType==types.null() then res.inputType = rigel.VTrigger + elseif f.inputType:is("StaticFramed") then + res.inputType = types.VFramed( f.inputType.params.A, f.inputType.params.mixed, f.inputType.params.dims ) else res.inputType = rigel.V(f.inputType) end @@ -519,6 +521,8 @@ modules.liftDecimate = memoize(function(f, handshakeTrigger, X) res.outputType = rigel.RV(rigel.extractData(f.outputType)) elseif rigel.isVTrigger(f.outputType) then res.outputType = rigel.RVTrigger + elseif f.outputType:is("VFramed") then + res.outputType = types.RVFramed( f.outputType.params.A, f.outputType.params.mixed, f.outputType.params.dims ) else err(false, "liftDecimate: expected V output type, but is "..tostring(f.outputType)) end @@ -539,7 +543,7 @@ modules.liftDecimate = memoize(function(f, handshakeTrigger, X) end return rigel.newFunction(res) - end) +end) -- converts V->RV to RV->RV modules.RPassthrough = memoize(function(f) @@ -705,17 +709,21 @@ end modules.liftHandshake = memoize(function( f, X ) err( X==nil, "liftHandshake: too many arguments" ) local res = {kind="liftHandshake", fn=f} - err( rigel.isV(f.inputType) or rigel.isVTrigger(f.inputType), "liftHandshake: expected V or VTrigger input type") - err( rigel.isRV(f.outputType) or rigel.isRVTrigger(f.outputType),"liftHandshake: expected RV or RVTrigger output type") + err( rigel.isV(f.inputType) or rigel.isVTrigger(f.inputType) or f.inputType:is("VFramed"), "liftHandshake: expected V or VTrigger or VFramed input type") + err( rigel.isRV(f.outputType) or rigel.isRVTrigger(f.outputType) or f.outputType:is("RVFramed"),"liftHandshake: expected RV or RVTrigger or RVFramed output type") if rigel.isVTrigger(f.inputType) then res.inputType = rigel.HandshakeTrigger + elseif f.inputType:is("VFramed") then + res.inputType = types.HandshakeFramed( f.inputType.params.A, f.inputType.params.mixed, f.inputType.params.dims ) else res.inputType = rigel.Handshake(rigel.extractData(f.inputType)) end if rigel.isRVTrigger(f.outputType) then res.outputType = rigel.HandshakeTrigger + elseif f.outputType:is("RVFramed") then + res.outputType = types.HandshakeFramed( f.outputType.params.A, f.outputType.params.mixed, f.outputType.params.dims ) else res.outputType = rigel.Handshake(rigel.extractData(f.outputType)) end @@ -871,6 +879,126 @@ modules.map = memoize(function( f, W, H ) return res end) +-- vectorized: if true, treat fn as a vectorized module (ie has some array size V, where V will be part of the {w,h} loop) +-- (this is the same as 'mixed') +-- if false, just treat fn as some opaque data type +-- +-- outputW/outputH/outputMixed: if 'fn' changes SDF Rate or vector size, this allows us to override the outputW/H/Mixed +-- think of this like a combined map+flatten operation (not just a hack..) +modules.mapFramed = memoize(function( fn, w, h, vectorized, outputW, outputH, outputMixed, X ) + err( rigel.isFunction(fn), "mapFramed: first argument to map must be Rigel module, but is "..tostring(fn) ) + err( type(w)=="number", "mapFramed: width must be number") + err( type(h)=="number", "mapFramed: height must be number") + err( type(vectorized)=="boolean", "mapFramed: vectorized must be bool") + err( outputW==nil or type(outputW)=="number", "mapFramed: outputW must be number or nil" ) + err( outputH==nil or type(outputH)=="number", "mapFramed: outputH must be number or nil" ) + err( outputMixed==nil or type(outputMixed)=="boolean", "mapFramed: outputMixed must be boolean or nil" ) + err( X==nil, "mapFramed: too many arguments" ) + + local res = {kind="mapFramed",fn=fn,w=w,h=h,vectorized=vectorized,stateful=fn.stateful,delay=fn.delay} + + res.name="MapFramed_"..fn.name.."_W"..tostring(w).."_H"..tostring(h).."_vectorized"..tostring(vectorized).."_outputW"..tostring(outputW).."_outputH"..tostring(outputH).."_outputMixed"..tostring(outputMixed) + + if fn.sdfInput~=nil then + assert(#fn.sdfInput==1) + assert(#fn.sdfOutput==1) + res.sdfInput={{fn.sdfInput[1][1],fn.sdfInput[1][2]}} + res.sdfOutput={{fn.sdfOutput[1][1],fn.sdfOutput[1][2]}} + else + res.sdfInput={{1,1}} + res.sdfOutput={{1,1}} + end + + res.globals={} + for k,_ in pairs(fn.globals) do res.globals[k]=1 end + + res.globalMetadata = {} + for k,v in pairs(fn.globalMetadata) do res.globalMetadata[k]=v end + + res.inputType = fn.inputType:addDim(w,h,vectorized) + + -- tokens: # of actual data items we process. + -- make sure this ends up being consistant... (ie possible based on given vector widths/SDF) + local inTok = w*h + if vectorized and fn.inputType:is("HandshakeFramed") then + inTok = inTok/fn.inputType:FV() + elseif vectorized then + err( types.isBasic(fn.inputType) or fn.inputType:is("Handshake"), "mapFramed unsupported type: "..tostring(fn.inputType) ) + local BT = rigel.extractData(fn.inputType) + inTok = inTok/BT:channels() + end + + local outTok + if outputW==nil then + res.outputType = fn.outputType:addDim(w,h,vectorized) + outTok = w*h + if vectorized and fn.outputType:is("HandshakeFramed") then + outTok = outTok/fn.outputType:FV() + elseif vectorized then + err( types.isBasic(fn.outputType) or fn.outputType:is("Handshake"), "mapFramed unsupported type: "..tostring(fn.outputType) ) + local BT = rigel.extractData(fn.outputType) + outTok = outTok/BT:channels() + end + else + res.outputType = fn.outputType:addDim(outputW,outputH,outputMixed) + outTok = outputW*outputH + print("FNOUTPUTPTYPE",fn.outputType) + if outputMixed then outTok = outTok/fn.outputType:FV() end + end + + if fn.inputType:is("Handshake") then + -- sanity check: make sure # of tokens we say we're making is consistant with SDF + -- need to scale # of output tokens by SDF + -- outputSDF/inputSDF = (out[1]/out[2])/(in[1]/in[2]) = (out[1]*in[2])/(out[2]*in[1]) + local n,d = fn.sdfOutput[1][1]*fn.sdfInput[1][2], fn.sdfOutput[1][2], fn.sdfInput[1][1] + local SDFTok = (inTok*n)/d + + err( outTok==SDFTok, "mapFramed: error, number of input and output tokens not equal based on specified params! inTok:"..tostring(inTok).." outTok:"..tostring(outTok).." SDFTok:"..tostring(SDFTok).." inputW:"..tostring(w).." inputH:"..tostring(h).." outputW:"..tostring(outputW).." outputH:"..tostring(outputH)) + end + + function res.makeSystolic() + local sm = Ssugar.moduleConstructor(res.name) + + for k,_ in pairs(fn.globals) do + sm:addSideChannel(k.systolicValue) + if k.systolicValueReady~=nil then sm:addSideChannel(k.systolicValueReady) end + end + + --local r = S.parameter("ready_downstream",types.bool()) + --sm:addFunction( S.lambda("ready", r, r, "ready") ) + local inner = sm:add(fn.systolicModule:instantiate("inner")) + + + local CE = S.CE("process_CE") + local reset_valid = S.parameter("reset",types.bool()) + if fn.inputType:is("Handshake") and fn.outputType==types.null() then + sm:addFunction( S.lambda("ready", S.parameter("ready_downstream",types.null()), inner:ready(), "ready" ) ) + sm:onlyWire(true) + CE=nil + elseif fn.inputType:is("Handshake") and fn.outputType:is("Handshake") then + local rds = S.parameter("ready_downstream",types.bool()) + sm:addFunction( S.lambda("ready", rds, inner:ready(rds), "ready" ) ) + sm:onlyWire(true) + CE=nil + else + err(types.isBasic(fn.inputType) and types.isBasic(fn.outputType), "NYI - "..tostring(fn.inputType)..tostring(fn.outputType) ) + end + + local I = S.parameter("process_input", rigel.lower(fn.inputType) ) + sm:addFunction( S.lambda("process",I,inner:process(I),"process_output",nil,nil,CE) ) + + if fn.stateful or fn.inputType:is("Handshake") then + sm:addFunction( S.lambda("reset", S.parameter("r",types.null()), inner:reset(nil,reset_valid), "ro", nil, reset_valid) ) + end + + return sm + end + + function res.makeTerra() return MT.mapFramed(res,fn,w,h,vectorized) end + + return rigel.newFunction(res) +end) + -- type {A,bool}->A -- rate: {n,d} format frac. If coerce=true, 1/rate must be integer. -- if rate={1,16}, filterSeq will return W*H/16 pixels @@ -1106,6 +1234,8 @@ modules.downsampleXSeq = memoize(function( A, W, H, T, scale, X ) err( J.isPowerOf2(scale), "NYI - scale must be power of 2") err( X==nil, "downsampleXSeq: too many arguments" ) + err( W%scale==0,"downsampleXSeq: NYI - scale ("..tostring(scale)..") does not divide W ("..tostring(W)..")") + local sbits = math.log(scale)/math.log(2) local outputT @@ -1121,9 +1251,11 @@ modules.downsampleXSeq = memoize(function( A, W, H, T, scale, X ) local tfn, sdfOverride if scale>T then -- A[T] to A[1] - sdfOverride = {{1,scale}} + -- tricky: each token contains multiple pixels, any of of which can be valid + assert(scale%T==0) + sdfOverride = {{1,scale/T}} if terralib~=nil then tfn = MT.downsampleXSeqFn(innerInputType,outputType,scale) end - else + else -- scale <= T... for every token, we output 1 token sdfOverride = {{1,1}} if terralib~=nil then tfn = MT.downsampleXSeqFnShort(innerInputType,outputType,scale,outputT) end end @@ -1137,17 +1269,24 @@ modules.downsampleXSeq = memoize(function( A, W, H, T, scale, X ) if scale>T then -- A[T] to A[1] svalid = S.eq(S.cast(S.bitSlice(sy,0,sbits-1),types.uint(sbits)),S.constant(0,types.uint(sbits))) sdata = S.index(sinp,1) +-- print("SDATA",sdata.type) + sdata = S.index(sdata,0) +-- print("SDATA",sdata.type) + sdata = S.cast(S.tuple{sdata},types.array2d(sdata.type,1)) else svalid = S.constant(true,types.bool()) local datavar = S.index(sinp,1) sdata = J.map(J.range(0,outputT-1), function(i) return S.index(datavar, i*scale) end) sdata = S.cast(S.tuple(sdata), types.array2d(A,outputT)) end - return S.tuple{sdata,svalid} + local res = S.tuple{sdata,svalid} +-- print("DXSRES",res.type,scale,T,sinp.type,sdata.type) + return res end, function() return tfn end, nil, sdfOverride) +-- print(f) return modules.liftXYSeq( modname, "rigel.downsampleXSeq", f, W, H, T ) end) @@ -1249,25 +1388,31 @@ modules.upsampleXSeq = memoize(function( A, T, scale, X ) err( scale<65536, "upsampleXSeq: NYI - scale>=65536") err(X==nil, "upsampleXSeq: too many arguments") - if T==1 then + if T==1 or T==0 then -- special case the EZ case of taking one value and writing it out N times local res = {kind="upsampleXSeq",sdfInput={{1,scale}}, sdfOutput={{1,1}}, stateful=true, A=A, T=T, scale=scale} - local ITYPE = types.array2d(A,T) - res.inputType = ITYPE - res.outputType = rigel.RV(types.array2d(A,T)) + if T==0 then + res.inputType = A + res.outputType = rigel.RV(A) + else + local ITYPE = types.array2d(A,T) + res.inputType = ITYPE + res.outputType = rigel.RV(types.array2d(A,T)) + end + res.delay=0 res.name = verilogSanitize("UpsampleXSeq_"..tostring(A).."_T_"..tostring(T).."_scale_"..tostring(scale)) - if terralib~=nil then res.terraModule = MT.upsampleXSeq(res,A, T, scale, ITYPE ) end + if terralib~=nil then res.terraModule = MT.upsampleXSeq(res,A, T, scale, res.inputType ) end ----------------- function res.makeSystolic() local systolicModule = Ssugar.moduleConstructor(res.name) - local sinp = S.parameter( "inp", ITYPE ) + local sinp = S.parameter( "inp", res.inputType ) local sPhase = systolicModule:add( Ssugar.regByConstructor( types.uint(16), fpgamodules.sumwrap(types.uint(16),scale-1) ):CE(true):setReset(0):instantiate("phase") ) - local reg = systolicModule:add( S.module.reg( ITYPE,true ):instantiate("buffer") ) + local reg = systolicModule:add( S.module.reg( res.inputType,true ):instantiate("buffer") ) local reading = S.eq(sPhase:get(),S.constant(0,types.uint(16))):disablePipelining() local out = S.select( reading, sinp, reg:get() ) @@ -1710,10 +1855,13 @@ end) -- We do that here by modifying the valid bit combinationally!! This could potentially -- cause a combinationaly loop (validOut depends on readyDownstream) if another later unit does the opposite -- (readyUpstream depends on validIn). But I don't think we will have any units that do that?? -modules.broadcastStream = memoize(function(A,N,X) +modules.broadcastStream = memoize(function(A,N,framed,framedMixed,framedDims,X) err( types.isType(A), "broadcastStream: A must be type") rigel.expectBasic(A) err( type(N)=="number", "broadcastStream: N must be number") + if framed==nil then framed=false end + err( type(framed)=="boolean","broadcastStream: framed must be boolean") + err( framed==false or type(framedMixed)=="boolean", "broadcastStream: frameMixed should be boolean") assert(X==nil) local res = {kind="broadcastStream", A=A, N=N, stateful=false} @@ -1721,6 +1869,9 @@ modules.broadcastStream = memoize(function(A,N,X) if A==types.null() then res.inputType = rigel.HandshakeTrigger res.outputType = rigel.HandshakeTriggerArray(N) + elseif framed then + res.inputType = types.HandshakeFramed(A,framedMixed,framedDims) + res.outputType = types.HandshakeArrayFramed( A, framedMixed, framedDims, N ) else res.inputType = rigel.Handshake(A) res.outputType = rigel.HandshakeArray(A, N) @@ -1786,30 +1937,44 @@ modules.broadcastStream = memoize(function(A,N,X) end) -- output type: {uint16,uint16}[T] -modules.posSeq = memoize(function( W, H, T, bits, X ) +-- asArray: return output as u16[2][T] instead +modules.posSeq = memoize(function( W, H, T, bits, framed, asArray, X ) err(type(W)=="number","posSeq: W must be number") err(type(H)=="number","posSeq: H must be number") err(type(T)=="number","posSeq: T must be number") err(W>0, "posSeq: W must be >0"); err(H>0, "posSeq: H must be >0"); - err(T>=1, "posSeq: T must be >=1"); + err(T>=0, "posSeq: T must be >=0"); if bits==nil then bits=16 end - err( type(bits)=="number", "posSeq: bits should be number") + err( type(bits)=="number", "posSeq: bits should be number, but is: "..tostring(bits)) + if framed==nil then framed=false end + err( type(framed)=="boolean", "posSeq: framed should be boolean") + if asArray==nil then asArray=false end + err( type(asArray)=="boolean", "posSeq: asArray should be boolean") err(X==nil, "posSeq: too many arguments") local res = {kind="posSeq", T=T, W=W, H=H } res.inputType = types.null() - res.outputType = types.array2d(types.tuple({types.uint(bits),types.uint(bits)}),T) + + local sizeType = types.tuple({types.uint(bits),types.uint(bits)}) + if asArray then sizeType = types.array2d(types.uint(bits),2) end + if T==0 then + res.outputType = sizeType + if framed then res.outputType = types.StaticFramed(res.outputType,false,{{W,H}}) end + else + res.outputType = types.array2d(sizeType,T) + if framed then res.outputType = types.StaticFramed(res.outputType,true,{{W,H}}) end + end res.stateful = true res.sdfInput, res.sdfOutput = {{1,1}},{{1,1}} res.delay = 0 res.name = sanitize("PosSeq_W"..W.."_H"..H.."_T"..T.."_bits"..tostring(bits)) - if terralib~=nil then res.terraModule = MT.posSeq(res,W,H,T) end + if terralib~=nil then res.terraModule = MT.posSeq(res,W,H,T,asArray) end function res.makeSystolic() local systolicModule = Ssugar.moduleConstructor(res.name) - local posX = systolicModule:add( Ssugar.regByConstructor( types.uint(bits), fpgamodules.incIfWrap( types.uint(bits), W-T, T ) ):setInit(0):setReset(0):CE(true):instantiate("posX_posSeq") ) + local posX = systolicModule:add( Ssugar.regByConstructor( types.uint(bits), fpgamodules.incIfWrap( types.uint(bits), W-math.max(T,1), math.max(T,1) ) ):setInit(0):setReset(0):CE(true):instantiate("posX_posSeq") ) local posY = systolicModule:add( Ssugar.regByConstructor( types.uint(bits), fpgamodules.incIfWrap( types.uint(bits), H-1 ) ):setInit(0):setReset(0):CE(true):instantiate("posY_posSeq") ) local printInst @@ -1818,7 +1983,7 @@ modules.posSeq = memoize(function( W, H, T, bits, X ) printInst = systolicModule:add( S.module.print( types.tuple{types.uint(bits),types.uint(bits)}, "x %d y %d", true):instantiate("printInst") ) end - local incY = S.eq( posX:get(), S.constant(W-T,types.uint(bits)) ):disablePipelining() + local incY = S.eq( posX:get(), S.constant(W-math.max(T,1),types.uint(bits)) ):disablePipelining() local out = {S.tuple{posX:get(),posY:get()}} for i=1,T-1 do @@ -1831,7 +1996,18 @@ modules.posSeq = memoize(function( W, H, T, bits, X ) table.insert( pipelines, posY:setBy( incY ) ) if DARKROOM_VERBOSE then table.insert( pipelines, printInst:process( S.tuple{posX:get(),posY:get()}) ) end - systolicModule:addFunction( S.lambda("process", S.parameter("pinp",types.null()), S.cast(S.tuple(out),types.array2d(types.tuple{types.uint(bits),types.uint(bits)},T)), "process_output", pipelines, nil, CE ) ) + local finout + if T==0 then + finout = out[1] + if asArray then + finout = S.cast(finout,types.array2d(types.uint(bits),2)) + end + else + finout = S.cast(S.tuple(out),types.array2d(types.tuple{types.uint(bits),types.uint(bits)},T)) + assert(asArray==false) -- NYI + end + + systolicModule:addFunction( S.lambda("process", S.parameter("pinp",types.null()), finout, "process_output", pipelines, nil, CE ) ) systolicModule:addFunction( S.lambda("reset", S.parameter("r",types.null()), nil, "ro", {posX:reset(), posY:reset()}, S.parameter("reset",types.bool())) ) @@ -1891,7 +2067,7 @@ function modules.liftXYSeqPointwise( name, generatorStr, f, W, H, T, X ) end -- takes an image of size A[W,H] to size A[W-L-R,H-B-Top] -modules.cropSeq = memoize(function( A, W, H, T, L, R, B, Top, X ) +modules.cropSeq = memoize(function( A, W, H, T, L, R, B, Top, framed, X ) err( types.isType(A), "cropSeq: type must be rigel type ") err( rigel.isBasic(A),"cropSeq: expects basic type") err( type(W)=="number", "cropSeq: W must be number"); err(W>=0, "cropSeq: W must be >=0") @@ -1901,13 +2077,15 @@ modules.cropSeq = memoize(function( A, W, H, T, L, R, B, Top, X ) err( type(R)=="number", "cropSeq: R must be number"); err(R>=0, "cropSeq: R must be >=0") err( type(B)=="number", "cropSeq: B must be number"); err(B>=0, "cropSeq: B must be >=0") err( type(Top)=="number", "cropSeq: Top must be number"); err(Top>=0, "cropSeq: Top must be >=0") - + if framed==nil then framed=false end + err( type(framed)=="boolean", "cropSeq: framed must be boolean") + err(T>=1,"cropSeq T must be <1") err( L>=0, "cropSeq, L must be <0") err( R>=0, "cropSeq, R must be <0") err( W%T==0, "cropSeq, W%T must be 0") - err( L%T==0, "cropSeq, L%T must be 0") + err( L%T==0, "cropSeq, L%T must be 0. L="..tostring(L).." T="..tostring(T)) err( R%T==0, "cropSeq, R%T must be 0, R="..tostring(R)..", T="..tostring(T)) err( (W-L-R)%T==0, "cropSeq, (W-L-R)%T must be 0") err( X==nil, "cropSeq: too many arguments" ) @@ -1954,7 +2132,21 @@ modules.cropSeq = memoize(function( A, W, H, T, L, R, B, Top, X ) end, nil, {{((W-L-R)*(H-B-Top))/T,(W*H)/T}}) - return modules.liftXYSeq( modname, "rigel.cropSeq", f, W, H, T, BITS ) + local res = modules.liftXYSeq( modname, "rigel.cropSeq", f, W, H, T, BITS ) + + -- HACK + if framed then + print("CROPHACK", res.inputType, res.outputType,W,H) + --local idim = res.inputType:dims() + --idim[#idim]={W,H} + res.inputType = res.inputType:addDim(W,H,true) --types.StaticFramed(res.inputType,idim) + --local odim = rigel.extractData(res.outputType):dims() + --odim[#odim]={W-L-R,H-B-Top} + res.outputType = types.VFramed(res.outputType.params.A,true,{{W-L-R,H-B-Top}}) + print("CROPHACK", res.inputType, res.outputType) + end + + return res end) -- takes an image of size A[W,H] to size A[W-L-R,H-B-Top]. @@ -2097,6 +2289,7 @@ modules.changeRate = memoize(function(A, H, inputRate, outputRate, X) err( types.isType(A), "A should be a type") err( type(H)=="number", "H should be number") err( type(inputRate)=="number", "inputRate should be number") + err( inputRate>0, "changeRate: inputRate must be >0") err( inputRate==math.floor(inputRate), "inputRate should be integer") err( type(outputRate)=="number", "outputRate should be number, but is: "..tostring(outputRate)) err( outputRate==math.floor(outputRate), "outputRate should be integer") @@ -2104,7 +2297,7 @@ modules.changeRate = memoize(function(A, H, inputRate, outputRate, X) local maxRate = math.max(inputRate,outputRate) - err( maxRate % inputRate == 0, "maxRate % inputRate ~= 0") + err( maxRate % inputRate == 0, "maxRate ("..tostring(maxRate)..") % inputRate ("..tostring(inputRate)..") ~= 0") err( maxRate % outputRate == 0, "maxRate % outputRate ~=0") rigel.expectBasic(A) @@ -2177,10 +2370,12 @@ modules.changeRate = memoize(function(A, H, inputRate, outputRate, X) return modules.waitOnInput(rigel.newFunction(res)) end) -modules.linebuffer = memoize(function( A, w, h, T, ymin, X ) +modules.linebuffer = memoize(function( A, w, h, T, ymin, framed, X ) assert(w>0); assert(h>0); assert(ymin<=0) err(X==nil,"linebuffer: too many args!") + err( framed==nil or type(framed)=="boolean", "modules.linebuffer: framed must be bool or nil") + if framed==nil then framed=false end -- if W%T~=0, then we would potentially have to do two reads on wraparound. So don't allow this case. err( w%T==0, "Linebuffer error, W%T~=0 , W="..tostring(w).." T="..tostring(T)) @@ -2189,10 +2384,17 @@ modules.linebuffer = memoize(function( A, w, h, T, ymin, X ) rigel.expectBasic(A) res.inputType = types.array2d(A,T) res.outputType = types.array2d(A,T,-ymin+1) + + if framed then + res.inputType = types.StaticFramed(res.inputType,true,{{w,h}}) + -- this is strange for a reason: inner loop is no longer a serialized flat array, so size must change + res.outputType = types.StaticFramed(res.outputType,false,{{w/T,h}}) + end + res.stateful = true res.sdfInput, res.sdfOutput = {{1,1}},{{1,1}} res.delay = 0 - res.name = sanitize("linebuffer_w"..w.."_h"..h.."_T"..T.."_ymin"..ymin.."_A"..tostring(A)) + res.name = sanitize("linebuffer_w"..w.."_h"..h.."_T"..T.."_ymin"..ymin.."_A"..tostring(A).."_framed"..tostring(framed)) if terralib~=nil then res.terraModule = MT.linebuffer(res, A, w, h, T, ymin) end @@ -2488,11 +2690,15 @@ modules.makeHandshake = memoize(function( f, tmuxRates, nilhandshake ) end else - rigel.expectBasic(f.inputType) - rigel.expectBasic(f.outputType) + --rigel.expectBasic(f.inputType) + --rigel.expectBasic(f.outputType) + err( types.isBasic(f.inputType) or f.inputType:is("StaticFramed"),"makeHandshake: fn input type should be basic or StaticFramed") + err( types.isBasic(f.outputType) or f.outputType:is("StaticFramed"),"makeHandshake: fn output type should be basic or StaticFramed") if f.inputType==types.null() and nilhandshake==true then res.inputType = rigel.HandshakeTrigger + elseif f.inputType~=types.null() and f.inputType:is("StaticFramed") then + res.inputType = types.HandshakeFramed( f.inputType.params.A, f.inputType.params.mixed, f.inputType.params.dims ) elseif f.inputType~=types.null() then res.inputType = rigel.Handshake(f.inputType) else @@ -2501,6 +2707,8 @@ modules.makeHandshake = memoize(function( f, tmuxRates, nilhandshake ) if f.outputType==types.null() and nilhandshake==true then res.outputType = rigel.HandshakeTrigger + elseif f.outputType~=types.null() and f.outputType:is("StaticFramed") then + res.outputType = types.HandshakeFramed( f.outputType.params.A, f.outputType.params.mixed, f.outputType.params.dims ) elseif f.outputType~=types.null() then res.outputType = rigel.Handshake(f.outputType) else @@ -2518,7 +2726,7 @@ modules.makeHandshake = memoize(function( f, tmuxRates, nilhandshake ) res.sdfInput, res.sdfOutput = {{1,1}},{{1,1}} end - res.stateful = f.stateful + res.stateful = true -- for the shift register of valid bits res.name = "MakeHandshake_HST_"..tostring(nilhandshake).."_"..f.name res.globals={} @@ -2693,7 +2901,6 @@ modules.fifo = memoize(function( A, size, nostall, W, H, T, csimOnly, X ) load:setOutput( fifo:hasData(), "load_output" ) else local res = S.tuple{fifo:popFront( nil, fifo:hasData() ), fifo:hasData() } - print("LOAD_OUTPUT",res.type) load:setOutput( res, "load_output" ) end @@ -3387,10 +3594,10 @@ function modules.lambda( name, input, output, instances, generatorStr, generator -- for the non-handshake (purely systolic) modules, the ready bit doesn't flow from outputs to inputs, -- it flows from inputs to outputs. The reason is that upstream things can't stall downstream things anyway, so there's really no point of doing it the 'right' way. -- this is kind of messed up! - if rigel.isRV( fn.inputType ) then + if rigel.isRV( fn.inputType ) or fn.inputType:is("RVFramed") then assert( S.isAST(out[2]) ) local readyfn = module:addFunction( S.lambda("ready", readyInput, out[2], "ready", {} ) ) - elseif rigel.isRV( fn.outputType ) then + elseif rigel.isRV( fn.outputType ) or fn.outputType:is("RVFramed") then local readyfn = module:addFunction( S.lambda("ready", S.parameter("RINIL",types.null()), out[2], "ready", {} ) ) elseif HANDSHAKE_MODE then @@ -3425,7 +3632,7 @@ function modules.lambda( name, input, output, instances, generatorStr, generator local value = i[1] local thisi = value[parentKey] - if rigel.isHandshake(n.type) or rigel.isHandshakeTrigger(n.type) then + if rigel.isHandshake(n.type) or rigel.isHandshakeTrigger(n.type) or n.type:is("HandshakeFramed") then assert(systolicAST.isSystolicAST(thisi)) assert(thisi.type:isBool()) @@ -3569,7 +3776,7 @@ function modules.lambda( name, input, output, instances, generatorStr, generator err( #n.inputs==0 or type(res)=="table","res should be table "..n.kind.." inputs "..tostring(#n.inputs)) for k,i in ipairs(n.inputs) do - if rigel.isHandshake(i.type) then + if rigel.isHandshake(i.type) or i.type:is("HandshakeFramed") then err(systolicAST.isSystolicAST(res[k]), "incorrect output format "..n.kind.." input "..tostring(k)..", not systolic value" ) err(systolicAST.isSystolicAST(res[k]) and res[k].type:isBool(), "incorrect output format "..n.kind.." input "..tostring(k).." (type "..tostring(i.type)..", name "..i.name..") is "..tostring(res[k].type).." but expected bool, "..n.loc ) elseif rigel.isHandshakeTrigger(i.type) then @@ -3661,7 +3868,7 @@ function modules.lift( name, inputType, outputType, delay, makeSystolic, makeTer err( (outputType==types.null() and systolicOutput==nil) or systolicAST.isSystolicAST(systolicOutput), "modules.lift: makeSystolic returned something other than a systolic value (module "..name..")" ) if outputType~=nil and systolicOutput~=nil then -- user may not have passed us a type, and is instead using the systolic system to calculate it - err( systolicOutput.type==rigel.lower(outputType), "lifted systolic output type does not match. Is "..tostring(systolicOutput.type).." but should be "..tostring(outputType).." (module "..name..")" ) + err( systolicOutput.type==rigel.lower(outputType), "lifted systolic output type does not match. Is "..tostring(systolicOutput.type).." but should be "..tostring(outputType)..", which lowers to "..tostring(rigel.lower(outputType)).." (module "..name..")" ) end if systolicInstances~=nil then @@ -4178,5 +4385,4 @@ modules.triggerCounter = memoize(function(N) end) - return modules diff --git a/src/modulesTerra.t b/src/modulesTerra.t index e3bd793..4a98fe0 100644 --- a/src/modulesTerra.t +++ b/src/modulesTerra.t @@ -242,7 +242,7 @@ function MT.RPassthrough(res,f) end function MT.liftHandshake(res,f,delay) - local struct LiftHandshake{ delaysr: simmodules.fifo( rigel.lower(f.outputType):toTerraType(), delay, "liftHandshake"), + local struct LiftHandshake{ delaysr: simmodules.fifo( rigel.lower(f.outputType):toTerraType(), delay, "liftHandshake("..f.name..")"), inner: f.terraModule, ready:bool, readyDownstream:bool} terra LiftHandshake:reset() self.delaysr:reset(); self.inner:reset() end terra LiftHandshake:init() self.inner:init() end @@ -312,6 +312,44 @@ function MT.map(res,f,W,H) return MT.new(MapModule) end +function MT.mapFramed(res,f,W,H,vectorized) + local struct MapFramed {fn:f.terraModule, ready:bool, readyDownstream:bool} + + terra MapFramed:init() self.fn:init() end + + if f.stateful then + terra MapFramed:reset() self.fn:reset() end + end + + if f.inputType:is("Handshake") and f.outputType==types.null() then +-- assert(false) + terra MapFramed:process( inp : &res.inputType:toTerraType(), out : &res.outputType:toTerraType() ) + self.fn:process(inp,out) + end + + terra MapFramed:calculateReady() + self.fn:calculateReady() + self.ready = self.fn.ready + end + elseif f.inputType:is("Handshake") and f.outputType:is("Handshake") then + terra MapFramed:process( inp : &res.inputType:toTerraType(), out : &res.outputType:toTerraType() ) + self.fn:process(inp,out) + end + + terra MapFramed:calculateReady( readyDS:bool) + self.readyDownstream = readyDS + self.fn:calculateReady(readyDS) + self.ready = self.fn.ready + end + else + terra MapFramed:process( inp : &res.inputType:toTerraType(), out : &res.outputType:toTerraType() ) + self.fn:process(inp,out) + end + end + + return MT.new(MapFramed) +end + function MT.filterSeq( res, A, W,H, rate, fifoSize, coerce ) assert(type(coerce)=="boolean") @@ -714,15 +752,33 @@ function MT.broadcastStream(res,A,N) return BroadcastStream end -function MT.posSeq(res,W,H,T) +function MT.posSeq(res,W,H,T,asArray) local struct PosSeq { x:uint16, y:uint16 } terra PosSeq:reset() self.x=0; self.y=0 end - terra PosSeq:process( out : &rigel.lower(res.outputType):toTerraType() ) - for i=0,T do - (@out)[i] = {self.x,self.y} - self.x = self.x + 1 - if self.x==W then self.x=0; self.y=self.y+1 end - if self.y==H then self.y=0 end + + if T==0 then + local slf = symbol(&PosSeq) + local out = symbol(&rigel.lower(res.outputType):toTerraType()) + local assign + if asArray then + assign = quote (@out)[0] = slf.x; (@out)[1] = slf.y; end + else + assign = quote (@out) = {slf.x,slf.y} end + end + terra PosSeq.methods.process( [slf], [out] ) + [assign] + slf.x = slf.x + 1 + if slf.x==W then slf.x=0; slf.y=slf.y+1 end + if slf.y==H then slf.y=0 end + end + else + terra PosSeq:process( out : &rigel.lower(res.outputType):toTerraType() ) + for i=0,T do + (@out)[i] = {self.x,self.y} + self.x = self.x + 1 + if self.x==W then self.x=0; self.y=self.y+1 end + if self.y==H then self.y=0 end + end end end @@ -1026,7 +1082,8 @@ function MT.makeHandshake(res, f, tmuxRates, nilhandshake ) local delay = math.max( 1, f.delay ) --assert(delay>0) -- we don't need an input fifo here b/c ready is always true - local struct MakeHandshake{ delaysr: simmodules.fifo( rigel.lower(res.outputType):toTerraType(), delay, "makeHandshake"), + + local struct MakeHandshake{ delaysr: simmodules.fifo( rigel.lower(res.outputType):toTerraType(), delay, "makeHandshake("..f.name..")"), inner: f.terraModule, ready:bool, readyDownstream:bool} terra MakeHandshake:reset() self.delaysr:reset(); self.inner:reset() end @@ -1054,6 +1111,7 @@ function MT.makeHandshake(res, f, tmuxRates, nilhandshake ) var tout : rigel.lower(res.outputType):toTerraType() valid(tout) = true self.inner:process(&data(tout)) + self.delaysr:pushBack(&tout) end end @@ -1128,7 +1186,6 @@ function MT.makeHandshake(res, f, tmuxRates, nilhandshake ) valid(tout) = valid(inp); if (valid(inp)~=validFalse) or innerconst then self.inner:process(&data(inp),&data(tout)) end -- don't bother if invalid - self.delaysr:pushBack(&tout) end end @@ -1462,13 +1519,15 @@ end function MT.fwriteSeq(filename,ty,passthrough) local struct FwriteSeq { file : &cstdio.FILE } terra FwriteSeq:init() - self.file = cstdio.fopen(filename, "wb") + self.file = cstdio.fopen(filename, "wb") if self.file==nil then cstdio.perror(["Error opening "..filename.." for writing"]) end [J.darkroomAssert]( self.file~=nil, ["Error opening "..filename.." for writing"] ) end + terra FwriteSeq:free() cstdio.fclose(self.file) end + terra FwriteSeq:reset() end @@ -1599,7 +1658,7 @@ function MT.lambdaCompile(fn) -- readyInput = symbol(bool, "readyinput") end - if rigel.hasReady(fn.inputType) or rigel.isRV(fn.output.type) then + if rigel.hasReady(fn.inputType) or rigel.isRV(fn.output.type) or fn.output.type:is("RVFramed") then table.insert( Module.entries, {field="ready", type=rigel.extractReady(fn.input.type):toTerraType()} ) end @@ -1626,7 +1685,7 @@ function MT.lambdaCompile(fn) local readyOutput -- build ready calculation - if rigel.isRV(fn.output.type) or fn.output:outputStreams()>0 or rigel.streamCount(fn.inputType)>0 or rigel.handshakeMode(fn.output) then + if rigel.isRV(fn.output.type) or fn.output.type:is("RVFramed") or fn.output:outputStreams()>0 or rigel.streamCount(fn.inputType)>0 or rigel.handshakeMode(fn.output) then fn.output:visitEachReverseBreakout( function(n, args) @@ -1734,10 +1793,10 @@ return {`mself.[n.name].ready} end) end - if rigel.isRV(fn.inputType) then + if rigel.isRV(fn.inputType) or fn.inputType:is("RVFramed") then assert(readyOutput~=nil) terra Module.methods.calculateReady( [mself], [readyInput] ) mself.readyDownstream = readyInput; [readyStats]; mself.ready = readyOutput end - elseif rigel.isRV(fn.outputType) then + elseif rigel.isRV(fn.outputType) or fn.outputType:is("RVFramed") then assert(readyOutput~=nil) -- notice that we set readyInput to true here. This is kind of a hack to make the code above simpler. This should never actually be read from. terra Module.methods.calculateReady( [mself] ) var [readyInput] = true; [readyStats]; mself.ready = readyOutput end diff --git a/src/systolic.lua b/src/systolic.lua index 9d31b5a..d22ba83 100644 --- a/src/systolic.lua +++ b/src/systolic.lua @@ -730,7 +730,7 @@ function systolicASTFunctions:internalDelay() else return 0,0 -- if pipelining is disabled on an op end - elseif self.kind=="tuple" or self.kind=="fndefn" or self.kind=="parameter" or self.kind=="slice" or self.kind=="cast" or self.kind=="module" or self.kind=="constant" or self.kind=="null" or self.kind=="bitSlice" or self.kind=="readSideChannel" then + elseif self.kind=="tuple" or self.kind=="fndefn" or self.kind=="parameter" or self.kind=="slice" or self.kind=="cast" or self.kind=="module" or self.kind=="constant" or self.kind=="null" or self.kind=="bitSlice" or self.kind=="readSideChannel" or self.kind=="writeSideChannel" then return 0,0 -- purely wiring, or inputs elseif self.kind=="delay" then return 0,0 @@ -1440,7 +1440,7 @@ function userModuleFunctions:instanceToVerilog( instance, module, fnname, datava end - if fn.CE==nil and cevar~=nil then err(false, "module was given a CE, but does not expect a CE. Function '"..fnname.."' on instance '"..instance.name.."' in module '"..module.name.."' "..instance.loc) end + if fn.CE==nil and cevar~=nil then err(false, "module was given a CE, but does not expect a CE. Function '"..fnname.."' on instance '"..instance.name.."' of module '"..instance.module.name.."' inside module '"..module.name.."' "..instance.loc) end if fn.CE~=nil then err(type(cevar)=="string", "Module expected a CE, but was not given one. Function '"..fnname.."' on instance '"..instance.name.."' (of module "..instance.module.name..") inside module '"..module.name.."' "..instance.loc) @@ -1741,7 +1741,7 @@ function systolic.module.new( name, fns, instances, onlyWire, parameters, verilo for _,inst in pairs(instances) do if inst.module.sideChannels~=nil then for sc,_ in pairs(inst.module.sideChannels) do - err(SC[sc]~=nil,"systolic.module.new: Instance '"..inst.name.."' has dangling side channel '"..sc.name.."'") + err(SC[sc]~=nil,"systolic.module.new: Instance '"..inst.name.."' has dangling side channel '"..sc.name.."', when creating new module '"..name.."'") end end end diff --git a/src/types.lua b/src/types.lua index 612bf46..65429e6 100644 --- a/src/types.lua +++ b/src/types.lua @@ -100,7 +100,8 @@ function types.array2d( _type, w, h ) err(w==math.floor(w), "non integer array width "..tostring(w)) assert(h==math.floor(h)) err( _type:verilogBits()>0, "types.array2d: array type must have >0 bits" ) - + err( w*h>0,"types.array2d: w*h must be >0" ) + -- dedup the arrays local ty = setmetatable( {kind="array", over=_type, size={w,h}}, TypeMT ) return J.deepsetweak(types._array, {_type,w,h}, ty) @@ -501,7 +502,7 @@ end function TypeFunctions:isArray() return self.kind=="array" end function TypeFunctions:arrayOver() - err(self.kind=="array","arrayOver type was not an array") + err(self.kind=="array","arrayOver type was not an array, but is: "..tostring(self)) return self.over end @@ -564,6 +565,10 @@ function TypeFunctions:isUint() return self.kind=="uint" end function TypeFunctions:isBits() return self.kind=="bits" end function TypeFunctions:isNull() return self.kind=="null" end function TypeFunctions:isNamed() return self.kind=="named" end +function TypeFunctions:is(str) + if self.kind=="named" then return self.generator==str + else return self.kind==str end +end function TypeFunctions:isNumber() return self.kind=="float" or self.kind=="uint" or self.kind=="int" @@ -708,17 +713,256 @@ function TypeFunctions:toCPUType() end end +function types.isBasic(A) + assert(types.isType(A)) + if A:isArray() then + return types.isBasic(A:arrayOver()) + elseif A:isTuple() then + for _,v in ipairs(A.list) do + if types.isBasic(v)==false then + return false + end + end + return true + elseif A:isNamed() and A.generator=="fixed" then + return true -- COMPLETE HACK, REMOVE + elseif A:isNamed() then + return false + end + + return true +end + +function types.Handshake(A) + err(types.isType(A),"Handshake: argument should be type") + err(types.isBasic(A),"Handshake: argument should be basic type, but is: "..tostring(A)) + return types.named("Handshake("..tostring(A)..")", types.tuple{A,types.bool()}, "Handshake", {A=A} ) +end + +--[=[ +local function framedName(dims,Adims) + local str = "" + err( #dims>=#Adims, "Framed: number of total dims ("..#dims..") must be >= number of parallel dims ("..#Adims..")") + for i=1,#Adims do + if Adims[i][1]==dims[i][1] and Adims[i][2]==dims[i][2] then + -- parallel and serial dims match + str = str.."["..dims[i][1]..","..dims[i][2].."]" + elseif dims[i][1]>=Adims[i][1] and Adims[i][2]==1 then + print("DIM",Adims[i][1],Adims[i][2],dims[i][1],dims[i][2]) + str = str.."["..Adims[i][1]..";"..dims[i][1]..","..dims[i][2].."}" + else + err(false,"NYI - HandshakeFramed, parallel and serial dim doesn't match?") + end + end + + for i=#Adims+1,#dims do + str = str.."{"..dims[i][1]..","..dims[i][2].."}" + end + return str +end +]=] + +-- dims goes from innermost (idx 1) to outermost (idx n) +local function makeFramedType(kind,A,mixed,dims,extra0,extra1,X) + err(types.isType(A),kind.."Framed: argument should be type") + err(types.isBasic(A),kind.."Framed: argument should be basic type, but is: "..tostring(A)) + err( type(mixed)=="boolean", kind.."Framed: mixed should be boolean, but is: "..tostring(mixed)) + err( type(dims)=="table", kind.."Framed: dims should be table") + err( X==nil, kind.."Framed: too many arguments") + + -- make a deep copy, just in case + local ldims = {} + for i=1,#dims do + err( type(dims[i])=="table", kind.."Framed: each entry of dims should be a table of size 2") + err( #dims[i]==2, kind.."Framed: each entry of dims should be a table of size 2") + + err(type(dims[i][1])=="number", kind.."Framed: dim must be number") + err(math.floor(dims[i][1])==dims[i][1], kind.."Framed: dim must be integer, but is: "..tostring(dims[i][1])) + err(type(dims[i][2])=="number", kind.."Framed: dim must be number") + err(math.floor(dims[i][2])==dims[i][2], kind.."Framed: dim must be integer, but is: "..tostring(dims[i][2])) + table.insert(ldims,{dims[i][1],dims[i][2]}) + end + + err( mixed==false or A:isArray(),kind.."Framed: if mixed, input type must be an array") + + if mixed then + err(A.size[1]outermost +function types.HandshakeFramed(A,mixed,dims) return makeFramedType("Handshake",A,mixed,dims) end +function types.HandshakeArrayFramed(A,mixed,dims,W,H) return makeFramedType("HandshakeArray",A,mixed,dims,W,H) end +function types.StaticFramed(A,mixed,dims) return makeFramedType("Static",A,mixed,dims) end +function types.VFramed(A,mixed,dims) return makeFramedType("V",A,mixed,dims) end +function types.RVFramed(A,mixed,dims) return makeFramedType("RV",A,mixed,dims) end + +-- Add an extra outermost dim (loop) to the type +function TypeFunctions:addDim(w,h,mixed) + err( type(w)=="number", ":addDim w should be number") + err( type(h)=="number", ":addDim h should be number") + err( type(mixed)=="boolean", ":addDim mixed should be boolean") + err( mixed==false or types.isBasic(self) or self:is("V") or self:is("RV"), "addDim: if mixed, this must be a basic type, but is: "..tostring(self) ) + + local ldims = {} + local A + if self:is("StaticFramed") or self:is("HandshakeFramed") then + for i=1,#self.params.dims do + table.insert(ldims,{self.params.dims[i][1],self.params.dims[i][2]}) + end + elseif self:is("V") or self:is("RV") or self:is("Handshake") then + A = self.params.A + else + A = self + err(types.isBasic(self),":addDim - "..tostring(self)) + end + + table.insert(ldims,{w,h}) + + if self:is("StaticFramed") or types.isBasic(self) then + return types.StaticFramed(A,mixed,ldims) + elseif self:is("HandshakeFramed") or self:is("Handshake") then + return types.HandshakeFramed(A,mixed,ldims) + elseif self:is("V") then + return types.VFramed(A,mixed,ldims) + elseif self:is("RV") then + return types.RVFramed(A,mixed,ldims) + else + print("COULD NOT ADDDIM "..tostring(self)) + assert(false) + end +end + +--[=[ +function types.FramedCollectParallelDims(A) + assert(types.isBasic(A)) + local Adims = {} + local innerType + local function rec(ty) + if ty:isArray() then + rec(ty:arrayOver()) + table.insert(Adims,{ty.size[1],ty.size[2]}) + else + innerType = ty + end + end + rec(A) + return Adims,innerType +end + +function TypeFunctions:dims() + return types.FramedCollectParallelDims(self) +end +]=] + +-- figure out 'V' vector width setting from HandshakeFramed +-- 'V' is basically the last parallel dimension +function types.HSFV(A) + assert(types.isType(A)) + err( A:is("HandshakeFramed") or A:is("StaticFramed"), "calling HSFV on unsupported type: "..tostring(A) ) + + assert(#A.params.dims==1) + -- err(A.params.mixed, "HSFV: NYI - "..tostring(A)) + if A.params.mixed then + assert(A.params.A.size[2]==1) + return A.params.A.size[1] + else + return 0 + end +end + +function types.HSFSize(A) + assert(types.isType(A)) + assert(A:is("HandshakeFramed") or A:is("StaticFramed") ) + + assert(#A.params.dims==1) + return {A.params.dims[1][1],A.params.dims[1][2]} +end + +-- if we have HandshakeFramed(u8[8;640,480}), this will return u8 +-- if we have HandshakeFramed(u8{640,480}), this will return u8 +function types.HSFPixelType(A) + assert(types.isType(A)) + assert( A:is("HandshakeFramed") or A:is("StaticFramed") ) +-- local Adims,innerType = types.FramedCollectParallelDims(A.params.A) + + if A.params.mixed then + return A.params.A:arrayOver() + else + return A.params.A + end +end + +function TypeFunctions:FV() return types.HSFV(self) end +function TypeFunctions:FW() return types.HSFSize(self)[1] end +function TypeFunctions:FH() return types.HSFSize(self)[2] end +function TypeFunctions:FPixelType() return types.HSFPixelType(self) end + +-- this is sort of like arrayOver but for framed types +function TypeFunctions:framedOver() + assert(#self.params.dims==1) + assert(self.params.mixed==false) + + if self:is("StaticFramed") then + return self.params.A + elseif self:is("HandshakeFramed") then + return types.Handshake(self.params.A) + end + +end + if terralib~=nil then require("typesTerra") end function types.export(t) if t==nil then t=_G end rawset(t,"u",types.uint) + + for i=1,32 do + rawset(t,"u"..i,types.uint(i)) + end + rawset(t,"i",types.int) rawset(t,"b",types.bits) -- rawset(t,"bool",types.bool(false)) -- used in terra!! rawset(t,"ar",types.array2d) rawset(t,"tup",types.tuple) + rawset(t,"Handshake",types.Handshake) end return types