In [3]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np

In [4]:
batch_size = 128
data_all = np.load('50kdataNegZ.npy')

class QuadrotorDataset(Dataset):
  def __init__(self, data_path = 'testing.csv', time_size=1, state_size = 9, ctrlinput_size = 4, training=True, data_all = data_all):
    self.input_size = state_size + ctrlinput_size
    self.output_size = ctrlinput_size

    self.data = data_all

    np.random.shuffle(data_all)
    train_size = int(len(data_all)*0.8)


    if training:
      self.data = data_all[0:train_size]
    else:
      self.data = data_all[train_size:]

    print(self.data.shape)

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    # print(self.data[idx,:])
    inputs = self.data[idx, 0: self.input_size]
    outputs = self.data[idx, self.input_size:]
    # print(type(inputs))
    # print(type(outputs))
    # print(inputs)
    # print(inputs.dtype)
    # print("h")
    return torch.FloatTensor(inputs), torch.FloatTensor(outputs)
  

train_dataset = QuadrotorDataset(training=True)
test_dataset = QuadrotorDataset(training=False)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
)

(40000, 17)
(10000, 17)


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import numpy as np
import time
from torch.autograd import Function


In [6]:
class MONSingleFc(nn.Module):
    """ Simple MON linear class, just a single full multiply. """

    def __init__(self, in_dim, out_dim, m=1.0):
        super().__init__()
        self.U = nn.Linear(in_dim, out_dim)
        self.A = nn.Linear(out_dim, out_dim, bias=False)
        self.B = nn.Linear(out_dim, out_dim, bias=False)
        self.m = m

    def x_shape(self, n_batch):
        return (n_batch, self.U.in_features)

    def z_shape(self, n_batch):
        return ((n_batch, self.A.in_features),)

    def forward(self, x, *z):
        return (self.U(x) + self.multiply(*z)[0],)

    def bias(self, x):
        return (self.U(x),)

    def multiply(self, *z):
        ATAz = self.A(z[0]) @ self.A.weight
        z_out = (1 - self.m) * z[0] - ATAz + self.B(z[0]) - z[0] @ self.B.weight
        return (z_out,)

    def multiply_transpose(self, *g):
        ATAg = self.A(g[0]) @ self.A.weight
        g_out = (1 - self.m) * g[0] - ATAg - self.B(g[0]) + g[0] @ self.B.weight
        return (g_out,)

    def init_inverse(self, alpha, beta):
        I = torch.eye(self.A.weight.shape[0], dtype=self.A.weight.dtype,
                      device=self.A.weight.device)
        W = (1 - self.m) * I - self.A.weight.T @ self.A.weight + self.B.weight - self.B.weight.T
        self.Winv = torch.inverse(alpha * I + beta * W)

    def inverse(self, *z):
        return (z[0] @ self.Winv.transpose(0, 1),)

    def inverse_transpose(self, *g):
        return (g[0] @ self.Winv,)


class MONReLU(nn.Module):
    def forward(self, *z):
        return tuple(F.relu(z_) for z_ in z)

    def derivative(self, *z):
        return tuple((z_ > 0).type_as(z[0]) for z_ in z)


In [7]:
class Meter(object):
    """Computes and stores the min, max, avg, and current values"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.max = -float("inf")
        self.min = float("inf")

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.max = max(self.max, val)
        self.min = min(self.min, val)


class SplittingMethodStats(object):
    def __init__(self):
        self.fwd_iters = Meter()
        self.bkwd_iters = Meter()
        self.fwd_time = Meter()
        self.bkwd_time = Meter()

    def reset(self):
        self.fwd_iters.reset()
        self.fwd_time.reset()
        self.bkwd_iters.reset()
        self.bkwd_time.reset()

    def report(self):
        print('Fwd iters: {:.2f}\tFwd Time: {:.4f}\tBkwd Iters: {:.2f}\tBkwd Time: {:.4f}\n'.format(
                self.fwd_iters.avg, self.fwd_time.avg,
                self.bkwd_iters.avg, self.bkwd_time.avg))


class MONPeacemanRachford(nn.Module):

    def __init__(self, linear_module, nonlin_module, alpha=1.0, tol=1e-5, max_iter=50, verbose=False):
        super().__init__()
        self.linear_module = linear_module
        self.nonlin_module = nonlin_module
        self.alpha = alpha
        self.tol = tol
        self.max_iter = max_iter
        self.verbose = verbose
        self.stats = SplittingMethodStats()
        self.save_abs_err = False

    def forward(self, x):
        """ Forward pass of the MON, find an equilibirum with forward-backward splitting"""

        start = time.time()
        # Run the forward pass _without_ tracking gradients
        self.linear_module.init_inverse(1 + self.alpha, -self.alpha)
        with torch.no_grad():
            z = tuple(torch.zeros(s, dtype=x.dtype, device=x.device)
                      for s in self.linear_module.z_shape(x.shape[0]))
            u = tuple(torch.zeros(s, dtype=x.dtype, device=x.device)
                      for s in self.linear_module.z_shape(x.shape[0]))

            n = len(z)
            bias = self.linear_module.bias(x)

            err = 1.0
            it = 0
            errs = []
            while (err > self.tol and it < self.max_iter):
                u_12 = tuple(2 * z[i] - u[i] for i in range(n))
                z_12 = self.linear_module.inverse(*tuple(u_12[i] + self.alpha * bias[i] for i in range(n)))
                u = tuple(2 * z_12[i] - u_12[i] for i in range(n))
                zn = self.nonlin_module(*u)

                if self.save_abs_err:
                    fn = self.nonlin_module(*self.linear_module(x, *zn))
                    err = sum((zn[i] - fn[i]).norm().item() / (zn[i].norm().item()) for i in range(n))
                    errs.append(err)
                else:
                    err = sum((zn[i] - z[i]).norm().item() / (1e-6 + zn[i].norm().item()) for i in range(n))
                z = zn
                it = it + 1

        if self.verbose:
            print("Forward: ", it, err)

        # Run the forward pass one more time, tracking gradients, then backward placeholder
        zn = self.linear_module(x, *z)
        zn = self.nonlin_module(*zn)

        zn = self.Backward.apply(self, *zn)
        self.stats.fwd_iters.update(it)
        self.stats.fwd_time.update(time.time() - start)
        self.errs = errs
        return zn

    class Backward(Function):
        @staticmethod
        def forward(ctx, splitter, *z):
            ctx.splitter = splitter
            ctx.save_for_backward(*z)
            return z

        @staticmethod
        def backward(ctx, *g):
            start = time.time()
            sp = ctx.splitter
            n = len(g)
            z = ctx.saved_tensors
            j = sp.nonlin_module.derivative(*z)
            I = [j[i] == 0 for i in range(n)]
            d = [(1 - j[i]) / j[i] for i in range(n)]
            v = tuple(j[i] * g[i] for i in range(n))

            z = tuple(torch.zeros(s, dtype=g[0].dtype, device=g[0].device)
                      for s in sp.linear_module.z_shape(g[0].shape[0]))
            u = tuple(torch.zeros(s, dtype=g[0].dtype, device=g[0].device)
                      for s in sp.linear_module.z_shape(g[0].shape[0]))

            err = 1.0
            errs=[]
            it = 0
            while (err >sp.tol and it < sp.max_iter):
                u_12 = tuple(2 * z[i] - u[i] for i in range(n))
                z_12 = sp.linear_module.inverse_transpose(*u_12)
                u = tuple(2 * z_12[i] - u_12[i] for i in range(n))
                zn = tuple((u[i] + sp.alpha * (1 + d[i]) * v[i]) / (1 + sp.alpha * d[i]) for i in range(n))
                for i in range(n):
                    zn[i][I[i]] = v[i][I[i]]

                err = sum((zn[i] - z[i]).norm().item() / (1e-6 + zn[i].norm().item()) for i in range(n))
                errs.append(err)
                z = zn
                it = it + 1

            if sp.verbose:
                print("Backward: ", it, err)

            dg = sp.linear_module.multiply_transpose(*zn)
            dg = tuple(g[i] + dg[i] for i in range(n))

            sp.stats.bkwd_iters.update(it)
            sp.stats.bkwd_time.update(time.time() - start)
            sp.errs = errs
            return (None,) + dg

In [8]:
def expand_args(defaults, kwargs):
    d = defaults.copy()
    for k, v in kwargs.items():
        d[k] = v
    return d

MON_DEFAULTS = {
    'alpha': 1.0,
    'tol': 1e-5,
    'max_iter': 50
}

class SingleFcNet(nn.Module):

    def __init__(self, splittingMethod, in_dim=13, intermediate_dim=100, out_dim=4, m=0.1, **kwargs):
        super().__init__()
        linear_module = MONSingleFc(in_dim, intermediate_dim, m=m)
        nonlin_module = MONReLU()
        self.mon = splittingMethod(linear_module, nonlin_module, **expand_args(MON_DEFAULTS, kwargs))
        self.Wout = nn.Linear(intermediate_dim, out_dim)

    def forward(self, x):
        x = x.view(x.shape[0], -1)
        z = self.mon(x)
        return self.Wout(z[-1])

In [16]:
loaded_model = SingleFcNet(MONPeacemanRachford,
                        in_dim = 13,
                        out_dim = 4,
                        alpha=1.0,
                        max_iter=300,
                        tol=1e-2,
                        m=1.0)

In [17]:
PATH = "DEQ_model50k.pt"
loaded_model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [20]:
i = 0
bad = 0
num_trials = 1000
for data in train_dataloader:
    input_data, output_data = data
    # print(input_data[0])
    # print(output_data[0])
    output_pred = loaded_model(input_data)
    # print(output_pred[0])
    if ((output_pred[0] - output_data[0]) < torch.tensor([.08]*4)).all():
        pass
        # print("Good Prediction!")
    else:
        bad +=1
        # print("Bad Prediction")

    i += 1
    if i >= 100:
        print(f"error pct: {(bad/num_trials)*100}%")
        break

error pct: 0.8%


In [21]:
PATH = "DEQModel_New.pt"
loaded_model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [24]:
i = 0
bad = 0
num_trials = 1000
for data in train_dataloader:
    input_data, output_data = data
    # print(input_data[0])
    # print(output_data[0])
    output_pred = loaded_model(input_data)
    # print(output_pred[0])
    if ((output_pred[0] - output_data[0]) < torch.tensor([.08]*4)).all():
        pass
        # print("Good Prediction!")
    else:
        bad +=1
        # print("Bad Prediction")

    i += 1
    if i >= 100:
        print(f"error pct: {(bad/num_trials)*100}%")
        break

error pct: 2.1%
