In [17]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import tqdm
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
import matplotlib.pyplot as plt
from Behaviour_pi import Behavioural, Policy
import numpy as np

In [18]:
# KL between 2D Bernoulli probabilities for one batch
def bernoulli_kl(q, p, eps=1e-8):
    p = torch.clamp(p, eps, 1 - eps)
    q = torch.clamp(q, eps, 1 - eps)
    return (p * torch.log(p / q) + (1 - p) * torch.log((1 - p) / (1 - q))).sum(dim=-1)


In [19]:
# Compute KL between two MLP policies over a batch of states
def kl_between_policies(pi_perturbed, pi, states):
    with torch.no_grad():
        probs_pi = pi(states)
        probs_pert = pi_perturbed(states)
        kl_vals = bernoulli_kl(probs_pert, probs_pi)
    return kl_vals.mean()


In [4]:
# Compute KL over multiple perturbed models
def kl_over_perturbed_set(pi, perturbed_pis, states):
    kl_results = []
    for i, perturbed_pi in enumerate(perturbed_pis):
        kl_val = kl_between_policies(perturbed_pi, pi, states)
        kl_results.append(kl_val.item())
        print(f"KL(pi_perturbed_{i}|| pi) = {kl_val:.6f}")
    return kl_results


In [None]:
# Compute KL for single perturbed models
# def kl_over_perturbed_set(perturbed_pi, pi, states):
#     kl_val = kl_between_policies(perturbed_pi, pi, states)
#     return kl_val

In [20]:
df = pd.read_csv("./traj.csv")  # or use your DataFrame directly

states = torch.tensor(df.iloc[:60000, :5].values, dtype=torch.float32)
behavior_pi = Behavioural()
behavior_pi.load_state_dict(torch.load("./Behavioural_model_2.pth"))
pi = torch.load("./Policys/Perturbed_model_2.pth")
t = 0
kl_score_ls = []
for i in range(100):
    pi = torch.load(f"./Policys/Perturbed_model_{i}.pth")
    kl_scores = kl_between_policies(pi, behavior_pi, states)
    kl_score_ls.append(kl_scores.item())
    print(kl_scores)
    # if kl_scores < 0.3:
    #     t += 1
np.savetxt("./Test/kl_diver.txt", kl_score_ls)
# print(t)

# states = torch.randn(100, 5)  # input_dim = 5

# Suppose this is your original policy
# pi = Behavioural(input_dim=5, hidden_dim=64, output_dim=2)

# # Create N perturbed copies
# perturbed_pis = [
#     get_multiplicatively_perturbed_copy(pi, omega=0.02)
#     for _ in range(10)
# ]

# Compute KL divergence for each
# kl_scores = kl_between_policies(pi, behavior_pi, states)
# kl_scores


tensor(0.3342)
tensor(1.3984)
tensor(0.4517)
tensor(0.9462)
tensor(1.0949)
tensor(0.6914)
tensor(0.1616)
tensor(0.8555)
tensor(2.1016)
tensor(1.4483)
tensor(0.6429)
tensor(0.4607)
tensor(0.7295)
tensor(0.6676)
tensor(0.6063)
tensor(0.4905)
tensor(1.6391)
tensor(0.4566)
tensor(1.5285)
tensor(0.3781)
tensor(1.1725)
tensor(0.3113)
tensor(0.3304)
tensor(0.8087)
tensor(0.2493)
tensor(1.8349)
tensor(0.5938)
tensor(0.8975)
tensor(0.6358)
tensor(0.5642)
tensor(0.2502)
tensor(0.4852)
tensor(0.3306)
tensor(1.1133)
tensor(0.3592)
tensor(0.5447)
tensor(0.9337)
tensor(0.3455)
tensor(0.3058)
tensor(1.1078)
tensor(1.3052)
tensor(0.3511)
tensor(0.4927)
tensor(0.4710)
tensor(0.1862)
tensor(0.4161)
tensor(0.5183)
tensor(0.6850)
tensor(0.5280)
tensor(0.7506)
tensor(0.7821)
tensor(0.4182)
tensor(0.5157)
tensor(inf)
tensor(0.6555)
tensor(0.7821)
tensor(0.4213)
tensor(0.1785)
tensor(0.4118)
tensor(inf)
tensor(0.8901)
tensor(0.6524)
tensor(0.2094)
tensor(0.2055)
tensor(0.2296)
tensor(0.9009)
tensor(0.7924)
t