# CartPole Imitation Learning

* Anton Karazeev, you can text me: [```anton.karazeev@gmail.com```](mailto:anton.karazeev@phystech.edu) or [t.me/akarazeev](https://t.me/akarazeev)

# Load transitions

In [None]:
import pickle

# Unpickle expert's policy actions
# there are 100 batches with length of 200 states
with open('transitions.pkl', 'rb') as f:
    transitions = pickle.load(f)

In [None]:
import time
import math
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

import matplotlib.pyplot as plt
%matplotlib inline


def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [None]:
# Defining helper functions

def states_actions(chunk):
    """
    Divides `chunk` of history into lists of states and actions
    """
    states = list(map(lambda x: x[0], chunk))
    actions = list(map(lambda x: x[1], chunk))
    return states, actions

def random_chunk(chunk_size=50):
    """
    Samples `chunk_size` states from `transitions`
    :param chunk_size: size of history sample
    """
    full_chunk = transitions[np.random.randint(len(transitions))]
    start_index = np.random.randint(len(full_chunk) - chunk_size)
    end_index = start_index + chunk_size
    return full_chunk[start_index:end_index]

def training_batch(chunk_size=50):
    """
    :param chunk_size: size of history sampled from `transitions`
    :returns: `states_tensor` and `actions_tensor` sampled from `transitions`
    """
    states, actions = states_actions(random_chunk(chunk_size))
    states_tensor = torch.cat(states).view(-1, 1, states[0].shape[1])
    actions_tensor = torch.cat(actions)
    return Variable(states_tensor), Variable(actions_tensor)

# Feedforward Network definition 

In [None]:
# Define our simple network
# and try to learn supervised expert's policy

class FFN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(FFN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x.view(x.size(0), -1)
    
ffn = FFN(4, 128, 2)

In [None]:
criterion = nn.CrossEntropyLoss()

learning_rate = 0.01

# Training function
def train_ffn(states_tensor, actions_tensor):
    tmp_loss = 0
    
    for i in range(states_tensor.size()[0]):
        ffn.zero_grad()

        output = ffn(states_tensor[i])               # Predict action
        loss = criterion(output, actions_tensor[i])  # Calculate error
        tmp_loss += loss.data[0]
        loss.backward()                              # Error backpropagation

        for p in ffn.parameters():                   # Update net's parameters
            p.data.add_(-learning_rate, p.grad.data)
    
    return output, tmp_loss / states_tensor.size()[0]

In [None]:
# Train loop

n_iters = 100
print_every = 10

all_losses = []
start = time.time()

for iter in range(1, n_iters + 1):
    output, loss = train_ffn(*training_batch(190))
    all_losses.append(loss)

    if iter % print_every == 0:
        print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))
        
plt.plot(all_losses)
plt.ylabel('Loss')
plt.xlabel('time');

In [None]:
# Load CartPole-v0 environment

from itertools import count
import gym

env = gym.make('CartPole-v0').unwrapped

def preprocess_state(state):
    return Variable(torch.Tensor(state).view(1, 4))

In [None]:
# FFN Evaluation

n_epoch = 50

durations = []

for epoch in range(n_epoch):
    # Get initial state
    state = env.reset()
    state = preprocess_state(state)

    for t in count():
        output = ffn(state)
        action = output.data.topk(1)[1][0][0]

        state, _, done, _ = env.step(action)
        state = preprocess_state(state)

        if done:
            durations.append(t)
            break
            
plt.plot(durations)
plt.ylabel('Durations')
plt.xlabel('Iteration');

# Recurrent Neural Network definition

In [None]:
# Define our simple recurrent network
# and try to learn supervised expert's policy

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        self.input_size = input_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax()

    def forward(self, inp, hidden):
        combined = torch.cat((inp, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return Variable(torch.zeros(1, self.hidden_size))
    
    def initPrev(self):
        return Variable(torch.zeros(1, self.input_size))

rnn = RNN(4, 8, 2)

# Everything is ok
states, actions = training_batch()
hidden = rnn.initHidden()

output, hidden = rnn(states[0], hidden)

In [None]:
criterion = nn.NLLLoss()
optimizer = optim.Adam(rnn.parameters(), lr=0.01)

def train_rnn(states_tensor, actions_tensor):
    hidden = rnn.initHidden()                           # Initialize hidden state of `rnn`
    optimizer.zero_grad()

    loss = 0

    for i in range(states_tensor.size()[0]):            # Pass over whole batch
        output, hidden = rnn(states_tensor[i], hidden)  # Predict action and get next hidden state
        loss += criterion(output, actions_tensor[i])    # Compute loss at current step

    loss.backward()                                     # Error backpropagation through the whole pass 
    optimizer.step()                                    # Update parameters
    
    return output, loss.data[0] / states_tensor.size()[0]

In [None]:
# Train loop

n_iters = 700
print_every = 50

all_losses = []
start = time.time()

for iter in range(1, n_iters + 1):
    output, loss = train_rnn(*training_batch(50))
    all_losses.append(loss)
    
    if iter % print_every == 0:
        print('%s (%d %d%%) %.4f' % (timeSince(start), iter, iter / n_iters * 100, loss))

plt.plot(all_losses);

In [None]:
# Load CartPole-v0 environment

from itertools import count
import gym

env = gym.make('CartPole-v0').unwrapped

def preprocess_state(state):
    return Variable(torch.Tensor(state).view(1, 4))

In [None]:
# RNN Evaluation

n_epoch = 10

durations = []

for epoch in range(n_epoch):
    # Get initial state
    state = env.reset()
    state = preprocess_state(state)
    hidden = rnn.initHidden()

    for t in count():
        output, hidden = rnn(state, hidden)   # Predict action and get next hidden state
        action = output.data.topk(1)[1][0][0]

        state, _, done, _ = env.step(action)  # Make step
        state = preprocess_state(state)
        
        if done or t > 500:
            durations.append(t)
            break
            
plt.plot(durations)
plt.ylabel('Durations')
plt.xlabel('Iteration');