In [1]:
from load import *
from torch import Tensor, nn
import torch
from model_base import *
from modules import *
from anim import *
import util
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import warnings

ROOT = "./Datas/Week 8"

Q = 1.60217663e-19

In [2]:
sc = load_space_charge() * -Q
ep = load_elec_potential()
vg = load_vgs()
poi = NormalizedPoissonMSE('cpu')
poi(ep, sc)

tensor(5.3032e-10, dtype=torch.float64)

In [3]:
def eigenify(X):
    cov = torch.cov(X.T.float())
    l, v = torch.linalg.eig(cov)

    # Temporarily supress warnings
    warnings.filterwarnings("ignore", category=UserWarning) 
    sorted_eigen = l.float().argsort(descending=True)
    P = v[:, sorted_eigen].double()
    warnings.filterwarnings("default", category=UserWarning) 

    # Perform projection
    X_ = X - torch.mean(X.float(), dim = 0)
    components = X_.double() @ P
    return components

def eigende(X, N: int):
    cov = torch.cov(X.T.float())
    l, v = torch.linalg.eig(cov)

    # Temporarily supress warnings
    warnings.filterwarnings("ignore", category=UserWarning) 
    sorted_eigen = l.float().argsort(descending=True)
    sorted_vec = v[:, sorted_eigen].double()
    P = sorted_vec[:, :N]
    warnings.filterwarnings("default", category=UserWarning)

    print(P.shape)

    # P is the required matrix
    # Perform projection. The original data is transposed so welcome to transpose hell
    X_ = X - torch.mean(X.float(), dim = 0)
    components = P.T @ X_.T.double()
    return components.T

In [None]:
# Pass a function object to scipy because we can avoid closures and local functions
class ReconstructedPoissonLoss:
    def __init__(self, xep, xsc, /, device = None):
        self.i = 0
        self.xep = xep
        self.xsc = xsc
        self.poisson_loss = NormalizedPoissonRMSE(device)
        self._device = self.poisson_loss._device

    def set_i(self, i):
        self.i = i
        return self

    def reconstruct(self, x):
        i = self.i

        ep_region_2 = x[:429].reshape(84 - 45, 11)
        ep_region_5 = x[429:663].reshape(84 - 45, 6)
        sc_region_2 = x[663:].reshape(84 - 45, 11)

        reconstructed_ep = self.xep[i]
        reconstructed_ep[45:84,:11] = ep_region_2
        reconstructed_ep[45:84,11:] = ep_region_5
        reconstructed_ep = reconstructed_ep.reshape(1, 129, 17)

        reconstructed_sc = self.xsc[i]
        reconstructed_sc[45:84,:11] = sc_region_2
        reconstructed_sc = reconstructed_sc.reshape(1, 129, 17)

        return reconstructed_ep, reconstructed_sc
    
    def __call__(self, x):
        x = torch.tensor(x).to(self._device)
        rep, rsc = self.reconstruct(x)
        mse = self.poisson_loss(rep, rsc)
        return float(mse.item())

class DebugeModel(Model):
    """Use a first model to predict stuff, then use a second model to make them self consistent - aka satisfy the Poisson equation"""
    def __init__(self, ep1: TrainedLinear, sc1: TrainedLinear):
        # From the linearity plots, we only need to care about region 2 in practice for space charge
        # and region 2, 5 for electric potential
        self.ep1 = ep1
        self.sc1 = sc1
        
    def forward(self, x) -> Tensor:
        num_data = int(x.shape[0])
        result = torch.zeros(num_data, 4386)

        # Initial prediction
        xep = self.ep1(x).reshape(-1, 129, 17).cpu()
        xsc = self.sc1(x).reshape(-1, 129, 17).cpu()
                
        ep_region_2 = xep[:, 45:84,:11].reshape(-1, 429)
        ep_region_5 = xep[:, 45:84,11:].reshape(-1, 234)
        sc_region_2 = xsc[:, 45:84,:11].reshape(-1, 429)

        joined = torch.cat([ep_region_2, ep_region_5, sc_region_2], dim = 1)

        pca = PrincipalComponentExtractor(2)
        pca.fit(joined)

        print(pca.eigenvalues)
        return result

In [None]:
index = util.TRAINING_IDXS[0]

x = vg[index]
epy = ep[index]
scy = sc[index]
ep_linear = TrainedLinear(1, 2193, algorithm='linear').fit(x, epy.reshape(-1, 2193))
sc_linear = TrainedLinear(1, 2193, algorithm='linear').fit(x, scy.reshape(-1, 2193))
model = DebugeModel(ep_linear, sc_linear)
print(model.summary())

model(x)