In [None]:
def tj_recursion(L, W):
    """
    P(i, j) = penalty for the line W[i] .. W[j-1]
    (P(i,j) = \infty of longer than L)
    TJP(i) = smallest penalty for splitting W[0]...W[i-1]
    
    TJP(i) = min_{j=0,...,i-1}(P(j,i)+TJP(j))
    """
    def P(i ,j):
        length = sum([len(W[k]) for k in range(i, j)])
        length += j - i
        if length > L:
            return math.inf
        return (L-length)**3
    
    def TJP(i):
        if i == 0:
            return 0
        return min([TJP(j) + P(j, i) for j in range(i)])
    
    return TJP(len(W))

In [None]:
def tj_dynamic_obf(L, W):
    def P(i ,j):
        length = sum([len(W[k]) for k in range(i, j)])
        length += j - i -1
        if length > L:
            return math.inf
        return (L-length)**3
    
    def TJP(n):
        tbl = [math.inf] * (n+1)
        tbl[0] = 0
        for i in range(1, n+1):
            tbl[i] = min([tbl[j] + P(j, i) for j in range(i)])
        return tbl[n]
    
    return TJP(len(W))

In [None]:
def tj_dynamic(L, W):
    """
    time complexity is O(n^3)
    """
    n = len(W)
    tbl = [math.inf] * (n+1) # O(n)
    tbl[0] = 0
    
    for i in range(1, n+1): #O(1^2 + 2^2 + ... + n^2) = O(i^3)
        for j in range(i): #O(i + (i-1) + ... + 1) = O(i^2)
            length = i - j -1
            for k in range(j, i): #O(i - j)
                length += len(W(k))
            if length > L:
                P = math.inf
            else:
                P = (L-length)**3
            tbl[i] = min(tbl[i], tbl[j] + P)
    
    return tbl[n]

W_example = ["jars", "jaws", "joke", "jury", "juxtaposition"]

L_example = 15

tbl = [0, 1331, 216, 1, 432, 440]
split = [0,0,0,0,2,4]

$440=2^3+432$ where 432=tbl[4], (tbl[split[-1]])

$432=min\begin{cases} 11^3 + 1 = 1332 \\ 6^3+216=432 \\ 1^3+1331=1332 \end{cases}$

so we have tbl[2] for split[-2], 

lines = [W[4:5], W[2:4], W[:2]]

return lines[::-1]

complexity: O(n)

In [None]:
def tj_dynamic_new(L, W):
    """
    memorize the length, not compute each time
    time complexity is O(n^2)
    """
    n = len(W)
    tbl = [math.inf] * (n+1)
    tbl[0] = 0
    
    for i in range(1, n+1):
        length = -1
        for j in range(i - 1, -1, -1):
            length += 1 + len(W[j])

            if length > L:
                P = math.inf
            else:
                P = (L-length)**3
            tbl[i] = min(tbl[i], tbl[j] + P)
    
    return tbl[n]

In [None]:
import math

def tj_cost(L, W):

    n = len(W)
    tbl = [ math.inf ] * (n + 1)
    split  = [0] * (n + 1) # NEW!
    tbl[0] = 0

    for i in range(1, n + 1):
        
        if i == n:
            length = -1
            for word in W[split[i-1]: n]:
                length += 1 + len(word)

            if length > L:
                
                tbl[i] = tbl[i-1]
                split[i] = i-1

            else:
                # add the last word to previous line, penalty should be previous penalty minus something
                tbl[i] = tbl[i-1] - 0
                split[i] = split[i-1]
        
        else:
            length = -1
            for j in range(i-1, -1, -1):

                length += 1 + len(W[j])
                
                if length > L:
                    P = math.inf
                else:
                    P = (L - length)**3

                print('j', j, W[j], 'tbl[j]', tbl, 'tbl[i]', tbl[i], 'P', P)

                if tbl[i] > tbl[j] + P:
                    tbl[i] = tbl[j] + P
                    split[i] = j

    # print(tbl)
    # print(split)
    return tbl[n]

def tj(L, W):
    n = len(W)
    tbl = [ math.inf ] * (n + 1)
    split  = [0] * (n + 1) # NEW!
    tbl[0] = 0
    for i in range(1, n + 1):
        if i == n:
            length = -1
            for word in W[split[i-1]: n]:
                length += 1 + len(word)

            if length > L:
                tbl[i] = tbl[i-1]
                split[i] = i-1

            else:
                # add the last word to previous line, penalty should be previous penalty minus something
                tbl[i] = tbl[i-1] - 0
                split[i] = split[i-1]
        
        
        else:
            length = -1
            for j in range(i-1, -1, -1):

                length += 1 + len(W[j])
                
                if length > L:
                    P = math.inf
                else:
                    P = (L - length)**3

                if  i == n:
                    if tbl[i] > P:
                        tbl[i] = P
                        split[i] = j
                else:
                    if tbl[i] > tbl[j] + P:
                        tbl[i] = tbl[j] + P
                        split[i] = j

                # if tbl[i] > tbl[j] + P:
                #     tbl[i] = tbl[j] + P
                #     split[i] = j

    lines = []
    last = n

    # print(split)

    while last > 0:
        lines.append(" ".join(W[split[last] : last]))
        last = split[last]
        
    return "\n".join(lines[::-1])

if __name__ == "__main__":
    W_example = ["juxtaposition", "jury", "jury", "jury", "jury"]
    L_example = 15
    print(tj_cost(L_example, W_example))
    print(tj(L_example, W_example))

    # W_example = ["jars", "jaws", "joke", "jury", "abc", "juxtaposition", "a"]
    # print(tj_cost(L_example, W_example))

    # W_example = ["jars", "jaws", "joke", "jury", "juxtaposition",]
    # should print 432
    # print(tj_cost(L_example, W_example))
    # should print:
    #jars jaws
    #joke jury
    #juxtaposition
    # print(tj(L_example, W_example))