# Imports

In [1]:
import os
import numpy as np

from baselines.common.atari_wrappers import make_atari, wrap_deepmind
from baselines import bench

from utils.gym_wrapper import WrapPyTorch
from utils.agent import agent
from utils.config import *
import torch
import torch.nn.functional as F

from PIL import Image
import numpy as np

# Simulate Teacher Data

In [2]:
env_id = "BreakoutNoFrameskip-v4"
env = make_atari(env_id)
env = bench.Monitor(env, os.path.join("log", env_id))
env = wrap_deepmind(env, episode_life=True, clip_rewards=True, frame_stack=False, scale=True)
env = WrapPyTorch(env)
teacher = agent(env=env)
teacher.load_w()



In [3]:
MAX_TEACHER_SIMULATIONS = 5000

In [4]:
teacher_observation = []
list_of_states = []
list_of_q_values = []
observation = env.reset()

for frame_idx in range(1, MAX_TEACHER_SIMULATIONS + 1):
    teacher_observation.append(observation)
    X = torch.tensor([observation], device=device, dtype=torch.float)
    Q = teacher.model(X).data
    list_of_states.append(X)
    list_of_q_values.append(Q)
    action = teacher.get_action(observation, 0.01)

    prev_observation=observation
    observation, _, done, _ = env.step(action)
    observation = None if done else observation

    if done:
        observation = env.reset()

# Save Teacher Simulation For Create Videos

In [6]:
# save teacher simulation in TEACHER FOLDER
for index,obs in enumerate(teacher_observation):
    #Rescale to 0-255 and convert to uint8
    data = obs[0]
    rescaled = (255.0 / data.max() * (data - data.min())).astype(np.uint8)

    im = Image.fromarray(rescaled)
    im.save('TEACHER/%s.png'%str(index))

# Create Tensor From states and q_values

In [5]:
Y = torch.Tensor(MAX_TEACHER_SIMULATIONS,1,4).cuda()
torch.cat(list_of_q_values, out=Y)

tensor([[2.7094, 2.7138, 2.7128, 2.7907],
        [2.7149, 2.7421, 2.7498, 2.7578],
        [2.6924, 2.7157, 2.7606, 2.7533],
        ...,
        [2.4868, 2.4219, 2.4703, 2.4487],
        [2.4538, 2.3915, 2.4308, 2.4219],
        [2.4762, 2.4093, 2.4672, 2.4565]], device='cuda:0')

In [6]:
inputs = torch.Tensor(MAX_TEACHER_SIMULATIONS, 1,1,84,84).cuda()
torch.cat(list_of_states, out=inputs)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0.

In [7]:
x = torch.squeeze(inputs,1).view(MAX_TEACHER_SIMULATIONS,84*84)

In [8]:
x.shape

torch.Size([5000, 7056])

In [9]:
Y.shape

torch.Size([5000, 4])

# Create MSE Student

In [12]:
class MSE_Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(MSE_Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.predict = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        return x
    
class MSE_Student():
    def __init__(self, net, device):
        self.net = net
        self.device=device
        self.net = self.net.to(device)
        self.optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
        self.loss_func = torch.nn.MSELoss()
        
    def train(self,x,Y,epoch=1000):
        for t in range(epoch):
            prediction = self.net(x)     # input x and predict based on x

            loss = self.loss_func(prediction, Y)     # must be (1. nn output, 2. target)

            self.optimizer.zero_grad()   # clear gradients for next train
            loss.backward(retain_graph=True)    # backpropagation, compute gradients
            self.optimizer.step()  # apply gradients
            
    def get_action(self,s,num_actions, eps=0.1):
        with torch.no_grad():
            if np.random.random() >= eps:
                x_obs = torch.tensor([s], device=device, dtype=torch.float).view(-1,84*84)
                a = self.net(x_obs).max(1)[1].view(1, 1)
                return a.item()
            else:
                return np.random.randint(0, num_actions)
            
    def save_video(self,env,max_frames):
        observation = env.reset()
        for frame_idx in range(1, max_frames + 1):
            data_for_saved = observation[0]
            im = Image.fromarray((255.0 / data_for_saved.max() * (data_for_saved - data_for_saved.min())).astype(np.uint8))
            
            im.save('STUDENT_MSE/%s.png'%str(frame_idx))
            action = self.get_action(observation,env.action_space.n)

            observation, reward, done, _ = env.step(action)
            observation = None if done else observation

            if done:
                observation = env.reset()

In [13]:
mse_net = MSE_Net(n_feature=84*84, n_hidden=500, n_output=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [14]:
mse_student = MSE_Student(mse_net,device)
mse_student.train(x,Y,epoch=100)

In [15]:
mse_student.save_video(env,max_frames=500)

In [16]:
mse_net(x)

tensor([[2.0373, 2.0245, 2.0213, 2.0390],
        [2.0359, 2.0282, 2.0194, 2.0394],
        [2.0316, 2.0233, 2.0099, 2.0295],
        ...,
        [2.0322, 2.0286, 2.0131, 2.0355],
        [2.0306, 2.0241, 2.0174, 2.0337],
        [2.0280, 2.0246, 2.0124, 2.0351]], device='cuda:0',
       grad_fn=<AddmmBackward>)

# Create Distilled_KL Student

In [45]:
class Distilled_KL_Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Distilled_KL_Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.predict = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x, eps=0.00001):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.predict(x)             # linear output
        x = F.softmax(x)+eps
        return x
    
class Distilled_KL_Student():
    def __init__(self, net, device):
        self.net = net
        self.device=device
        self.net = self.net.to(device)
        self.optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
        self.loss_func = torch.nn.KLDivLoss()
        
    def train(self,x,Y,epoch=1000):
        for t in range(epoch):
            prediction = self.net(x)     # input x and predict based on x

            loss = self.loss_func(prediction, Y)     # must be (1. nn output, 2. target)

            self.optimizer.zero_grad()   # clear gradients for next train
            loss.backward(retain_graph=True)    # backpropagation, compute gradients
            self.optimizer.step()  # apply gradients
            
    def get_action(self,s,num_actions, eps=0.1):
        with torch.no_grad():
            if np.random.random() >= eps:
                x_obs = torch.tensor([s], device=device, dtype=torch.float).view(-1,84*84)
                a = self.net(x_obs).max(1)[1].view(1, 1)
                return a.item()
            else:
                return np.random.randint(0, num_actions)
            
    def save_video(self,env,max_frames):
        observation = env.reset()
        for frame_idx in range(1, max_frames + 1):
            data_for_saved = observation[0]
            im = Image.fromarray((255.0 / data_for_saved.max() * (data_for_saved - data_for_saved.min())).astype(np.uint8))
            
            im.save('STUDENT_DISTILLED_KL/%s.png'%str(frame_idx))
            action = self.get_action(observation,env.action_space.n)

            observation, reward, done, _ = env.step(action)
            observation = None if done else observation

            if done:
                observation = env.reset()

In [46]:
distilled_kl_net = Distilled_KL_Net(n_feature=84*84, n_hidden=500, n_output=4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
tau = 0.01
eps = 0.00001
Y_kl = F.softmax(Y/tau)+eps 

  This is separate from the ipykernel package so we can avoid doing imports until


In [48]:
distilled_kl_student = Distilled_KL_Student(distilled_kl_net,device)
distilled_kl_student.train(x,Y_kl,epoch=100)

  # Remove the CWD from sys.path while we load stuff.


In [50]:
distilled_kl_student.save_video(env,max_frames=500)

  # Remove the CWD from sys.path while we load stuff.


In [51]:
distilled_kl_net(x)

  # Remove the CWD from sys.path while we load stuff.


tensor([[0.1188, 0.0867, 0.3074, 0.4871],
        [0.1191, 0.0868, 0.3060, 0.4882],
        [0.1200, 0.0872, 0.3056, 0.4873],
        ...,
        [0.1254, 0.0877, 0.3026, 0.4843],
        [0.1250, 0.0876, 0.3022, 0.4853],
        [0.1256, 0.0878, 0.3021, 0.4845]], device='cuda:0',
       grad_fn=<AddBackward0>)