In [None]:
import numpy as np
import torch
import scipy
import itertools
from scipy.special import binom
from torch import nn

In [None]:
class basis(nn.Module):
    def __init__(self,dim_in,dim_out,n=8,shapes=[16,16],NL=nn.ELU,batch_size=8):
        super(basis, self).__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.n = n
        self.shapes = shapes
        self.first = nn.Linear(dim_in,shapes[0])
        self.basis = nn.ModuleList(
                        nn.ModuleList(
                            nn.Linear(shapes[k],shapes[k+1])\
                            for k in range(len(shapes)-1))\
                            for i in range(n))
        self.last = nn.Linear(shapes[-1], dim_out)
        self.NL = NL(inplace=True) 
        self.batch_size = batch_size
    
    def generate_basis(self):
        return self.basis
    
    def forward(self,i,y):
        y_in = y.unsqueeze(0).repeat(self.batch_size,1,1)
        y = self.NL(self.first.forward(y_in))
        for layer in self.basis[i]:
            y = self.NL(layer.forward(y)) 
        y = self.last.forward(y)
        return y
    
    def basis_size(self):
        return self.n

In [None]:
A = basis(3,5)

In [None]:
B = A.generate_basis()

In [None]:
len(B)

In [None]:
A = basis(3,5,batch_size=16)

In [None]:
A.forward(7,torch.rand(3)).shape

In [None]:
from source.integrators import MonteCarlo

In [None]:
mc = MonteCarlo()

In [None]:
class F_NN(nn.Module):
    def __init__(self,in_dim,out_dim,shapes,NL=nn.ELU):
        super(F_NN, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.n_layers = len(shapes) - 1
        self.shapes = shapes
        self.first = nn.Linear(in_dim,shapes[0])
        self.layers = nn.ModuleList([nn.Linear(shapes[i],shapes[i+1]) for i in range(self.n_layers)])
        self.last = nn.Linear(shapes[-1], out_dim)
        self.NL = NL(inplace=True) 
        
    def forward(self, y):
        y_in = y.unsqueeze(0).repeat(16,1,1)
        y = self.NL(self.first.forward(y_in))
        for layer in self.layers:
            y = self.NL(layer.forward(y))   
        y = self.last.forward(y)
        return y

In [None]:
class Leray_Schauder(nn.Module):
    def __init__(self,basis,epsilon=.1,dim=1,modes=2,N=1000,p=2,batch_size=8):
        super(Leray_Schauder, self).__init__()
        self.basis = basis
        self.epsilon = epsilon
        self.dim = dim
        self.modes = modes
        self.N = N
        self.p = p
        self.n = self.basis.basis_size()
        self.batch_size = batch_size
        
    def norm(self,func):
        #print(func(torch.rand(1000,3)).shape)
        integral = mc.integrate(
            fn= lambda s: func(s)**self.p,
            dim= self.dim,
            N= self.N,
            out_dim = -2,
            )
        return torch.pow(integral,1/self.p)
        
    def mu_i(self,func,i):
        norm_ = torch.norm(self.norm(lambda s: func(s)-self.basis(i,s)),p=self.p,dim=[-1]).to(torch.float64)
        norm_ = norm_.unsqueeze(-1)
        return torch.where(norm_<=self.epsilon,norm_,0.).float()#norm_ if norm_<= self.epsilon else 0.
        
    def proj(self,func,x):
        out = torch.zeros(self.batch_size,self.modes)
        normalization = torch.tensor([1e-7]).unsqueeze(0).repeat(self.batch_size,1)
        for i in range(self.n):
            mui = self.mu_i(func,i)
            out += mui*self.basis.forward(i,x).view(self.batch_size,self.modes)
            normalization += mui
        out /= normalization
        return out
    
    def proj_coeff(self,func):
        out = torch.tensor([])
        normalization = torch.tensor([1e-7]).unsqueeze(0).repeat(self.batch_size,1)
        for i in range(self.n):
            mui = self.mu_i(func,i)
            out = torch.cat([out,mui],dim=-1)
            normalization += mui
        out /= normalization
        return out
    
    def basis_eval(self,i,x):
        return self.basis.forward(i,x)
    
    def return_basis(self):
        return self.basis
    
    def return_modes(self):
        return self.modes

In [None]:
func = F_NN(3,5,[16,16])

In [None]:
LS = Leray_Schauder(A,dim=3,epsilon=2.,modes=5,batch_size=16)

In [None]:
LS.mu_i(func,2).shape

In [None]:
LS.proj(func,torch.rand(3)).shape

In [None]:
LS.proj_coeff(func).shape

In [None]:
spatial_domain_xy = torch.meshgrid([torch.linspace(0,1,10) for i in range(2)])
        
x_space = spatial_domain_xy[0].flatten().unsqueeze(-1)
y_space = spatial_domain_xy[1].flatten().unsqueeze(-1)

spatial_domain = torch.cat([x_space,y_space],-1)

In [None]:
spatial_domain.shape

In [None]:
class Leray_Schauder_model(nn.Module):
    def __init__(self,LS_map,proj_NN,batch_size=8):
        super(Leray_Schauder_model, self).__init__()
        self.LS_map = LS_map
        self.proj_NN = proj_NN
        self.n = LS_map.return_basis().basis_size()
        self.basis = LS_map.return_basis()
        self.batch_size = batch_size
       
    def recompose(self,coeff):
        func = lambda s: torch.cat([
                coeff[i]*self.LS_map.basis_eval(i,s)\
                            .view(self.batch_size,1,self.LS_map.return_modes())\
                            for i in range(self.basis.basis_size())],dim=-2).sum(dim=-2)
        return func
    
    def projected_function(self,func):
        projection_coeff = self.LS_map.proj_coeff(func)
        out = self.proj_NN.forward(projection_coeff)
        out_func = self.recompose(out)
        
        return out_func

In [None]:
NN = F_NN(8,8,[16,16])

In [None]:
func = F_NN(3,5,[16,16])

In [None]:
model = Leray_Schauder_model(LS,NN,batch_size=16)

In [None]:
out_func = model.projected_function(func)

In [None]:
out_func(torch.rand(16,3)).shape

In [None]:
out_func(torch.rand(16,3))

In [None]:
torch.rand(5).dtype

In [None]:
x = torch.rand(5,dtype=torch.double)

In [None]:
torch.where(x>=0.,x,0.).float()