In [None]:
from IPython import get_ipython
get_ipython().magic('reset -sf')

for rng_itrs in [0]: # Put deired list of RNG values
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    torch.cuda.empty_cache()
    import matplotlib.pyplot as plt

    import snntorch as snn
    from snntorch import surrogate
    from snntorch import functional as SF
    from snntorch import utils
    from snntorch import spikegen

    from timeit import default_timer
    from pytorch_wavelets import DWT1D, IDWT1D

    torch.manual_seed(rng_itrs)
    np.random.seed(rng_itrs)

    from utilities_0 import *
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # from utilities_1 import *
    # device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

    # %%
    
    def encode(x, nsteps):

        dx = 1/nsteps

        x = x//dx

        x_encoded = torch.zeros(x.shape[0],x.shape[1],nsteps)

        for i in range(0,x.shape[0]):
            for j in range(0,x.shape[1]):
                x_encoded[i,j,0:int(x[i,j])] = 1

        return x_encoded

    """ Def: 1d Wavelet layer """

    class WaveConv1d(nn.Module):
        def __init__(self, in_channels, out_channels, level, dummy):
            super(WaveConv1d, self).__init__()
            
            self.in_channels = in_channels
            self.out_channels = out_channels

            self.level = level
            self.wavelet = 'db6'
            self.mode = 'symmetric'

            self.dwt_ = DWT1D(J=self.level, mode=self.mode, wave=self.wavelet).to(dummy.device)
            self.x_dwt, _ = self.dwt_(dummy) 
            self.modes = self.x_dwt.shape[-1]

            self.scale = (1 / (in_channels*out_channels))
            self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes))
            self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes))

        def mul2d(self, input, weights):
            # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
            return torch.einsum("bix,iox->box", input, weights)

        def forward(self, x):
            batchsize = x.shape[0]
            # Compute single tree Discrete Wavelet coefficients using some wavelet     
            dwt = DWT1D(J=self.level, mode=self.mode, wave=self.wavelet).to(x.device)
            x_ft, x_coeff = dwt(x)

            # Multiply the final low pass and high pass coefficients
            out_ft = torch.zeros(batchsize, self.out_channels, x_ft.shape[-1],  device=x.device)
            out_ft = self.mul2d(x_ft, self.weights1)
            x_coeff[-1] = self.mul2d(x_coeff[-1], self.weights2)

            # Reconstruct the signal
            idwt = IDWT1D(mode=self.mode, wave=self.wavelet).to(x.device)
            x = idwt((out_ft, x_coeff))        
            return x

    """ The forward operation """

    class WNO1d(nn.Module):
        def __init__(self, level, width, dummy_data):
            super(WNO1d, self).__init__()

            self.level1 = level
            self.width = width
            self.padding = 2 # pad the domain if input is non-periodic
            self.dummy_data = dummy_data

            self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x)

            self.conv0 = WaveConv1d(self.width, self.width, self.level1, self.dummy_data)
            self.conv1 = WaveConv1d(self.width, self.width, self.level1, self.dummy_data)
            self.conv2 = WaveConv1d(self.width, self.width, self.level1, self.dummy_data)
            self.conv3 = WaveConv1d(self.width, self.width, self.level1, self.dummy_data)
            self.w0 = nn.Conv1d(self.width, self.width, 1)
            self.w1 = nn.Conv1d(self.width, self.width, 1)
            self.w2 = nn.Conv1d(self.width, self.width, 1)
            self.w3 = nn.Conv1d(self.width, self.width, 1)

            self.fc1 = nn.Linear(self.width, 128)
            self.fc2 = nn.Linear(128, 1)
            
            beta = torch.rand(1024)
            thr = torch.rand(1024)   
            self.lif1 = snn.Leaky(beta=beta, threshold=thr, reset_mechanism='zero',
                                  learn_beta=True, learn_threshold=True,
                                  spike_grad=surrogate.fast_sigmoid())
            
            beta = torch.rand(1024)
            thr = torch.rand(1024)   
            self.lif2 = snn.Leaky(beta=beta, threshold=thr, reset_mechanism='zero',
                                  learn_beta=True, learn_threshold=True,
                                  spike_grad=surrogate.fast_sigmoid())
            
            beta = torch.rand(1024)
            thr = torch.rand(1024)   
            self.lif3 = snn.Leaky(beta=beta, threshold=thr, reset_mechanism='zero',
                                  learn_beta=True, learn_threshold=True,
                                  spike_grad=surrogate.fast_sigmoid())        
            
            beta = torch.rand(128)
            thr = torch.rand(128)   
            self.lif4 = snn.Leaky(beta=beta, threshold=thr, reset_mechanism='zero',
                                  learn_beta=True, learn_threshold=True,
                                  spike_grad=surrogate.fast_sigmoid())

        def forward(self, x):
            n_spikes = 10
            x_spiketime = torch.empty([x.shape[0], x.shape[2], 128, n_spikes]).to(x.device)

            inputs = x

            mem1 = self.lif1.init_leaky()
            mem2 = self.lif2.init_leaky()
            mem3 = self.lif3.init_leaky()
            mem4 = self.lif4.init_leaky()

            s1 = 0
            s2 = 0
            s3 = 0
            s4 = 0        

            for i in range(0,n_spikes):
                x = inputs[:,i,:,:]
                
                grid = self.get_grid(x.shape, x.device)
                x = torch.cat((x, grid), dim=-1)

                x = self.fc0(x)
                
                x = x.permute(0, 2, 1)

                x1 = self.conv0(x)
                x2 = self.w0(x)
                x = x1 + x2
                spike,mem1 = self.lif1(x, mem1)
                x = spike

                s1 += spike.sum()/(x.shape[0]*64*1024)

                x1 = self.conv1(x)
                x2 = self.w1(x)
                x = x1 + x2
                spike,mem2 = self.lif2(x, mem2)
                x = spike

                s2 += spike.sum()/(x.shape[0]*64*1024)

                x1 = self.conv2(x)
                x2 = self.w2(x)
                x = x1 + x2
                spike,mem3 = self.lif3(x, mem3)
                x = spike

                s3 += spike.sum()/(x.shape[0]*64*1024)

                x1 = self.conv3(x)
                x2 = self.w3(x)
                x = x1 + x2

                x = x.permute(0, 2, 1)

                x = self.fc1(x)
                spike,mem4 = self.lif4(x, mem4)
                x = spike

                s4 += spike.sum()/(x.shape[0]*1024*128)

                x_spiketime[:,:,:,i] = x

            x = torch.mean(x_spiketime, 3)
            x = self.fc2(x)

            return s1/n_spikes, s2/n_spikes, s3/n_spikes, s4/n_spikes, x

        def get_grid(self, shape, device):
            # The grid of the solution
            batchsize, size_x = shape[0], shape[1]
            gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
            gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
            return gridx.to(device)

    # %%

    """ Model configurations """

    ntrain = 1000
    ntest = 100

    sub = 2**3
    h = 2**13 // sub
    s = h

    batch_size = 10
    learning_rate = 5e-5

    epochs = 500
    step_size = 100
    gamma = 0.5

    level = 8
    width = 64

    # %%

    """ Read data """

    import scipy.io as sio

    dataloader = sio.loadmat('/DATA/SG/WNO/data/burgers_data_R10.mat')

    x_data = torch.tensor(dataloader['a'][:,::sub], dtype=torch.float)
    y_data = torch.tensor(dataloader['u'][:,::sub], dtype=torch.float)

    x_train = x_data[:ntrain,:]
    y_train = y_data[:ntrain,:]
    x_test = x_data[-ntest:,:]
    y_test = y_data[-ntest:,:]

    max_x = torch.max(x_train)
    min_x = torch.min(x_train)

    x_train = (x_train-min_x)/(max_x-min_x)
    x_test = (x_test-min_x)/(max_x-min_x)

    n_spikes = 10

    x_train_re = torch.zeros([ntrain, n_spikes, 1024, 1], dtype = torch.float)
    x_test_re = torch.zeros([ntest, n_spikes, 1024, 1], dtype = torch.float)

    for i in range(0,ntrain):
        x_train_re[i] = encode(x_train[i].detach().clone()[:, None], n_spikes).permute(2,0,1)

    for i in range(0,ntest):
        x_test_re[i] = encode(x_test[i].detach().clone()[:, None], n_spikes).permute(2,0,1)

    x_train = x_train.reshape(ntrain,s,1)
    x_test = x_test.reshape(ntest,s,1)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train_re, y_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_re, y_test), batch_size=batch_size, shuffle=False)

    # %%

    """ The model definition """
    model = WNO1d(level, width, x_train[0:1].permute(0, 2, 1)).to(device)
    print(count_params(model))

    from torchinfo import summary
    print(summary(model, input_size=(batch_size, 10, 1024, 1)))

    # %%

    """ Training and testing """

    import pickle

    model = WNO1d(level, width, x_train[0:1].permute(0, 2, 1)).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    myloss = LpLoss(size_average=False)
    for ep in range(epochs):
        model.train()
        t1 = default_timer()
        train_mse = 0
        train_l2 = 0

        itr_tr = 0

        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            s1,s2,s3,s4,out = model(x)

            itr_tr += 1

            mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
            l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))

            new_loss = l2

            new_loss.backward() # new loss

            optimizer.step()
            train_mse += mse.item()
            train_l2 += l2.item()

        scheduler.step()
        model.eval()
        test_l2 = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)

                s1,s2,s3,s4,out = model(x)
                test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

        train_mse /= len(train_loader)
        train_l2 /= ntrain
        test_l2 /= ntest

        t2 = default_timer()

        if ep%1 == 0:
            print('%5d %10.4f %15.4e %15.4e %15.4e %10.4f %10.4f %10.4f %10.4f'%(ep, t2-t1, train_mse,
                                                                                 train_l2, test_l2, s1,
                                                                                 s2, s3, s4))

    # %%

    filename = './model/RNG_'+str(rng_itrs)+'_Model_1DBurgers_1000TDS_WNO_SNN_TE_LearnBeTh_LR0p0001.pt'
    torch.save(model.state_dict(), filename)


    # %%

    """ Prediction """

    import pickle

    pred = torch.zeros([y_test.shape[0], y_test.shape[1]])
    myloss = LpLoss(size_average=False)

    filename = './model/RNG_'+str(rng_itrs)+'_Model_1DBurgers_1000TDS_WNO_SNN_TE_LearnBeTh_LR0p0001.pt'

    loaded_model = WNO1d(level, width, x_train[0:1].permute(0, 2, 1)).to(device)
    loaded_model.load_state_dict(torch.load(filename))

    index = 0
    test_e = torch.zeros(y_test.shape[0])
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test_re, y_test), batch_size=1, shuffle=False)

    ts1 = 0
    ts2 = 0
    ts3 = 0
    ts4 = 0

    with torch.no_grad():
        for x, y in test_loader:
            test_l2 = 0
            x, y = x.to(device), y.to(device)

            s1,s2,s3,s4,out = loaded_model(x)
            pred[index,:] = out.view(-1)

            test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
            test_e[index] = test_l2

            index = index + 1
            ts1 += s1
            ts2 += s2
            ts3 += s3
            ts4 += s4

    print('%.4f'%(100*torch.mean(test_e)))
    print('%.4f %.4f %.4f %.4f'%(100*ts1.item()/ntest, 100*ts2.item()/ntest, 100*ts3.item()/ntest, 100*ts4.item()/ntest))

    # %%

    """ Plotting """

    m = pred.numpy()

    for i in range(0,y_test.shape[0],10):
        plt.plot(y_test[i, :].numpy(), 'r', label='Actual')
        plt.plot(m[i,:], 'k', label='Prediction')
    plt.show()