In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:

matrixA_empty = torch.tensor([[-1,3,-3,1],
                            [3,-6,3,0],
                            [-3,0,3,0],
                            [1,4,1,0]],dtype=torch.float32, requires_grad=False)/6

coef_win_empty = torch.zeros(5,5,requires_grad=False,dtype=torch.int32)

win_matrix_base = torch.tensor([0,1,2,3],requires_grad=False,dtype=torch.float)

win_matrix_empty = win_matrix_base.repeat(16).int()

x_index_empty = torch.tensor(range(64),requires_grad=False,dtype=torch.int32).contiguous()

In [None]:
def preactivate(x, n_int, positive=False, seq = False, causal_mask = None):
    '''
        x: 2D tensor of shape (batch_size, input_dim) or 3D tensor of shape (batch_size, seq_dim, input_dim) if seq is True
        n_int: integer, number of intervals
        positive: bool, whether to clip the input to [0,1]
        causal_mask: bool, whether to use causal mask
        reutrn: 3D tensor of shape (batch_size, input_dim, n_int+3) or 3D tensor of shape (batch_size, seq_dim, input_dim, n_int+3) if seq is True
    '''
    
    device = x.device
    global matrixA
    matrixA = matrixA_empty.to(device)
    if positive==True:
        x = x.clip(0,1)
    else:
        x = x.clip(-1,1)
    if seq == True:
        batch_size = x.shape[0]
        seq_dim = x.shape[1]
        input_dim = x.shape[2]
        output_shape = (batch_size,seq_dim,input_dim,n_int+3)
        size = batch_size*seq_dim*input_dim
    else:
        batch_size = x.shape[0]
        input_dim = x.shape[1]
        output_shape = (batch_size,input_dim,n_int+3)
        size = batch_size*input_dim
    
    x = x.view(-1)
    #Select the adjacent control points
    if positive==True:
        gap = 1/n_int
    else:
        gap = 2/n_int
    if positive==True:
        coef_index  = torch.clip(torch.floor(x/gap).int(),max = n_int-1)
    else:
        coef_index  = torch.clip(torch.floor((x+1)/gap).int(),max = n_int-1) 
    
    global coef_win_empty
    if coef_win_empty.shape[0] < 4*size or coef_win_empty.shape[1] < n_int+3:
        coef_win_empty = torch.zeros(4*size,n_int+3,requires_grad=False, device=device,dtype=torch.float)
    coef_win = coef_win_empty[:4*size,:n_int+3].clone().to(device)
    
    global x_index_empty
    if x_index_empty.shape[0] < 4*size:
        x_index_empty = torch.tensor(range(4*size),requires_grad=False,dtype=torch.int32,device=device).contiguous()
    x_index = x_index_empty[:4*size].to(device).contiguous()
    
    global win_matrix_empty
    if win_matrix_empty.shape[0] < 4*size:
        win_matrix_empty = win_matrix_base.repeat(4*size).int().to(device)
    win_matrix = win_matrix_empty[:4*size].to(device)
    
    y_index = (coef_index.unsqueeze(1).expand(-1,4).contiguous().view(-1).int()+win_matrix).contiguous()
    
    coef_win.index_put_((x_index,y_index),torch.tensor(1,device=device,dtype=torch.float))
    
    if positive==True:
        x = (x%gap/gap).unsqueeze(-1)
    else:
        x = ((x+1)%gap/gap).unsqueeze(-1)
    
    if seq == True and causal_mask is not None:
        x = x.view(batch_size,seq_dim,input_dim,1)
        mask = causal_mask[:,:seq_dim,:input_dim].view(1,seq_dim,input_dim,1)
        x = torch.cat((torch.pow(x,3),torch.pow(x,2),x,torch.ones(size,1,device=device)),dim=-1) # (batch_size,seq_dim,input_dim,4)
        x = x.masked_fill(mask==0,0)
        x = x.view(batch_size*seq_dim*input_dim,4).unsqueeze(1) # (size,1,4)
    else:    
        x = torch.cat((torch.pow(x,3),torch.pow(x,2),x,torch.ones(size,1,device=device)),dim=-1).unsqueeze(1) # (size,1,4)

    x = torch.matmul(x,matrixA)
    x = torch.matmul(x,coef_win.view(size,4,n_int+3)).view(*output_shape)
    return x
    
@torch.no_grad()
def infer_coef(x, y, n_int, positive = False):
    '''
        x: 2D tensor of shape (batch_size, input_dim)
        y: 2D tensor of shape (batch_size, output_dim)
        n_int: integer, number of intervals
        use the least square method to infer the coefficients
    '''
    device = x.device
    y = y.to(device)
    batch_size = x.shape[0]
    input_dim = x.shape[1]
    preact  = preactivate(x, n_int, positive).reshape(batch_size,-1) 
    coef = torch.linalg.lstsq(preact.to(device), y.to(device),
                              driver='gelsd' if device == 'cpu' else 'gels').solution
    coef = coef.reshape(input_dim,n_int+3,-1)
    return coef

def spline_forward(x: torch.Tensor, n_int: int, coef: torch.Tensor, positive = False):
    '''
        x: 2D tensor of shape (batch_size, input_dim)
        n_int: integer, number of intervals
        coef: 3D tensor of shape (input_dim, n_int+3, output_dim) 
    '''
    batch_size = x.shape[0]
    input_dim = x.shape[1]
    output_dim = coef.shape[2]
    device = x.device
    coef = coef.to(device)
    preact  = preactivate(x, n_int).to(device) # (batch_size, input_dim, n_int+3)
    preact = preact.reshape(batch_size,input_dim*(n_int+3))
    coef = coef.reshape(input_dim*(n_int+3),output_dim)
    out = torch.matmul(preact,coef)
    return out
    

In [None]:
# Plot the curve
x = torch.linspace(-1,1,100).reshape(100,1,1) # (1000,1)
x_int = 3
coef = torch.tensor([[-2,0],[-1,-1],[0,0],[1,1],[2,0],[3,-1]],dtype=torch.float32, device='cpu').reshape(1,6,2)
#coef = torch.randn(1,6,2,dtype=torch.float32,device='cuda')
coef_1d = torch.tensor(range(6),dtype=torch.float32,device='cpu').reshape(1,6,1)
y = spline_forward(x, x_int, coef)
#y_1d  = spline_forward(x, x_int, coef_1d, 'cuda').to('cpu')
plt.plot(y.numpy()[:,0],y.numpy()[:,1])
#plt.plot(x.numpy(),y_1d.numpy())
#plot the control points
plt.scatter(coef[:,:,0].to('cpu').numpy(),coef[:,:,1].to('cpu').numpy())
plt.show()


In [None]:
# Plot the curve
x = torch.linspace(-1,1,1000).unsqueeze(1)
x_int = 12
x_eval = torch.linspace(-1,1,12).unsqueeze(1)
coef = infer_coef(x_eval,torch.randn(12,2,dtype=torch.float32,device='cuda'),12,'cuda')
y = spline_forward(x, x_int, coef, 'cuda').to('cpu')
plt.plot(y.numpy()[:,0],y.numpy()[:,1])
# Plot the control points
plt.scatter(coef[:,0].to('cpu').numpy(),coef[:,1].to('cpu').numpy())
plt.show()
print(infer_coef(x, y, x_int, 'cuda'))
