In [1]:
import numpy as np
from IPython.display import clear_output
import glob
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [2]:
ENVIRONMENT_SIZE = 16
BATCH_SIZE = 32

In [3]:
class BCDataset(Dataset):
    
    def __init__(self, npy_dir):
        files = glob.glob(os.path.join(npy_dir, "*.npy"))
        self.trajectories = []
        
        for file in files:
            self.trajectories.append(np.load(file, allow_pickle=True))
        
        self.trajectories = np.vstack(self.trajectories)
        
    def __len__(self):
        return len(self.trajectories)
    
    def __getitem__(self, idx):
        state, action = self.trajectories[idx]
#         state = [state[i] / ENVIRONMENT_SIZE if (state[i] != -1) and (i in state[0::4] or i in state[1::4]) 
#                  else state[i]
#                  for i in range(len(state))]
        return np.array(state, dtype=np.float32), action

In [4]:
training_dataset = BCDataset("expert-dir")

In [5]:
len(training_dataset)

5500

In [6]:
from collections import Counter, OrderedDict

def get_training_distribution(dataset):
    actions = Counter([dataset[i][1] for i in range(len(dataset))])
    actions = {action:1/count for action, count in actions.items()}
    actions = OrderedDict(sorted(actions.items()))
    return torch.tensor(list(actions.values()))

In [7]:
training_dataset[0]

(array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,  8.,
         1.,  1.,  1.,  6.,  0.,  1., -1.], dtype=float32), 0)

In [8]:
arr = [-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,  8.,
         1.,  1.,  1.,  6.,  0.,  1., -1.]

In [9]:
list(set([arr[i:i + 4] for i in range(0, len(arr), 4)][0])) == [-1.0]

True

In [10]:
training_dataloader = DataLoader(training_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

In [11]:
len(training_dataloader)

172

In [12]:
class BCModel(nn.Module):
    
    def __init__(self, state_size, action_size, hidden_size=256):
        super(BCModel, self).__init__()
        
        self.fc1 = nn.Linear(in_features=state_size, out_features=hidden_size)
        self.leaky_relu_1 = nn.LeakyReLU()
        self.fc2 = nn.Linear(in_features=hidden_size, out_features=hidden_size//2)
        self.leaky_relu_2 = nn.LeakyReLU()
        self.fc3 = nn.Linear(in_features=hidden_size//2, out_features=action_size)
        
    def forward(self, x):
        out = self.leaky_relu_1(self.fc1(x))
        out = self.fc2(out)
        out = self.fc3(out)
        
        return F.softmax(out, dim=1)

In [13]:
STATE_SIZE = 20
ACTION_SIZE = 5

model = BCModel(state_size=STATE_SIZE, action_size=ACTION_SIZE)

loss_fn = nn.CrossEntropyLoss(weight=get_training_distribution(training_dataset))
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [15]:
from tqdm import tqdm, trange

EPISODES = 2000

episode_losses = []
t = trange(EPISODES, desc="Episode")
for current_episode_num in t:
    current_episode_loss = 0
    
    for i, data in enumerate(training_dataloader):
        states, actions = data
        
        optimizer.zero_grad()
        
        predicted_actions = model(states.float())
        loss = loss_fn(predicted_actions, actions)
        loss.backward()
        optimizer.step()
        
        current_episode_loss += loss.item()
        
    t.set_description(f"Loss {current_episode_loss / len(training_dataloader)}")
    t.refresh()
       
    episode_losses.append(current_episode_loss / len(training_dataloader))   

Episode:   0%|          | 0/2000 [00:00<?, ?it/s]

torch.Size([])
1.1756950616836548
torch.Size([])
2.464120030403137
torch.Size([])
3.7692935466766357
torch.Size([])
5.007721662521362
torch.Size([])
6.366112232208252
torch.Size([])
7.72144615650177
torch.Size([])
9.083545088768005
torch.Size([])
10.638295292854309
torch.Size([])
11.931357026100159
torch.Size([])
13.203975915908813
torch.Size([])
14.575132131576538
torch.Size([])
15.922528743743896
torch.Size([])
17.17240023612976
torch.Size([])
18.5379581451416
torch.Size([])
19.853751182556152
torch.Size([])
21.224359273910522
torch.Size([])
22.606212258338928
torch.Size([])
23.9505934715271
torch.Size([])
25.26146912574768
torch.Size([])
26.63640260696411
torch.Size([])
27.934675335884094
torch.Size([])
29.271453380584717
torch.Size([])
30.5817654132843
torch.Size([])
31.946340918540955
torch.Size([])
33.24947190284729
torch.Size([])
34.54498541355133
torch.Size([])
35.88970696926117
torch.Size([])
37.181679368019104
torch.Size([])
38.42470180988312
torch.Size([])
39.70220756530762


Loss 1.340725653393324:   0%|          | 1/2000 [00:00<23:10,  1.44it/s]

torch.Size([])
206.21338164806366
torch.Size([])
207.7433022260666
torch.Size([])
209.15760672092438
torch.Size([])
210.48522758483887
torch.Size([])
211.755016207695
torch.Size([])
213.08936762809753
torch.Size([])
214.60628008842468
torch.Size([])
215.9959043264389
torch.Size([])
217.34203505516052
torch.Size([])
218.71398890018463
torch.Size([])
220.04672992229462
torch.Size([])
221.4770132303238
torch.Size([])
222.782271027565
torch.Size([])
224.2563591003418
torch.Size([])
225.58208763599396
torch.Size([])
226.78024423122406
torch.Size([])
228.06976521015167
torch.Size([])
229.41012740135193
torch.Size([])
230.60481238365173
torch.Size([])
1.3029226064682007
torch.Size([])
2.614096999168396
torch.Size([])
4.00787615776062
torch.Size([])
5.403172135353088
torch.Size([])
6.870807409286499
torch.Size([])
8.248309016227722
torch.Size([])
9.652184128761292
torch.Size([])
11.06454885005951
torch.Size([])
12.483851313591003
torch.Size([])
13.695143103599548
torch.Size([])
15.092195630073

Loss 1.3332971722580667:   0%|          | 2/2000 [00:01<23:18,  1.43it/s]

171.63901674747467
torch.Size([])
172.8774231672287
torch.Size([])
174.1353256702423
torch.Size([])
175.38874447345734
torch.Size([])
176.69085657596588
torch.Size([])
178.09736335277557
torch.Size([])
179.49808490276337
torch.Size([])
180.74580001831055
torch.Size([])
182.1838229894638
torch.Size([])
183.5393327474594
torch.Size([])
184.99563419818878
torch.Size([])
186.39143884181976
torch.Size([])
187.6566309928894
torch.Size([])
188.88085412979126
torch.Size([])
190.1889772415161
torch.Size([])
191.42294108867645
torch.Size([])
192.78841030597687
torch.Size([])
194.1062866449356
torch.Size([])
195.4563845396042
torch.Size([])
196.95916390419006
torch.Size([])
198.28413915634155
torch.Size([])
199.6402851343155
torch.Size([])
200.93042862415314
torch.Size([])
202.29053056240082
torch.Size([])
203.5521045923233
torch.Size([])
204.9373916387558
torch.Size([])
206.39654672145844
torch.Size([])
207.69684028625488
torch.Size([])
209.1502983570099
torch.Size([])
210.37115025520325
torch.S

Loss 1.3273590477400048:   0%|          | 3/2000 [00:02<24:37,  1.35it/s]

torch.Size([])
197.76398921012878
torch.Size([])
199.0357723236084
torch.Size([])
200.4335411787033
torch.Size([])
201.60650527477264
torch.Size([])
202.96568977832794
torch.Size([])
204.17104637622833
torch.Size([])
205.52243614196777
torch.Size([])
206.8466523885727
torch.Size([])
208.13315081596375
torch.Size([])
209.55352306365967
torch.Size([])
210.92491269111633
torch.Size([])
212.35694479942322
torch.Size([])
213.81451404094696
torch.Size([])
215.2104071378708
torch.Size([])
216.56552648544312
torch.Size([])
217.8339797258377
torch.Size([])
219.20999562740326
torch.Size([])
220.45356476306915
torch.Size([])
221.6935098171234
torch.Size([])
223.02180635929108
torch.Size([])
224.4798367023468
torch.Size([])
225.7848162651062
torch.Size([])
226.97969210147858
torch.Size([])
228.30575621128082
torch.Size([])
1.2477840185165405
torch.Size([])
2.574386239051819
torch.Size([])
3.816970705986023
torch.Size([])
5.221407055854797
torch.Size([])
6.586893320083618
torch.Size([])
7.909781217




KeyboardInterrupt: 

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

plt.figure(figsize=(15, 10))
plt.plot(episode_losses)
plt.show()

In [None]:
import time

from environment import Environment

def test(max_steps, speed=0.5, agent_pos=None, food_pos=None, render=True):
    model.eval()
    
    env = Environment(rows=16, cols=16, scope=10)

    if agent_pos != None:
        env.current_pos = env.pos(agent_pos[0], agent_pos[1])

    if food_pos != None:
        env.food = env.pos(food_pos[0], food_pos[1])

    i = 0
    success = True
    while (not env.is_done()):
        clear_output(wait=True)
        print(f"Step: {i+1}, Food: {env.consumed_count}")
        if i == max_steps or env.num_food == 0:
            success = False
            break

        state = env.get_state()
        state = torch.from_numpy(state).unsqueeze(0)
        
        with torch.no_grad():
            action_probs = model(state.float())
            print(f"Action prob: {action_probs}")
            _, action = torch.max(action_probs, 1)
            action = action.item()
            print(f"Action: {action}")

        reward = 0
        if action == 0:
            reward = env.move_up()
        elif action == 1:
            reward = env.move_left()
        elif action == 2:
            reward = env.move_down()
        elif action == 3:
            reward = env.move_right()
        elif action == 4:
            reward = env.ingest()

        if render:
            env.render()

        i += 1

        time.sleep(speed)

    return success, env.consumed_count

In [None]:
# test(max_steps=250, speed=0.1, render=True)

In [None]:
from collections import Counter

actions = []
action_num_to_name = {0: "up", 1: "left", 2: "down", 3: "right", 4: "ingest"}

for val in training_dataset:
    actions.append(val[1])
    
action_distribution = dict(Counter(actions))
action_distribution = {action:1/count for action, count in action_distribution.items()}
print(action_distribution)