In [5]:
import torch
import torch.nn as nn

import transformers

from decision_transformer.models.model import TrajectoryModel
from decision_transformer.models.trajectory_gpt2 import GPT2Model
import math
import numpy as np
import torch.nn.functional as F
from torch import distributions as pyd

class LinearTransform(pyd.transforms.Transform):
    bijective = True
    sign = +1

    def __init__(self, value_range=(0.0, 5.0), cache_size=1):
        super().__init__(cache_size=cache_size)
        self.min_val, self.max_val = value_range
        self.scale = 2 / (self.max_val - self.min_val)
        self.offset = -1 - self.min_val * self.scale
        self.domain = pyd.constraints.interval(self.min_val, self.max_val)
        self.codomain = pyd.constraints.interval(-1.0, 1.0)
        
    def _call(self, x):
        # 입력을 [min_val, max_val]에서 [-1, 1]로 스케일링
        return self.scale * x + self.offset

    def _inverse(self, y):
        # [-1, 1]에서 [min_val, max_val]로 역변환
        return (y - self.offset) / self.scale

    def log_abs_det_jacobian(self, x, y):
        # 변환은 선형이므로 Jacobian은 상수임
        return math.log(abs(self.scale))

class TanhTransform(pyd.transforms.Transform):
    domain = pyd.constraints.real
    codomain = pyd.constraints.interval(-1.0, 1.0)
    bijective = True
    sign = +1

    def __init__(self, cache_size=1):
        super().__init__(cache_size=cache_size)

    @staticmethod
    def atanh(x):
        return 0.5 * (x.log1p() - (-x).log1p())

    def __eq__(self, other):
        return isinstance(other, TanhTransform)

    def _call(self, x):
        return x.tanh()

    def _inverse(self, y):
        # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
        # one should use `cache_size=1` instead
        return self.atanh(y)

    def log_abs_det_jacobian(self, x, y):
        # We use a formula that is more numerically stable, see details in the following link
        # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7
        return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))


class SquashedNormal(pyd.transformed_distribution.TransformedDistribution):
    """
    Squashed Normal Distribution(s)

    If loc/std is of size (batch_size, sequence length, d),
    this returns batch_size * sequence length * d
    independent squashed univariate normal distributions.
    """

    def __init__(self, loc, std, transform_type = 'tanh', value_range = [-1.0, 1.0]):
        self.loc = loc
        self.std = std
        self.base_dist = pyd.Normal(loc, std)

        if transform_type == 'tanh':
            transforms = [TanhTransform()]
        elif transform_type == 'linear':
            print("linear transform")
            transforms = [LinearTransform(value_range=value_range)]
        super().__init__(self.base_dist, transforms)

    @property
    def mean(self):
        mu = self.loc
        for tr in self.transforms:
            mu = tr(mu)
        return mu

    def entropy(self, N=1):
        # sample from the distribution and then compute
        # the empirical entropy:
        x = self.rsample((N,))
        log_p = self.log_prob(x)

        # log_p: (batch_size, context_len, action_dim),
        return -log_p.mean(axis=0).sum(axis=2)

    def log_likelihood(self, x):
        # log_prob(x): (batch_size, context_len, action_dim)
        # sum up along the action dimensions
        # Return tensor shape: (batch_size, context_len)

        transformed_x = self.transforms[0]._inverse(x)  # LinearTransform의 역변환 적용
        return self.base_dist.log_prob(transformed_x).sum(axis=2)
        return self.log_prob(x).sum(axis=2)



class DiagGaussianActor(nn.Module):
    """torch.distributions implementation of an diagonal Gaussian policy."""

    def __init__(self, hidden_dim, act_dim, log_std_bounds=[-5.0, 2.0], transform_type = 'tanh', value_range = [-1.0, 1.0]):
        super().__init__()

        self.mu = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std_bounds = log_std_bounds
        self.transform_type = transform_type
        self.value_range = value_range
        def weight_init(m):
            """Custom weight init for Conv2D and Linear layers."""
            if isinstance(m, torch.nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)

        self.apply(weight_init)

    def forward(self, obs):
        mu, log_std = self.mu(obs), self.log_std(obs)
        log_std = torch.tanh(log_std)
        # log_std is the output of tanh so it will be between [-1, 1]
        # map it to be between [log_std_min, log_std_max]
        log_std_min, log_std_max = self.log_std_bounds
        log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1.0)
        std = log_std.exp()
        return SquashedNormal(mu, std, self.transform_type, value_range = self.value_range)

predict_state = DiagGaussianActor(512, 4, transform_type='linear', value_range=[0, 1.0])
obs = torch.randn(256,1, 512)
predict_s = predict_state(obs)  
predict_s.log_likelihood(torch.ones(1, 4)*2)

linear transform


tensor([[-4.4932e+02],
        [-9.9151e+03],
        [-1.9761e+04],
        [-7.7693e+00],
        [-3.6626e+03],
        [-1.2410e+03],
        [-1.0183e+01],
        [-2.4498e+04],
        [-1.2754e+03],
        [-6.9986e+03],
        [-1.5124e+04],
        [-1.7210e+01],
        [-3.7716e+02],
        [-4.1022e+03],
        [-5.4004e+02],
        [-1.0658e+02],
        [-1.9554e+02],
        [-2.2658e+02],
        [-3.9010e+03],
        [-6.3951e+03],
        [-9.8746e+02],
        [-3.4608e+04],
        [-1.1668e+04],
        [-2.8650e+03],
        [-2.4699e+04],
        [-1.1391e+04],
        [-3.2456e+03],
        [-2.0190e+04],
        [-2.5150e+04],
        [-1.4178e+04],
        [-3.8264e+04],
        [-8.1793e+03],
        [-3.0649e+03],
        [-1.3716e+03],
        [-2.5247e+02],
        [-1.1431e+03],
        [-4.1997e+03],
        [-7.2035e+02],
        [-2.8083e+02],
        [-5.0875e+04],
        [-4.9362e+03],
        [-1.0401e+03],
        [-3.9076e+02],
        [-9

torch.Size([1, 4])
tensor([[1.1000, 1.1000, 1.1000, 1.1000]])


ValueError: The value argument must be within the support