In [45]:
import scipy.io as sio
import matplotlib.pyplot as plt
import os
import numpy as np
import logging
import argparse
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch
import scipy
from timeit import default_timer

import Wavelets


In [2]:
parser = argparse.ArgumentParser()

parser.add_argument("-data_fp", default="/home/owen/projects/fourier_neural_operator/data/2021-03-17_training_Burgers_data_GRF1.mat")
parser.add_argument("-plots_dir", default="/home/owen/projects/fourier_neural_operator/experiments/03_use_Haar_DWT/plots")
parser.add_argument("-preds_dir", default="/home/owen/projects/fourier_neural_operator/experiments/03_use_Haar_DWT/preds")
# parser.add_argument("-results_df", default="/home/owen/projects/fourier_neural_operator/experiments/03_use_Haar_DWT/experiment_results.txt")

parser.add_argument('--subsample_rate', type=int, default=2**3)
parser.add_argument('--grid_size', type=int, default=2**13)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--freq_modes', type=int, default=16)
parser.add_argument('--l1_lambda', type=float, default=0.)
args = parser.parse_args(args=[])

fmt = "%(asctime)s: %(levelname)s - %(message)s"
time_fmt = '%Y-%m-%d %H:%M:%S'
logging.basicConfig(level=logging.INFO,
                    format=fmt,
                    datefmt=time_fmt)

In [3]:
class LpLoss(object):
    def __init__(self, d=2, p=2, size_average=True, reduction=True):
        super(LpLoss, self).__init__()

        #Dimension and Lp-norm type are postive
        assert d > 0 and p > 0

        self.d = d
        self.p = p
        self.reduction = reduction
        self.size_average = size_average

    def abs(self, x, y):
        num_examples = x.size()[0]

        #Assume uniform mesh
        h = 1.0 / (x.size()[1] - 1.0)

        all_norms = (h**(self.d/self.p))*torch.norm(x.view(num_examples,-1) - y.view(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(all_norms)
            else:
                return torch.sum(all_norms)

        return all_norms

    def rel(self, x, y):
        num_examples = x.size()[0]

        diff_norms = torch.norm(x.reshape(num_examples,-1) - y.reshape(num_examples,-1), self.p, 1)
        y_norms = torch.norm(y.reshape(num_examples,-1), self.p, 1)

        if self.reduction:
            if self.size_average:
                return torch.mean(diff_norms/y_norms)
            else:
                return torch.sum(diff_norms/y_norms)

        return diff_norms/y_norms

    def __call__(self, x, y):
        return self.rel(x, y)


# reading data
class MatReader(object):
    def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True):
        super(MatReader, self).__init__()

        self.to_torch = to_torch
        self.to_cuda = to_cuda
        self.to_float = to_float

        self.file_path = file_path

        self.data = None
        self.old_mat = None
        self._load_file()

    def _load_file(self):
        self.data = scipy.io.loadmat(self.file_path)
        self.old_mat = True

    def load_file(self, file_path):
        self.file_path = file_path
        self._load_file()

    def read_field(self, field):
        x = self.data[field]

        if not self.old_mat:
            x = x[()]
            x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1))

        if self.to_float:
            x = x.astype(np.float32)

        if self.to_torch:
            x = torch.from_numpy(x)

            if self.to_cuda:
                x = x.cuda()

        return x

    def set_cuda(self, to_cuda):
        self.to_cuda = to_cuda

    def set_torch(self, to_torch):
        self.to_torch = to_torch

    def set_float(self, to_float):
        self.to_float = to_float


In [4]:


class SimpleBlock1d(nn.Module):
    def __init__(self, modes, width, level, keep):
        super(SimpleBlock1d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .

        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.modes1 = modes
        self.width = width
        self.level = level
        self.keep = keep
        self.fc0 = nn.Linear(2, self.width) # input channel is 2: (a(x), x)

        self.wave0 = WaveletBlock1d(self.level, self.width, self.keep)
        self.wave1 = WaveletBlock1d(self.level, self.width, self.keep)
        self.wave2 = WaveletBlock1d(self.level, self.width, self.keep)
        self.wave3 = WaveletBlock1d(self.level, self.width, self.keep)
        
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)


        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):

        x = self.fc0(x)
        x = x.permute(0, 2, 1)

        x1 = self.wave0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.relu(x1)

        x1 = self.wave1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.relu(x1)

        x1 = self.wave2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.relu(x1)

        x1 = self.wave3(x)
        x2 = self.w3(x)
        x = x1 + x2

        x = x1.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

class Net1d(nn.Module):
    def __init__(self, modes, width, level, keep):
        super(Net1d, self).__init__()

        """
        A wrapper function
        """

        self.conv1 = SimpleBlock1d(modes, width, level, keep)


    def forward(self, x):
        x = self.conv1(x)
        return x.squeeze()

    def count_params(self):
        c = 0
        for p in self.parameters():
            c += reduce(operator.mul, list(p.size()))

        return c


In [29]:
def main(args):

    # Figure out CUDA
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info("Running computation on device: {}".format(device))

    ################################################################
    #  configurations
    ################################################################
    ntrain = 100
    ntest = 10

    # sub = 2**3 #subsampling rate
    # h = 2**13 // sub #total grid size divided by the subsampling rate
    sub = args.subsample_rate
    h = args.grid_size // sub
    s = h

    batch_size = 10
    learning_rate = 0.001

    epochs = 4
    step_size = 100
    gamma = 0.5

    modes = 12
    width = 4
    
    level = int(Wavelets.HaarDWT().max_dwt_level(s))
    logging.info("Using DWT level {}".format(level))
    
    keep = int(s)

    # results_dd stores trial results and metadata. It will be printed as
    # a single line to a text file at args.results_fp
    results_dd = {'ntrain': ntrain,
                    'ntest': ntest,
                    'sub': sub,
                    'effective_grid_size': s,
                    'epochs': epochs,
                    'l1_lambda': args.l1_lambda,
                    'modes': modes,
                    'width': width}

    ################################################################
    # read data
    ################################################################

    # Data is of the shape (number of samples, grid size)
    dataloader = MatReader(args.data_fp)
    x_data = dataloader.read_field('a')[:,::sub]
    y_data = dataloader.read_field('u')[:,::sub]

    x_train = x_data[:ntrain,:]
    y_train = y_data[:ntrain,:]
    x_test = x_data[-ntest:,:]
    y_test = y_data[-ntest:,:]
    
    del x_data
    del y_data

    # cat the locations information
    grid = np.linspace(0, 2*np.pi, s).reshape(1, s, 1)
    grid = torch.tensor(grid, dtype=torch.float)
    x_train = torch.cat([x_train.reshape(ntrain,s,1), grid.repeat(ntrain,1,1)], dim=2)
    x_test = torch.cat([x_test.reshape(ntest,s,1), grid.repeat(ntest,1,1)], dim=2)

    train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

    # model
    model = Net1d(modes, width, level, keep).to(device)
#     logging.info("Number of model parameters: %i" % model.count_params())


    ################################################################
    # training and evaluation
    ################################################################
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    myloss = LpLoss(size_average=False)
    for ep in range(epochs):
        model.train()
        t1 = default_timer()
        train_mse = 0
        train_l2 = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            # print("TRAINING X SHAPE: {}".format(x.size()))
            # print("TRAINING Y SHAPE: {}".format(y.size()))

            optimizer.zero_grad()
            out = model(x)

            mse = F.mse_loss(out, y, reduction='mean')
            # mse.backward()
            l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))

            # If we specify a l1 regularization on the weights, we add up the
            # l1 norms of all of the different weights and add this to the loss.
            l1_reg = torch.tensor(0.).to(device)
            for param in model.parameters():
                l1_reg += torch.norm(param, p=1)

            # Our loss function is Lp loss + regulariztion term
            loss = l2 + args.l1_lambda * l1_reg

            # backprop the gradients, then perform one step of the algorithm.
            loss.backward()
            optimizer.step()


            train_mse += mse.item()
            train_l2 += l2.item()

        scheduler.step()
        model.eval()
        test_l2 = 0.0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)

                out = model(x)
                test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()

        train_mse /= len(train_loader)
        train_l2 /= ntrain
        test_l2 /= ntest

        t2 = default_timer()
        logging.info("Epoch: {}, time: {:.2f}, train_mse: {:.4f}, train_l2: {:.4f}, test_l2: {:.4f}, weights_l1: {:.4f}".format(ep, t2-t1, train_mse, train_l2, test_l2, l1_reg))

    torch.save(model, args.model_fp)
    logging.info("Saved model at {}".format(args.model_fp))

    # Compute training errors:
    train_pred = torch.zeros(y_train.shape)
    train_y = torch.zeros(y_train.shape)
    idx = 0
    with torch.no_grad():
        for x, y in train_loader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            # print(out.size())
            # print(train_pred[idx].size())

            train_pred[batch_size*idx: batch_size*(idx + 1)] = out
            train_y[batch_size*idx: batch_size*(idx + 1)] = y

            idx += 1

    train_pred = train_pred.cpu().numpy()
    train_y = train_y.cpu().numpy()

    results_dd['train_l2_errors'], results_dd['train_l2_normalized_errors'] = find_normalized_errors(train_pred, train_y, 2)
    results_dd['train_linf_errors'], results_dd['train_linf_normalized_errors'] = find_normalized_errors(train_pred, train_y, np.inf)


    pred = torch.zeros(y_test.shape)
    index = 0
    test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False)
    with torch.no_grad():
        for x, y in test_loader:
            test_l2 = 0
            x, y = x.to(device), y.to(device)

            out = model(x)
            pred[index] = out

            test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
            logging.info("Test index {}, test_l2: {:.4f}".format(index, test_l2))
            index = index + 1

    scipy.io.savemat(args.preds_fp, mdict={'pred': pred.cpu().numpy()})



    # I'm doing aggregate error reporting in numpy
    pred_test = pred.cpu().numpy()
    y_test = y_test.cpu().numpy()

    results_dd['test_l2_errors'], results_dd['test_l2_normalized_errors'] = find_normalized_errors(pred_test, y_test, 2)
    results_dd['test_linf_errors'], results_dd['test_linf_normalized_errors'] = find_normalized_errors(pred_test, y_test, np.inf)
    if args.results_fp is not None:
        write_result_to_file(args.results_fp, **results_dd)
        logging.info("Wrote results to {}".format(args.results_fp))
    else:
        logging.info("No results_fp specified, so here are the results")
        logging.info(results_dd)

    logging.info("Finished")


In [52]:
class IHaarDWT(nn.Module):
    """Short summary.

    Attributes
    ----------
    c_filter : pytorch Tensor
        Filter to produce c_i coefficients in DWT. Sometimes referred to as h_0.
    d_filter : type
        Filter to produce d_i coefficients in DWT. Sometimes referred to as h_1.
    padder : pytorch nn Module
        Appropriately pads odd-length arrays
    level : int
        Level of the DWT
    """
    def __init__(self, level=1):
        super().__init__()
        self.c_filter = torch.tensor(np.divide(np.array([1., 1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,2,1))
        self.d_filter = torch.tensor(np.divide(np.array([1., -1.]),
                                                np.sqrt(2)), dtype=torch.float).reshape((1,2,1))
        self.filter_len = 2
        self.level = level
        logging.info("Loading IDWT module with level: {}".format(self.level))

    def max_dwt_level(self, data_len):
        """
        This is a function to compute the maximum level DWT that is possible
        on a 1D input of length data_len. This formula is copied from
        PyWavelets: https://tinyurl.com/y9u7yvbw
        """
        return np.floor(np.log2(data_len / (self.filter_len - 1)))

    def unfilter(self, x):
        logging.info("Unfiltering with input: {}".format(x.size()))
        xlen = x.size()[-1]
        batch_num = x.size()[0]
        x = x.reshape((batch_num, 2, int(xlen / 2))).permute(0,2,1)
        logging.info("After reshaping and permuting, shape is {}".format(x.size()))
        c_out = F.conv1d(x, self.c_filter, stride=1)
        logging.info("C out shape: {}".format(c_out.size()))
        c_out = torch.mm(x, self.c_filter)
        d_out = torch.mm(x, self.d_filter)

        out = torch.hstack((c_out, d_out)).flatten()
        return out

    def forward(self, x, verbose=False):
        # x has shape (1,1,xlen). We need to transform it into (2,xlen / 2)
        xlen = x.size()[-1]

        if xlen % 2:
            raise ValueError("Expected even-length input but recieved length {}".format(xlen))
        for l in range(self.level-1, -1, -1):
            level_i_arr_idx = int(xlen / (2 ** l))
            x_in = x[:,:,:level_i_arr_idx]
            x[:,:,:level_i_arr_idx] = self.unfilter(x_in)
        return x


In [53]:
class WaveletBlock1d(nn.Module):
    def __init__(self, level, width, keep):
        super(WaveletBlock1d, self).__init__()
        self.level = level
        self.width = width
        self.keep = keep
        
        self.DWT = Wavelets.HaarDWT(level=level)
        self.IDWT = IHaarDWT(level=level)
        self.linear_layer = nn.Conv1d(self.width, self.width, 1)
        self.weights = nn.Parameter(torch.tensor())
        
    def forward(self, x):
        logging.info("WaveletBlock input size: {}".format(x.size()))
        # Do DWT row-by-row
        z = torch.zeros(x.size())
        xlen = x.size()[-1]
        for i, row in enumerate(x.split(1, dim=1)): #Check the axis
            out = self.DWT(row)
            z[:,i,:self.keep] = out.view(-1, self.keep)
        # ok so now z has the DWT coefficients
        logging.info("Z size: {}".format(z.size()))
        z = self.linear_layer(z)
        
        out = torch.zeros(z.size())
        for i, row in enumerate(z.split(1, dim=1)):
            logging.info("IDWT Input size: {}".format(row.size()))
            idwt_out = self.IDWT(row)
            logging.info("IDWT Output size: {}".format(idwt_out.size()))
            out[:, i, :] = idwt_out.view(-1, xlen)
        return out

In [54]:
import Wavelets
main(args)

2021-04-07 16:35:23: INFO - Running computation on device: cpu
2021-04-07 16:35:23: INFO - Using DWT level 10
2021-04-07 16:35:29: INFO - Loading IDWT module with level: 10
2021-04-07 16:35:29: INFO - Loading IDWT module with level: 10
2021-04-07 16:35:29: INFO - Loading IDWT module with level: 10
2021-04-07 16:35:29: INFO - Loading IDWT module with level: 10
2021-04-07 16:35:30: INFO - WaveletBlock input size: torch.Size([10, 4, 1024])
2021-04-07 16:35:30: INFO - Z size: torch.Size([10, 4, 1024])
2021-04-07 16:35:30: INFO - IDWT Input size: torch.Size([10, 1, 1024])
2021-04-07 16:35:30: INFO - Unfiltering with input: torch.Size([10, 1, 2])
2021-04-07 16:35:30: INFO - After reshaping and permuting, shape is torch.Size([10, 1, 2])


RuntimeError: Given groups=1, weight of size [1, 2, 1], expected input[10, 1, 2] to have 2 channels, but got 1 channels instead