In [20]:
from collections import defaultdict
from itertools import product
import numpy as np

def discrete_uniform_sum_pmf(a, b, n):
    du_pmf = {i: 1/(b-a+1) for i in range(a, b+1)}
    du_sum_pmf = {0: 1}

    for i in range(n):
        new_sum_pmf = defaultdict(float)
        for prev_sum, dice in product(du_sum_pmf, du_pmf):
            new_sum_pmf[prev_sum + dice] += du_sum_pmf[prev_sum] * du_pmf[dice]
        du_sum_pmf = new_sum_pmf

    return du_sum_pmf

def discrete_uniform_sum_pmf_count(a, b, n):
    du_pmf = {i: 1/(b-a+1) for i in range(a, b+1)}
    du_sum_pmf = {0: 1}

    for i in range(n):
        new_sum_pmf = defaultdict(int)
        for prev_sum, dice in product(du_sum_pmf, du_pmf):
            new_sum_pmf[prev_sum + dice] += du_sum_pmf[prev_sum]
        du_sum_pmf = new_sum_pmf

    return du_sum_pmf

def discrete_uniform_sum_pmf_2(a, b, n):
    du_sum_pmf = {0: 1}
    #print((b-a+1)**n)
    for i in range(n):
        #print(du_sum_pmf)
        new_sum_pmf = defaultdict(int)
        for prev_sum, dice in product(du_sum_pmf, range(a, b+1)):
            new_sum_pmf[prev_sum + dice] += du_sum_pmf[prev_sum]
        du_sum_pmf = new_sum_pmf

    return du_sum_pmf

def discrete_uniform_sum_pmf_3(a, b, n):
    du_sum_pmf = [1]
    for i in range(1, n+1):
        new_sum_pmf = [0] * (i * (b - a) + 1)
        for (j, prev_count), dice in product(enumerate(du_sum_pmf), range(a, b+1)):
            new_sum_pmf[j + dice - a] += prev_count
        du_sum_pmf = new_sum_pmf
    return (a * n, du_sum_pmf)

def discrete_uniform_sum_pmf_4(a: int, b: int, n: int):
    final_size = (n * (b - a) + 1)
    du_sum_pmf = [0] * final_size
    du_sum_pmf[0] = 1

    temp_sum_pmf = [0] * final_size
    for i in range(1, n+1):
        
        for j, dice in product(range((i-1) * (b - a) + 1), range(a, b+1)):
            temp_sum_pmf[j + dice - a] += du_sum_pmf[j]
        
        (du_sum_pmf, temp_sum_pmf) = (temp_sum_pmf, du_sum_pmf)
        for j in range(i * (b - a) + 1):
            temp_sum_pmf[j] = 0

    return (a * n, du_sum_pmf)

def discrete_uniform_sum_pmf_5(a: int, b: int, n: int):
    final_size = (n * (b - a) + 1)

    du_sum_pmf = np.zeros(final_size, dtype='int64')
    du_sum_pmf[0] = 1

    temp_sum_pmf = np.zeros(final_size, dtype='int64')
    for i in range(1, n+1):
        
        prior_length = (i-1) * (b - a) + 1
        for dice in range(a, b+1):
            temp_index = dice - a
            temp_sum_pmf[temp_index:temp_index + prior_length] += du_sum_pmf[:prior_length]
        
        (du_sum_pmf, temp_sum_pmf) = (temp_sum_pmf, du_sum_pmf)
        temp_sum_pmf[:i * (b - a) + 1] = 0

    return (a * n, du_sum_pmf)

def discrete_uniform_sum_pmf_6(a: int, b: int, n: int):
    du_sum_pmf = np.array([1], dtype='int64')
    chances = np.full((b-a+1), 1, dtype='int64')

    for i in range(n):
        tmp = np.convolve(du_sum_pmf, chances)
        du_sum_pmf = tmp

    return (a * n, du_sum_pmf)

def discrete_uniform_sum_pmf_6_fp(a: int, b: int, n: int):
    du_sum_pmf = np.array([1], dtype='int64')
    chances = np.full((b-a+1), 1/(b-a+1))

    for i in range(n):
        tmp = np.convolve(du_sum_pmf, chances)
        tmp /= tmp.sum()
        du_sum_pmf = tmp

    return (a * n, du_sum_pmf)

min = 0
max = 4
rolls = 7
print((max-min+1)**rolls)
print(discrete_uniform_sum_pmf_count(min, max, rolls))
#print(discrete_uniform_sum_pmf(1, 5, rolls))

print(discrete_uniform_sum_pmf_2(min, max, rolls))
print(discrete_uniform_sum_pmf_3(min, max, rolls))
print(discrete_uniform_sum_pmf_4(min, max, rolls))
print(discrete_uniform_sum_pmf_5(min, max, rolls))
print(discrete_uniform_sum_pmf_6(min, max, rolls))
print(discrete_uniform_sum_pmf_6_fp(min, max, rolls))


for tmin, tsize, trolls in product(range(100), range(10), range(1,10+1)):
    correct = True
    offset, counts = discrete_uniform_sum_pmf_6(tmin, tmin + tsize, trolls)
    for amt, count in discrete_uniform_sum_pmf_count(tmin, tmin + tsize, trolls).items():
        if counts[amt - offset] != count:
            print(counts[amt - offset], count)
            correct = False

print(correct)



78125
defaultdict(<class 'int'>, {0: 1, 1: 7, 2: 28, 3: 84, 4: 210, 5: 455, 6: 875, 7: 1520, 8: 2415, 9: 3535, 10: 4795, 11: 6055, 12: 7140, 13: 7875, 14: 8135, 15: 7875, 16: 7140, 17: 6055, 18: 4795, 19: 3535, 20: 2415, 21: 1520, 22: 875, 23: 455, 24: 210, 25: 84, 26: 28, 27: 7, 28: 1})
defaultdict(<class 'int'>, {0: 1, 1: 7, 2: 28, 3: 84, 4: 210, 5: 455, 6: 875, 7: 1520, 8: 2415, 9: 3535, 10: 4795, 11: 6055, 12: 7140, 13: 7875, 14: 8135, 15: 7875, 16: 7140, 17: 6055, 18: 4795, 19: 3535, 20: 2415, 21: 1520, 22: 875, 23: 455, 24: 210, 25: 84, 26: 28, 27: 7, 28: 1})
(0, [1, 7, 28, 84, 210, 455, 875, 1520, 2415, 3535, 4795, 6055, 7140, 7875, 8135, 7875, 7140, 6055, 4795, 3535, 2415, 1520, 875, 455, 210, 84, 28, 7, 1])
(0, [1, 7, 28, 84, 210, 455, 875, 1520, 2415, 3535, 4795, 6055, 7140, 7875, 8135, 7875, 7140, 6055, 4795, 3535, 2415, 1520, 875, 455, 210, 84, 28, 7, 1])
(0, array([   1,    7,   28,   84,  210,  455,  875, 1520, 2415, 3535, 4795,
       6055, 7140, 7875, 8135, 7875, 7140, 