# Task & Talk
Full reimplementation of the task in [Kottur, et al. (2017)](https://arxiv.org/pdf/1706.08502.pdf). Using tabular Q-learning for now.

## Initialization and Housekeeping
Set parameters, initialize the Q-tables, and describe any helper functions we might need for post-analysis.
### Parameters

In [163]:
num_episodes = 500000
eta = 0.8
gamma = 0.95

epsilon = 1.0
max_epsilon = 1.0
min_epsilon = 0.01
decay_rate = 0.0001;

### Setting up Q-Table

In [164]:
q_vocab = 3
a_vocab = 12

num_tasks = 6
num_attributes = 4;

In [165]:
# A-State: [att * att, q_vocab, a_vocab+1, q_vocab+1] (+1 for empty vocab)
# Q-State (Utterance): [task, q_vocab+1, a_vocab+1] (+1 for empty vocab)
# Q-State (Guess): [task, q_vocab, a_vocab, q_vocab, a_vocab]
a_table = zeros(num_attributes, num_attributes, num_attributes, q_vocab, a_vocab+1, q_vocab+1, a_vocab)
q_table_utt = zeros(num_tasks, q_vocab+1, a_vocab+1, q_vocab)
q_table_guess = zeros(num_tasks, q_vocab, a_vocab, q_vocab, a_vocab, num_attributes, num_attributes)

a_visited = falses(num_attributes, num_attributes, num_attributes, q_vocab, a_vocab+1, q_vocab+1)
q_utt_visited = falses(num_tasks, q_vocab+1, a_vocab+1)
q_guess_visited = falses(num_tasks, q_vocab, a_vocab, q_vocab, a_vocab)

num_correct = 0;

## Accuracy measurements

In [166]:
num_exploits = 0;

### Helper Functions

In [167]:
function get_reward(a_state, guess, num_task)
    reward = -1
    if (num_task == 1)
        if (a_state[1:2] == guess)
            reward = 1
        end
    elseif (num_task == 2)
        if (a_state[2:3] == guess)
            reward = 1
        end
    elseif (num_task == 3)
        if ([a_state[1], a_state[3]] == guess)
            reward = 1
        end
    elseif (num_task == 4)
        if ([a_state[2], a_state[1]] == guess)
            reward = 1
        end
    elseif (num_task == 5)
        if ([a_state[3], a_state[2]] == guess)
            reward = 1
        end
    else
        if ([a_state[3], a_state[1]] == guess)
            reward = 1
        end
    end
    
    # Return 1 if bool
    return reward
end

get_reward (generic function with 2 methods)

## Q-Learning
The core loop of the program:

In [None]:
total_rewards = 0
for episode in 1:num_episodes
    tradeoff = rand() # exploration-exploitation
    explore = (tradeoff < epsilon)
    
    # Generate random task:
    q_state = [rand(1:num_tasks), q_vocab+1, a_vocab+1]
    q_states = [copy(q_state)]
    
    
    # Q-BOT
    #  -> TURN 1
    if explore | ~q_utt_visited[q_state...]
        # Explore:
        q_state[2] = rand(1:q_vocab)
    else
        # Exploit:
        options = q_table_utt[q_state[1], q_state[2], q_state[3], :]
        q_state[2] = argmax(options)
    end
    
    # Generate random object:
    a_state = [rand(1:num_attributes), rand(1:num_attributes), rand(1:num_attributes), q_state[2], a_vocab+1, q_vocab+1]
    a_states = [copy(a_state)]
    
    
    # A-BOT
    #  -> TURN 1
    if explore | ~a_visited[a_state...]
        # Explore:
        a_state[5] = rand(1:a_vocab)
    else
        # Exploit:
        options = a_table[a_state[1], a_state[2], a_state[3], a_state[4], a_state[5], a_state[6], :]
        a_state[5] = argmax(options)
    end
    
    q_state[3] = a_state[5]
    push!(q_states, copy(q_state))
    
    
    # Q-BOT
    #  -> TURN 2
    if explore | ~q_utt_visited[q_state...]
        # Explore:
        a_state[6] = rand(1:q_vocab)
    else
        # Exploit:
        options = q_table_utt[q_state[1], q_state[2], q_state[3], :]
        a_state[6] = argmax(options)
    end
    push!(a_states, copy(a_state))
    
    # Update Q-State for guessing attributes:
    q_state = [q_state[1], q_state[2], q_state[3], a_state[6], a_vocab+1]
    
    
    # A-BOT
    #  -> TURN 2
    if explore | ~a_visited[a_state...]
        # Explore:
        q_state[5] = rand(1:a_vocab)
    else
        # Exploit:
        options = a_table[a_state[1], a_state[2], a_state[3], a_state[4], a_state[5], a_state[6], :]
        q_state[5] = argmax(options)
    end
    
    
    # Q-BOT
    #  -> GUESSING PHASE
    guess = []
    if explore | ~q_guess_visited[q_state...]
        # Explore:
        guess = [rand(1:num_attributes), rand(1:num_attributes)]
    else
        # Exploit:
        options = q_table_guess[q_state[1], q_state[2], q_state[3], q_state[4], q_state[5], :, :]
        optimal_first = num_attributes + 1 # Running best attribute
        optimal_second = num_attributes + 1
        max_val = -99999
        for att1 in 1:num_attributes
            for att2 in 1:num_attributes
                if (options[att1, att2] > max_val)
                    max_val = options[att1, att2]
                    optimal_first = att1
                    optimal_second = att2
                end
            end
        end
        guess = [optimal_first, optimal_second]
    end
    
    # Update Reward Tables:
    reward = get_reward(a_state, guess)
    
    a_table[a_state[1],a_state[2],a_state[3],a_state[4],a_vocab+1, q_vocab+1, q_state[3]] += reward
    a_visited[a_state[1],a_state[2],a_state[3],a_state[4],a_vocab+1, q_vocab+1] = true
    a_table[a_state[1],a_state[2],a_state[3],a_state[4],a_state[5],a_state[6], q_state[5]] += reward
    a_visited[a_state[1],a_state[2],a_state[3],a_state[4],a_state[5],a_state[6]] = true
    
    q_table_utt[q_state[1],q_vocab+1, a_vocab+1, q_state[2]] += reward
    q_utt_visited[q_state[1],q_vocab+1, a_vocab+1] = true
    q_table_utt[q_state[1],q_state[2],q_state[3], q_state[4]] += reward
    q_utt_visited[q_state[1],q_state[2],q_state[3]] = true
    
    q_table_guess[q_state[1],q_state[2],q_state[3],q_state[4],q_state[5], guess[1], guess[2]] += reward
    q_guess_visited[q_state[1],q_state[2],q_state[3],q_state[4],q_state[5]] = true

    epsilon = min_epsilon + (max_epsilon - min_epsilon)*exp(-decay_rate*episode)
    
    if (episode % 1000 == 0)
        println(episode, " | ", epsilon)
    end
    
    if (~explore) & (episode > 250000)
#         println(a_state)
#         println(guess)
        total_rewards += reward
        num_exploits += 1
    end
end
println("Accuracy: ", total_rewards/num_exploits)

1000 | 0.9057890438555999
2000 | 0.820543445547202
3000 | 0.7434100384749007
4000 | 0.6736168455752829
5000 | 0.6104653531155071
6000 | 0.5533235197330862
7000 | 0.5016194507534953
8000 | 0.45483567447604933
9000 | 0.4125039631431931
10000 | 0.3742006467597279
11000 | 0.33954237286109873
12000 | 0.30818226979308005
13000 | 0.27980647510367246
14000 | 0.2541309943021904
15000 | 0.23089885854694553
16000 | 0.20987755281470882
17000 | 0.1908566888122073
18000 | 0.17364589933937066
19000 | 0.1580729330304087
20000 | 0.1439819304042466
21000 | 0.13123186397045208
22000 | 0.11969512677871053
23000 | 0.10925625528557566
24000 | 0.09981077375651838
25000 | 0.0912641486376598
26000 | 0.08353084243219053
27000 | 0.07653345761235225
28000 | 0.07020196199896576
29000 | 0.06447298785584314
30000 | 0.05928919768418531
31000 | 0.05459871036962222
32000 | 0.05035458193858254
33000 | 0.046514335727227595
34000 | 0.04303953726072281
35000 | 0.039895409588095315
36000 | 0.03705048522281963
37000 | 0.0344

282000 | 0.010000000000560443
283000 | 0.010000000000507108
284000 | 0.010000000000458852
285000 | 0.010000000000415185
286000 | 0.010000000000375675
287000 | 0.010000000000339926
288000 | 0.010000000000307577
289000 | 0.010000000000278307
290000 | 0.010000000000251823
291000 | 0.01000000000022786
292000 | 0.010000000000206176
293000 | 0.010000000000186556
294000 | 0.010000000000168803
295000 | 0.01000000000015274
296000 | 0.010000000000138204
297000 | 0.010000000000125051
298000 | 0.010000000000113151
299000 | 0.010000000000102384
300000 | 0.010000000000092641
301000 | 0.010000000000083826
302000 | 0.010000000000075848
303000 | 0.01000000000006863
304000 | 0.010000000000062098
305000 | 0.01000000000005619
306000 | 0.010000000000050843
307000 | 0.010000000000046003
308000 | 0.010000000000041627
309000 | 0.010000000000037665
310000 | 0.01000000000003408
311000 | 0.010000000000030838
312000 | 0.010000000000027903
313000 | 0.010000000000025247
314000 | 0.010000000000022845
315000 | 0.0100