In [None]:
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
from matplotlib import pyplot as plt

def create_gaussian_condition_distribution(data,prior=True,data2=torch.tensor([])):
    D = data.shape[0]
    state_dim = data.shape[1]
    if prior:
        init_dist = MultivariateNormal(torch.zeros(state_dim), torch.eye(state_dim))
        data2 = init_dist.sample((D,))

    mean_data2 = data2.mean(dim=0)
    diffs_init = data2-mean_data2
    mean_data = data.mean(dim=0)
    diffs_data = data-mean_data
    cov_11 = torch.matmul(diffs_data.T,diffs_data)
    cov_22 = torch.matmul(diffs_init.T,diffs_init)
    cov_12 = torch.matmul(diffs_data.T,diffs_init)
    cov_21 = torch.matmul(diffs_init.T,diffs_data)
    conditional_mean = mean_data + torch.matmul(torch.matmul(cov_12,torch.inverse(cov_22)),diffs_init.T).mean(dim=1)
    conditional_cov = cov_11 - torch.matmul(torch.matmul(cov_12,torch.inverse(cov_22)),cov_21)
    return MultivariateNormal(conditional_mean,conditional_cov)

iris = load_iris()
df_iris = torch.tensor(iris.data,dtype=torch.float32)
conditional_dist = create_gaussian_condition_distribution(df_iris,prior=True)

In [None]:

def non_linear_gaussian_ssm(state_dim):
    p_z = MultivariateNormal(torch.zeros(state_dim), torch.eye(state_dim))
    def samples(num_steps=100):
        q_t = (0.009**0.5)*torch.randn(state_dim)
        r_t = (0.009**0.5)*torch.randn(state_dim)
        result = torch.zeros(num_steps,state_dim)
        states = torch.zeros(num_steps,state_dim)
        for i in range(num_steps):
            f_temp = f(p_z.sample())
            z_t = f_temp + q_t
            states[i] = z_t
            y_t = h(z_t) + r_t
            result[i] = y_t
        return states, result
    return samples

f = lambda z: z + 0.4 * torch.tensor([torch.sin(z[1]), torch.cos(z[0])])
h = lambda z: z
state_dim = 2
lgssm = non_linear_gaussian_ssm(2)
states,emissions = lgssm(num_steps=100)
plot_inference(states,emissions, title="Noisy obervations from hidden trajectory")

In [None]:
A = torch.tensor([[3,2,-1],[2,-2,4],[-1,0.5,-1]])
B = torch.tensor([1,-2,0.0])
x = torch.linalg.solve(A,B)
print(x)

In [None]:
class kalman_filter:

    def __init__(self, yt, z_dim):
        self.z_dim = z_dim
        self.yt = yt
        self.y_dim = yt.shape[1]
        self.total_steps = yt.shape[0]
        self.parameters = self.params_init_kalman()
    
    def distribution(self,mean,cov):
        return MultivariateNormal(mean,cov)

    def init_values(self):
        z_mean = torch.zeros(self.z_dim)
        z_cov = torch.eye(self.z_dim)
        init_pz_samples = self.distribution(z_mean,z_cov).sample()
        y_mean = yt[:1].mean(dim=0)
        y_cov = 0.001*torch.eye(self.z_dim)
        return init_pz_samples,y_mean,y_cov

    def next_step(self, step_num, prev_z_mean=None, prev_z_cov=None):
        if step_num == 0:
            init_pz_samples,y_mean,y_cov = self.init_values()
            # print('intial_steps',init_pz_samples,y_mean,y_cov)
            return init_pz_samples,y_mean,y_cov
        else:
            print('prev_z_mean',prev_z_mean)
            print('prev_z_cov',prev_z_cov)
            next_pz_samples = self.distribution(prev_z_mean,prev_z_cov).sample()
            yt_current = yt[:step_num+1]
            y_mean = yt_current.mean(dim=0)
            y_cov = torch.cov(yt_current.T)
            print('next_pz_samples',next_pz_samples)
            print('y_mean',y_mean)
            print('y_cov',y_cov)
            return next_pz_samples,y_mean,y_cov
    
    def execute(self):
        zts = torch.zeros(yt.shape)
        for i in range(self.total_steps):
            if i == 0:
                zt,y_mean,y_cov = self.next_step(i)
                y_cov = torch.unsqueeze(y_cov,dim=0)
            else:
                zt_mean_prev = y_mean
                zt_cov_prev = y_cov
                zt,y_mean,y_cov = self.next_step(i,zt_mean_prev,zt_cov_prev)

            # print('next_step',y_mean,y_cov)
            zts[i] = zt
            # print('zts :',zts[:i+1])
            # zt_mean = zts[:i+1].mean(dim=0)
            # zt_cov = torch.cov(zts[:i+1].T)

            # zts_diffs = zts[:i+1] - zt_mean
            # print('zts_diffs ',zts_diffs)
            # yt_mean = self.yt[:i+1].mean(dim=0)
            # yt_diffs = self.yt[:i+1] - yt_mean
            # print('yt_diffs ',yt_diffs)

            # S_t = y_cov
            # C_t = torch.matmul(yt_diffs.T,zts_diffs)
            # print('S_t',S_t)
            # print('C_t',C_t)
            # K_t = torch.linalg.lstsq(S_t.T,C_t.T)
            # print(K_t)
            # print(yt[i]-y_mean)
            # posterior_z_y_mean = zt_mean+K_t*(yt[i]-y_mean)
            # posterior_z_y_cov = zt_cov-torch.matmul(torch.matmul(K_t,S_t),K_t.T)
        # return posterior_z_y_mean,posterior_z_y_cov

filter_preditions = kalman_filter(yt,z_dim=2)
# filter_preditions_mean, filter_preditions_cov = 
filter_preditions.execute()
# plot_inference(zt,yt,filter_preditions_mean,title='custom linear gaussian state space model')