In [1]:
import numpy as np
import torch
import torch.nn as nn
import os
from dataclasses import dataclass
from neural_verification import *
import matplotlib.pyplot as plt
import itertools
import copy

# This GeneralRNN has been defined in neural_verification, but I modified it:
# (1) comment out some assertions which throw out errors but aren't really errors?
# (2) Not only return hidden (last hidden states), but also hidden2 (second-to-last hidden states)
@dataclass
class GeneralRNNConfig:
    input_dim: int = 2
    output_dim: int = 1
    hidden_dim: int = 40
    hidden_mlp_depth: int = 2 # this would be 1 hidden layer
    hidden_mlp_width: int = 100
    output_mlp_depth: int = 2 # this would be 1 hidden layer
    output_mlp_width: int = 100
    activation: type = nn.SiLU

class GeneralRNN(nn.Module):
    def __init__(self, config, device='cpu'):
        super().__init__()
        self.config = config
        self.hidden_dim = config.hidden_dim
        hmlp_config = MLPConfig(
            config.hidden_dim + config.input_dim, 
            config.hidden_dim, 
            config.hidden_mlp_width, 
            config.hidden_mlp_depth, 
            config.activation
        )
        self.hmlp = MLP(hmlp_config).to(device)
        ymlp_config = MLPConfig(
            config.hidden_dim, 
            config.output_dim, 
            config.output_mlp_width, 
            config.output_mlp_depth, 
            config.activation
        )
        self.ymlp = MLP(ymlp_config).to(device)
        self.device = device

    def forward(self, x, h=None):
        """The transition is given by:
            h_t = f([h_{t-1}, x_t])
            y_t = g(h_t)
        where f and g are MLPs.

        This function takes in the input and the hidden state and 
        returns an output and a hidden state.
        """
        # x shape: (batch_size, input_dim)
        # h shape: (batch_size, hidden_dim)
        if h is None:
            h = torch.zeros(x.size(0), self.hidden_dim).to(self.device)
        else:
            assert h.size(0) == x.size(0)
            assert h.size(1) == self.hidden_dim
            #assert h.device == self.device
        assert x.size(1) == self.config.input_dim
        #assert x.device == self.device
        hx = torch.cat((h, x), dim=1)
        h = self.hmlp(hx)
        y = self.ymlp(h)
        return y, h
    
    def forward_sequence(self, x):
        """This function takes in a sequence of inputs and returns a sequence of outputs
        as well as the final hidden state."""
        # x shape: (batch_size, sequence_length, input_dim)
        batch_size = x.size(0)
        seq_length = x.size(1)
        hidden = torch.zeros(batch_size, self.hidden_dim).to(self.device)
        assert x.size(2) == self.config.input_dim
        #assert x.device == self.device
        outs = []
        for i in range(seq_length):
            out, hidden = self.forward(x[:,i,:], hidden)
            if i == seq_length-2:
                hidden2 = hidden.clone()
            outs.append(out)
        # out shape: (batch_size, sequence_length, output_dim)
        return torch.stack(outs).permute(1,0,2), hidden, hidden2

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

data = torch.load("../tasks/rnn_identity_numerical/data.pt")
config = torch.load("../tasks/rnn_identity_numerical/model_config.pt", map_location=torch.device('cpu'))
model = GeneralRNN(config,)
model.load_state_dict(torch.load("../tasks/rnn_identity_numerical/model_perfect.pt",map_location=torch.device('cpu')))

<All keys matched successfully>

In [2]:
hidden = model.forward_sequence(data[0].unsqueeze(dim=2))[1].cpu().detach().numpy()
hidden_last = model.forward_sequence(data[0].unsqueeze(dim=2))[2].cpu().detach().numpy()
D = hidden.shape[1]
thres = 1e-1

In [3]:
outs = model.forward_sequence(data[0].unsqueeze(dim=2))[0].cpu().detach().numpy()


In [4]:
outs[0]

array([[69.      ],
       [46.      ],
       [58.      ],
       [89.      ],
       [97.      ],
       [73.      ],
       [88.00001 ],
       [52.      ],
       [42.999996],
       [81.00001 ]], dtype=float32)

In [5]:
data[0].unsqueeze(dim=2)[0]

tensor([[69],
        [46],
        [58],
        [89],
        [97],
        [73],
        [88],
        [52],
        [43],
        [81]], dtype=torch.int8)

In [6]:
hidden[0]

array([-64.661835], dtype=float32)

In [7]:
hidden_last[0]

array([-34.50504], dtype=float32)

In [8]:
first_how_many=10

In [9]:
reg = data[0][:first_how_many]

In [10]:
# from RNN_to_DFA_vedang import NFA
# weights = model.state_dict()

# nfa = NFA(10, weights)
# from RNN_to_DFA_vedang import NFA
# weights = model.state_dict()

# nfa = NFA(10, weights)

In [11]:
model

GeneralRNN(
  (hmlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=2, out_features=1, bias=True)
    )
  )
  (ymlp): MLP(
    (mlp): Sequential(
      (0): Linear(in_features=1, out_features=1, bias=True)
    )
  )
)

In [12]:
# for k,v in nfa.weights.items():
#     print((k,v.shape))

In [13]:
"""
('in2hidden.weight', torch.Size([2, 3]))
('hidden2hidden.weight', torch.Size([2, 2]))
('hidden2hidden.bias', torch.Size([2]))
('hidden2out.weight', torch.Size([3, 2]))
('hidden2out.bias', torch.Size([3]))
"""

"\n('in2hidden.weight', torch.Size([2, 3]))\n('hidden2hidden.weight', torch.Size([2, 2]))\n('hidden2hidden.bias', torch.Size([2]))\n('hidden2out.weight', torch.Size([3, 2]))\n('hidden2out.bias', torch.Size([3]))\n"

In [14]:
# from RNN_to_DFA_vedang import NFA
# weights = model.state_dict()

# nfa = NFA(10, weights)
# points = list(np.concatenate([hidden[:10], hidden_last[:10]], axis=1))


# nfa.accommodate_points(points, 0.8)
# nfa.find_transitions()

# points = np.stack(points, axis=1)

In [32]:
def subset(arr, criterion=None, batch_size = 32):
    '''if arr contains too many samples, computation can be expensive.
    Choosing a subset of samples if usually enough for lattice detection.'''
    if criterion != None:
        arr = arr[np.where(criterion(arr))[0]]
    batch_id = np.random.choice(arr.shape[0], batch_size)
    arr = arr[batch_id]
    return arr  


# GCD

def get_vol_arr(x):
    '''Given n points in R^D, return all possible volumes of D+1 points'''
    num = x.shape[0]
    groups = np.array(list(itertools.combinations(x, D+1)))
    groups = groups[:,1:,:] - groups[:,[0],:]
    vols = np.abs(np.linalg.det(np.array(groups)))
    return np.array(vols)


def GCD_arr(arr):
    '''Find the greatest common divisor (GCD) for an array'''
    vol_arr = get_vol_arr(arr)
    while True:
        va = arr[[0]]; vb = arr[[1]]; v3 = arr[2:D+1]
        a = np.linalg.det(np.concatenate([va, v3], axis=0))
        b = np.linalg.det(np.concatenate([vb, v3], axis=0))
        
        # should be more robust? Threshold?
        if b == 0 or a == 0:
            if b == 0:
                arr = np.delete(arr, 1, 0)
            if a == 0:
                arr = np.delete(arr, 0, 0)
            continue
        gcd, vbp = GCD_2num_v(a, b, va, vb)
        flag = check_integer_arr(vol_arr/gcd)
        if flag == True:
            break
        else:
            arr = arr[1:]
            arr[0] = vbp
    return np.concatenate([vbp, v3], axis=0)

def GCD_2num_v(a, b, va, vb):
    '''Find the greatest common divisor (GCD) for two number a, b;
    and apply the same rule to two vectors va and vb'''
    
    while True:
        temp = a
        a = b
        b = temp

        temp = va
        va = vb
        vb = temp
        
        proj = np.round(a/b)
        a = a - proj * b
        va = va - proj * vb
        if np.abs(a) < thres:
            break
    return np.abs(b), vb

def check_integer_arr(arr):
    non_integer = np.abs(arr - np.round(arr)) > thres
    all_integer = np.sum(non_integer) == 0
    return all_integer

def normalize_basis(basis):
    basis = copy.deepcopy(basis)
    ii = 0
    while True and ii < 5:
        projs = []
        for i in range(D):
            proj = np.round(np.sum(basis*basis[[i]],axis=1)/np.linalg.norm(basis[i])**2)
            proj[i] = 0
            basis -= proj[:,np.newaxis] * basis[[i]]
            projs.append(proj)
        projs = np.array(projs)
        if np.sum(np.abs(projs)) == 0:
            break
        ii += 1
    basis *= (-1)**(np.sum(basis, axis=1)<0)[:,np.newaxis]
    return basis

In [33]:

# shift lattice such that (0,0) is on lattice
hidden_batch = subset(hidden)
shift_id = np.argmin(np.sum(hidden_batch, axis=1))
shift = copy.deepcopy(hidden_batch[[shift_id]])
hidden_batch_shift = hidden_batch - shift
basis = GCD_arr(hidden_batch_shift)
basis_comp = normalize_basis(basis)

TypeError: sum() received an invalid combination of arguments - got (out=NoneType, axis=int, ), but expected one of:
 * (*, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: out, axis
 * (tuple of ints dim, bool keepdim, *, torch.dtype dtype)
 * (tuple of names dim, bool keepdim, *, torch.dtype dtype)


In [None]:
h_int_lattice = np.matmul(hidden - shift, np.linalg.inv(basis_comp))
h_last_int_lattice = np.matmul(hidden_last - shift, np.linalg.inv(basis_comp))
input_lattice = data[0][:,-1].detach().numpy()
output_lattice = data[1][:,-1].detach().numpy()

In [None]:
# (h, x) => h
# (h_last_lattice, input_lattice) => h_lattice

hx = np.concatenate([np.round(h_last_int_lattice), np.round(input_lattice)[:,np.newaxis]], axis=1)
h_next = np.round(h_int_lattice)

from sklearn.linear_model import LinearRegression
reg1 = LinearRegression().fit(hx, h_next)
reg1.score(hx, h_next)

In [34]:
reg1.coef_, reg1.intercept_

(array([[ 3.5641033e-08, -9.9999571e-01]], dtype=float32),
 array([95.99774], dtype=float32))

In [35]:
reg1.predict([[1,1]])

array([[94.99774603]])

$h_{t} = x_t + 92$

In [36]:
# h => y
# h_lattice => output_lattice

from sklearn.linear_model import LinearRegression
reg2 = LinearRegression().fit(np.round(h_int_lattice), output_lattice)
reg2.score(np.round(h_int_lattice), output_lattice)

0.9999999998936863

In [37]:
reg2.coef_, reg2.intercept_

(array([-0.9999897], dtype=float32), 95.99952)

$y_t = - x_t + 92$

In [38]:
hstart = np.zeros_like(hidden)

In [39]:
testx = data[0].unsqueeze(dim=2)

In [40]:
batch_size = testx.size(0)
seq_length = testx.size(1)

#hack
hidden_dim = 1
hidden = torch.zeros(batch_size, hidden_dim)

outs = []
for i in range(seq_length):
    hx = torch.cat((hidden, testx[:,i,:]), dim=1)
    out, hidden = torch.tensor(reg2.predict(hidden)), torch.tensor(reg1.predict(hx))
    out=out.unsqueeze(1)
    outs.append(out)
outs, hidden = torch.stack(outs).permute(1,0,2), hidden

In [44]:
outs = outs.squeeze(2)

In [45]:
outs.shape, hidden.shape

(torch.Size([900000, 10]), torch.Size([900000, 1]))

In [47]:
data[1].shape

torch.Size([900000, 10])

In [48]:
outs[0]

tensor([95.9995, 69.0018, 46.0021, 58.0019, 89.0015, 97.0014, 73.0017, 88.0015,
        52.0020, 43.0021])

In [49]:
data[0][0]

tensor([69, 46, 58, 89, 97, 73, 88, 52, 43, 81], dtype=torch.int8)

In [51]:
data[1][0]

tensor([69, 46, 58, 89, 97, 73, 88, 52, 43, 81], dtype=torch.int8)

In [None]:
np.unique