In [3]:
import math
import random
import pdb
import copy
from typing import *

import collections as cc
import sortedcontainers as sc
import itertools as it
import functools as ft

import einops as eo
import scipy as sp
import numpy as np
import numpy.random as npr
import sklearn as skl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sb
import torch as tc
import torch.nn as tcn
import torch.nn.functional as tcf
import torch.optim as tco
import torch.distributions as tcd
import einops.layers.torch as eol
from torch.utils.data import TensorDataset, DataLoader

pd.options.display.max_rows = 40
pd.options.display.min_rows = 20
pd.options.display.max_columns = 100

from IPython.display import display, HTML, clear_output

%matplotlib inline

import gym
import gym.spaces

## Global Configs

In [70]:
class Config:
	def __init__(self, _epoch = 0, _debug = True):
		self.maxSteps = 32
		self.lamb = 0.9
		self.miniBatchSize = 64
		self.delayBatchs = 128
		self.totalEpochs = 50
		self.testEpochs = 4
		self.initEpochs = 20
		self.epoch = _epoch
		self.lr = 1E-3
		self.initlr = 3E-3
		self.modelPath = "./checkpoints/model"
		self.debug = _debug
		self.device = tc.device("cpu") if _debug or not tc.cuda.is_available() else tc.device("cuda")
		self.mapName = "4x4"
		self.stateEstimationWeight = 0.5
		self.trainingRandomInitState = True
		print(self)
	
	def stepCheckpoint(self, _epoch = None):
		if _epoch is not None:
			self.epoch = _epoch
		return "{path}_{epoch:02d}.bin".format(path = self.modelPath, epoch = self.epoch)


In [55]:
class ReplayEnvironment:
	def __init__(self, config: Config):
		self.name = "FrozenLake-v1"
		self.instance = gym.make(self.name, map_name = config.mapName)
		self.actionSpace = self.instance.action_space.n
		self.stateSpace = self.instance.observation_space.n
		self.config = config
		self.transProbs, self.transStates, self.transRewards, self.transEnds = self._convertTransMat(3, self.instance.P)
		self.instance.render()
	
	def __str__(self):
		return "ReplayEnvironment: " + self.name
	
	# (prob, nstate, rewards, end)
	@tc.no_grad()
	def _convertTransMat(self, probSpace: int, transDict: Dict[int, Dict[int, List[Tuple[float, int, float, bool]]]]) -> Tuple[tc.FloatTensor, tc.IntTensor, tc.FloatTensor, tc.BoolTensor]:
		self.probSpace = probSpace
		weights_ = tc.zeros((self.stateSpace, self.actionSpace, probSpace), dtype = tc.float)
		rewards_ = tc.zeros_like(weights_)
		nstates_ = tc.zeros((self.stateSpace, self.actionSpace, probSpace), dtype = tc.int)
		ends_ = tc.zeros((self.stateSpace, self.actionSpace, probSpace), dtype = tc.bool)

		for k1, s1 in transDict.items():
			for k2, s2 in s1.items():
				weights_[k1, k2] = tc.as_tensor([p[0] for p in it.islice(s2, probSpace)], dtype = tc.float)
				rewards_[k1, k2] = tc.as_tensor([p[2] for p in it.islice(s2, probSpace)], dtype = tc.float)
				nstates_[k1, k2] = tc.as_tensor([p[1] for p in it.islice(s2, probSpace)], dtype = tc.int)
				ends_[k1, k2] = tc.as_tensor([p[3] for p in it.islice(s2, probSpace)], dtype = tc.bool)
		
		return weights_, nstates_, rewards_, ends_
	
	# (newState, reward, done)
	@tc.no_grad()
	def step(self, states: tc.IntTensor, actions: tc.IntTensor) -> Tuple[tc.Tensor, tc.Tensor, tc.Tensor]:
		actions = actions.unsqueeze(1).unsqueeze(2).expand((-1, -1, self.probSpace))
		states = states.long()
		weights_ = self.transProbs[states].take_along_dim(actions, dim = 1).squeeze(dim = 1)
		rewards_ = self.transRewards[states].take_along_dim(actions, dim = 1).squeeze(dim = 1)
		nstates_ = self.transStates[states].take_along_dim(actions, dim = 1).squeeze(dim = 1)
		ends_ = self.transEnds[states].take_along_dim(actions, dim = 1).squeeze(dim = 1)

		idxes_ = tc.multinomial(weights_, 1, True)
		return nstates_.take_along_dim(idxes_, dim = 1).squeeze(dim = 1), rewards_.take_along_dim(idxes_, dim = 1).squeeze(dim = 1), ends_.take_along_dim(idxes_, dim = 1).squeeze(dim = 1)
			

In [16]:
class ActorModel(tcn.Module):
	def __init__(self, env: ReplayEnvironment):
		super().__init__()

		stateSize_ = env.stateSpace
		actionSize_ = env.actionSpace

		self.baseModel = tcn.Sequential(
			tcn.Linear(stateSize_, stateSize_ * 2),
			tcn.GELU(),
			tcn.Linear(stateSize_ * 2, stateSize_ * 2),
			tcn.GELU()
		)

		self.actionHead = tcn.Sequential(
			tcn.Linear(stateSize_ * 2, stateSize_ * 2),
			tcn.GELU(),
			tcn.Linear(stateSize_ * 2, actionSize_),
		)
		
		self.valueHead = tcn.Linear(stateSize_ * 2, 1)
		self.stateSize = stateSize_
	
	def decayParameters(self):
		return map(lambda x: x[1], filter(lambda x: "bias" not in x[0], self.named_parameters()))

	def nondecayParameters(self):
		return map(lambda x: x[1], filter(lambda x: "bias" in x[0], self.named_parameters()))
	
	def save(self, path):
		tc.save(self.state_dict(), path)

	def load(self, path):
		self.load_state_dict(tc.load(path))

	def forward(self, states: tc.Tensor, invtemp: float = 1.0) -> tc.Tensor:
		baseOutput_ = self.baseModel(tcf.one_hot(states.long(), self.stateSize).float())
		logits_ = self.actionHead(baseOutput_)
		actionDist_ = tcd.Categorical(tcf.softmax(logits_ * invtemp, dim = 1))
		stateValue_ = self.valueHead(baseOutput_).squeeze()
		return actionDist_, stateValue_

In [32]:
def initializePolicy(model: ActorModel, env: ReplayEnvironment, config: Config):
	opt_ = tco.AdamW([{"params": model.decayParameters(), "weight_decay": 0.01}, {"params": model.nondecayParameters(), "weight_decay": 0.0}], config.initlr)
	states_ = tc.arange(0, env.stateSpace, 1, dtype = tc.int)
	model.eval()
	with tc.no_grad():
		_, target_ = model(states_)
		target_ -= target_.min()
		target_ /= target_.max()

	model.train()
	for i in range(config.initEpochs):
		actionDists_, preds_ = model(states_)
		loss_ = tcf.mse_loss(preds_, target_)
		opt_.zero_grad()
		loss_.backward()
		opt_.step()
	print("initialization loss: {0}".format(loss_.item()))
	

In [43]:
# returns [states, values, actions, advantages]
@tc.no_grad()
def generatePaths(model: ActorModel, env: ReplayEnvironment, config: Config):
	model.eval()

	batchSize_ = config.miniBatchSize * config.delayBatchs

	if config.trainingRandomInitState:
		state_ = tc.randint(env.stateSpace, (batchSize_,))
	else:
		state_ = tc.zeros(batchSize_, dtype = tc.int)
	
	end_ = tc.zeros(batchSize_, dtype = tc.bool)
	
	states_ = list()
	values_ = list()
	advantages_ = list()
	actions_ = list()
	masks_ = list()
	value_ = None
	adv_ = None

	for s in range(config.maxSteps):
		states_.append(state_)
		masks_.append(tc.logical_not(end_))
		actDists_, estims_ = model(state_)
		acts_ = actDists_.sample()
		
		actions_.append(acts_)
		decayedEstims_ = estims_ * config.lamb
		if value_ is not None:
			value_ += decayedEstims_
		if adv_ is not None:
			adv_ += decayedEstims_
		
		state_, value_, end_ = env.step(state_, acts_)
		
		# generate "partial" value and advantage for the current
		values_.append(value_)
		adv_ = value_ - estims_
		advantages_.append(adv_)

	estims_ = model(state_)[1] * config.lamb
	value_ += estims_
	adv_ += estims_
	
	idx_ = tc.cat(masks_).to(tc.bool)
	return TensorDataset(tc.cat(states_)[idx_], tc.cat(values_)[idx_], tc.cat(actions_)[idx_], tc.cat(advantages_)[idx_])



In [9]:
def testModel(env: ReplayEnvironment, model: ActorModel, config: Config):
	paths_ = generatePaths(model, env, config)


In [61]:
def trainModel(env: ReplayEnvironment, model: ActorModel, config: Config):
	opt_ = tco.AdamW([{"params": model.decayParameters(), "weight_decay": 0.01}, {"params": model.nondecayParameters(), "weight_decay": 0.0}], config.lr)
	print("begin epoch {0}, path generation... ".format(config.epoch))
	
	dataloader_ = DataLoader(generatePaths(model, env, config), config.miniBatchSize, shuffle = True, drop_last= True)
	print("begin epoch {0}, training... ".format(config.epoch))
	
	model.train()
	losses_ = list()
	for batches_, miniBatch_ in enumerate(dataloader_):
		states_, values_, actions_, advantages_ = miniBatch_
		assert states_.shape[0] == values_.shape[0] == actions_.shape[0] == advantages_.shape[0] == config.miniBatchSize, "minibatch size doesn't match required {0}".format(config.miniBatchSize)
		

		actionDists_, pvals_ = model(states_)
		
		loss1_ = -tc.mean(actionDists_.log_prob(actions_) * advantages_)
		loss2_ = tcf.mse_loss(pvals_, values_)
		loss_ = loss1_ + loss2_ * config.stateEstimationWeight

		opt_.zero_grad()
		loss_.backward()
		opt_.step()
		
		losses_.append(loss_.item())

	print("finish epoch {0} with mean loss {1}".format(config.epoch, np.mean(losses_)))
	return losses_

In [63]:
def trainAll(env: ReplayEnvironment, config: Config) -> ActorModel:
	model_ = ActorModel(env)
	epoch_ = config.epoch
	
	if epoch_ > 0:
		print("loading checkpoint {0}".format(epoch_))
		model_.load(config.stepCheckpoint())
	else:
		initializePolicy(model_, env, config)
	
	for t in range(epoch_, config.totalEpochs):
		print("start training epoch {0}".format(t))
		path_ = config.stepCheckpoint(t)
		loss_ = trainModel(env, model_, config)
		print("finish training, save to path {0}".format(path_))
		#tc.save(model_.state_dict(), path_)

		print("finish training epoch {0}".format(t))
	
	print("finish all training steps")
	return model_

In [72]:
def main():
	config_ = Config(0, False)
	env_ = ReplayEnvironment(config_)
	model_ = trainAll(env_, config_)
	_, est_ = model_(tc.arange(0, env_.stateSpace))
	print(est_)
	
if __name__ == "__main__":
	main()

<__main__.Config object at 0x0000021048620BE0>

[41mS[0mFFF
FHFH
FFFH
HFFG
initialization loss: 0.053715143352746964
start training epoch 0
begin epoch 0, path generation... 
begin epoch 0, training... 
finish epoch 0 with mean loss -0.31609181969951783
finish training, save to path ./checkpoints/model_00.bin
finish training epoch 0
start training epoch 1
begin epoch 1, path generation... 
begin epoch 1, training... 
finish epoch 1 with mean loss -0.24551071100462432
finish training, save to path ./checkpoints/model_01.bin
finish training epoch 1
start training epoch 2
begin epoch 2, path generation... 
begin epoch 2, training... 
finish epoch 2 with mean loss -0.4038170665886862
finish training, save to path ./checkpoints/model_02.bin
finish training epoch 2
start training epoch 3
begin epoch 3, path generation... 
begin epoch 3, training... 
finish epoch 3 with mean loss -0.24417520223051212
finish training, save to path ./checkpoints/model_03.bin
finish training epoch 3
start trai