Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong write shape ICE #1481

Closed
maedoc opened this issue Oct 7, 2021 · 0 comments
Closed

Wrong write shape ICE #1481

maedoc opened this issue Oct 7, 2021 · 0 comments

Comments

@maedoc
Copy link

maedoc commented Oct 7, 2021

I tried to minimise the example but then the problem went away, so here's what I have:

module type field = {
    module R: real
    type t
    val zero : t
    val +: t -> t -> t
    val *: R.t -> t -> t
    -- dummy function to generate new non-zero values of for t
    val tab3: i64 -> i64 -> i64 -> t
}

module mk_scalar_field (R: real) = {
    module R = R
    type t = R.t
    let zero = R.i64 0
    let (+) (x:t) (y:t): t = R.(x + y)
    let (*) (a:R.t) (x:t): t = R.(a * x)
    let tab3 i j k: t = R.((i64 i)+(i64 j)+(i64 k))
}

module mk_lt (F: field) = {
    module R = F.R
    let rm1 = R.i64 (-1)
    let r0 = R.i64 0
    let r1 = R.i64 1
    let r2 = R.i64 2
    let r3 = R.i64 3
    let r4 = R.i64 4

    let len_q (N:i64): i64 = (N + 1) * (N + 2) // 2

    let gen_ml (N:i64): [](i64,i64) =
        loop ml = [(0,0)] for i < ((len_q N)-1) do
            let (m,l) = ml[i]
            let nl = if N == l then m + 1 else l + 1
            let nm = if N == l then m + 1 else m
            in ml ++ [(nm,nl)]

    let all_amm (N:i64): []R.t = iota N 
        |> map (\i -> R.i64(i + 1))
        |> map (\k -> R.((r2*k+r1)/(r2*k)))
        |> ([r1]++)
        |> scan (R.*) r1
        |> map (\el -> R.(sqrt(el/(r4*pi))))
    
    let amn (m:R.t) (n:R.t): R.t = R.(sqrt((r4*n*n - r1)/(n*n - m*m)))

    let bmn (m:R.t) (n:R.t): R.t =
        let l = R.((r2*n + r1)/(r2*n - r3))
        let r = R.(((n - r1)*(n - r1) - m*m)/(n*n - m*m))
        in R.(rm1*sqrt(l * r))
    
    let lat_grid (nlat:i64): []R.t = iota nlat
        |> map R.i64 
        |> map (\x -> R.(cos (x / (i64 nlat) * pi)))

    -- m<n
    let Lmx' [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (x:[nlat]F.t): [np1]F.t =
        let n = np1 - 1
        let X = tabulate np1 (\i -> F.zero)
        let m' = R.i64 m
        let Sx p = map2 (F.*) p x |> reduce (F.+) F.zero
        -- P^m_m
        let p0 = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx
        let X[m] = Sx p0
        -- P^m_(m + 1)
        let p1 = map2 (\cx p0 -> R.((amn m' (m' + r1))*cx*p0)) cx p0
        let X[m + 1] = Sx p1
        -- P^m_n -> P^m_n+1 -> P^m_n+2
        let p2 n p1 p0 = map3 (\cx p1 p0 -> R.((amn m' n)*cx*p1 + (bmn m' n)*p0)) cx p1 p0
        -- P^m_n
        let (X, pn, _) = match (n-m)
            case 0 -> (X, p0, p0)
            case 1 -> (X, p1, p0)
            case _ -> loop (X,p1,p0) 
                for i < (n - m - 1) do
                    let pi = p2 R.(m'+r2+i64 i) p1 p0
                    let X[m+2+i] = Sx pi
                    in (X, pi, p1)
        in X

    -- n==m and m<n
    let Lmx [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (x:[nlat]F.t): [np1]F.t =
        let n = np1 - 1
        let X = tabulate np1 (\i -> F.zero)
        let m' = R.i64 m
        let Sx p = map2 (F.*) p x |> reduce (F.+) F.zero
        -- P^m_m
        let p0 = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx
        let X[m] = Sx p0
        in
        if (n-m)==0 then X else Lmx' m np1 amm cx x

    let iLmX' [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (X:[np1]F.t): [nlat]F.t =
        let n = np1 - 1
        let x = tabulate nlat (\i -> F.zero)
        let m' = R.i64 m
        -- at each m we do x += X[m]P^_n
        let SX m x p = map2 (\xi pi -> xi F.+ (pi F.* X[m])) x p
        -- P^m_m
        let p0:[nlat]R.t = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx
        let x[:] = SX m x p0
        -- P^m_(m + 1)
        let p1 = map2 (\cx p0 -> R.((amn m' (m' + r1))*cx*p0)) cx p0
        let x[:] = SX m x p1
        -- P^m_n -> P^m_n+1 -> P^m_n+2
        let p2 n p1 p0 = map3 (\cx p1 p0 -> R.((amn m' n)*cx*p1 + (bmn m' n)*p0)) cx p1 p0
        -- P^m_n
        let (x, pn, _) = match (n-m)
            case 0 -> (x, p0, p0)
            case 1 -> (x, p1, p0)
            case _ -> loop (x,p1,p0) 
                for i < (n - m - 1) do
                    let pi = p2 R.(m'+r2+i64 i) p1 p0
                    let x[:] = SX (m+2+i) x pi
                    in (x, pi, p1)
        in x

    let iLmX [nlat] (m:i64) (np1:i64) (amm:R.t) (cx:[nlat]R.t) (X:[np1]F.t): [nlat]F.t =
        let n = np1 - 1
        let x = tabulate nlat (\i -> F.zero)
        let m' = R.i64 m
        -- at each m we do x += X[m]P^_n
        let SX m x p = map2 (\xi pi -> xi F.+ (pi F.* X[m])) x p
        -- P^m_m
        let p0:[nlat]R.t = map (\cx -> R.(amm*(r1 - cx*cx)**(m'/r2)*(rm1)**m')) cx
        let x[:] = SX m x p0
        in
        if (n-m)==0 then x else iLmX' m np1 amm cx X
        
    let lt [np1][nlon][nlat] (amm:[np1]R.t) (cx:[nlat]R.t) (x:[nlon][nlat]F.t): [np1][np1]F.t =
        map2 (\m x -> Lmx m np1 amm[m] cx x) (iota np1) x[:np1] :> [np1][np1]F.t

    let ilt [np1][nlon][nlat] (amm:[np1]R.t) (cx:[nlat]R.t) (X:[np1][np1]F.t): [nlon][nlat]F.t =
        let out = tabulate_2d nlon nlat (\_ _ -> F.zero)
        let out[:np1] = map2 (\m x -> iLmX m np1 amm[m] cx x) (iota np1) X
        in out :> [nlon][nlat]F.t

    let bench (nxfm:i64) (lmax:i64) (nlat:i64) (nlon:i64): [nxfm][nlon][nlat]F.t =
        -- lmax > nlat
        let amm = all_amm lmax
        let x = tabulate_3d nxfm nlon nlat F.tab3
        let cx = lat_grid nlat
        let X = map (lt amm cx) x
        let x' = map (ilt amm cx) X
        in x'
}

module lts = mk_lt (mk_scalar_field f32)

-- == 
-- input { 1i64 20i64 128i64 256i64 }
-- input { 8i64 20i64 128i64 256i64 }
entry main (nxfm:i64) (lmax:i64) (nlat:i64) (nlon:i64) = lts.bench nxfm lmax nlat nlon

When running on multicore it compiles without issue, but on OpenCL I see the following:

$ futhark -V
Futhark 0.21.0
git: clean-ad @ 8ec2df6 (Wed Sep 15 00:28:33 2021 +0200)
$ futhark bench --backend=opencl -r 1000 lt.fut 
...
Internal compiler error.
Please report this at https://github.com/diku-dk/futhark/issues.
Type error after pass 'In-place lowering':
In function entry_main
When checking function body
In expression of statement
  {defunc_4_map_res_15556 : ({}, [nxfm_15478][nlon_15481][nlat_15480]f32)}
in false branch
In expression of statement
  {defunc_4_map_res_18995 : ({}, [nxfm_15478][nlon_15481][nlat_15480]f32)}
in true branch
In expression of statement
  {defunc_4_map_res_16712 : ({}, [nxfm_15478][nlon_15481][nlat_15480]f32)}
In expression of statement
  {out_16963 : ({}, [nlon_15481][nlat_15480]f32)}
Type error:
WriteReturns returning iLmX_res_16887 of type [nlat_15480]f32, shape=[nlon_15481][nlat_15480], but destination array has type [nlon_15481][nlat_15480]f32
@athas athas closed this as completed in 45ad3d1 Oct 7, 2021
athas added a commit that referenced this issue Oct 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant