In [2]:
import numpy as np
import torch
import torch.nn as nn
import os
from dataclasses import dataclass
from neural_verification import *
import matplotlib.pyplot as plt
import itertools
import copy
from sklearn.linear_model import LinearRegression


seed = 0
np.random.seed(seed)
torch.manual_seed(seed)
device = torch.device("cuda")


class ModifiedRNN(GeneralRNN):
    def forward_sequence(self, x):
        """This function takes in a sequence of inputs and returns a sequence of outputs
        as well as the final hidden state."""
        # x shape: (batch_size, sequence_length, input_dim)
        batch_size = x.size(0)
        seq_length = x.size(1)
        hidden = torch.zeros(batch_size, self.hidden_dim).to(self.device)
        assert x.size(2) == self.config.input_dim
        #assert x.device == self.device
        outs = []
        for i in range(seq_length):
            out, hidden = self.forward(x[:,i,:], hidden)
            if i == seq_length-2:
                hidden2 = hidden.clone()
            outs.append(out)
        # out shape: (batch_size, sequence_length, output_dim)
        return torch.stack(outs).permute(1,0,2), hidden, hidden2
        


data = torch.load("../tasks/rnn_identity_numerical/data.pt", map_location=device)
config = torch.load("./rnn_models/rnn_identity_numerical/model_config.pt", map_location=device)
model = ModifiedRNN(config, device=device)
model.load_state_dict(torch.load("./rnn_models/rnn_identity_numerical/model_perfect.pt",map_location=device))

<All keys matched successfully>

In [3]:
hidden = model.forward_sequence(data[0].unsqueeze(dim=2))[1].cpu().detach().numpy()
hidden_last = model.forward_sequence(data[0].unsqueeze(dim=2))[2].cpu().detach().numpy()
D = hidden.shape[1]
thres = 1e-1

In [4]:
def subset(arr, criterion=None, batch_size = 32):
    '''if arr contains too many samples, computation can be expensive.
    Choosing a subset of samples if usually enough for lattice detection.'''
    if criterion != None:
        arr = arr[np.where(criterion(arr))[0]]
    batch_id = np.random.choice(arr.shape[0], batch_size)
    arr = arr[batch_id]
    return arr  


# GCD

def get_vol_arr(x):
    '''Given n points in R^D, return all possible volumes of D+1 points'''
    num = x.shape[0]
    groups = np.array(list(itertools.combinations(x, D+1)))
    groups = groups[:,1:,:] - groups[:,[0],:]
    vols = np.abs(np.linalg.det(np.array(groups)))
    return np.array(vols)


def GCD_arr(arr):
    '''Find the greatest common divisor (GCD) for an array'''
    vol_arr = get_vol_arr(arr)
    while True:
        va = arr[[0]]; vb = arr[[1]]; v3 = arr[2:D+1]
        a = np.linalg.det(np.concatenate([va, v3], axis=0))
        b = np.linalg.det(np.concatenate([vb, v3], axis=0))
        
        # should be more robust? Threshold?
        if b == 0 or a == 0:
            if b == 0:
                arr = np.delete(arr, 1, 0)
            if a == 0:
                arr = np.delete(arr, 0, 0)
            continue
        gcd, vbp = GCD_2num_v(a, b, va, vb)
        flag = check_integer_arr(vol_arr/gcd)
        if flag == True:
            break
        else:
            arr = arr[1:]
            arr[0] = vbp
    return np.concatenate([vbp, v3], axis=0)

def GCD_2num_v(a, b, va, vb):
    '''Find the greatest common divisor (GCD) for two number a, b;
    and apply the same rule to two vectors va and vb'''
    
    while True:
        temp = a
        a = b
        b = temp

        temp = va
        va = vb
        vb = temp
        
        proj = np.round(a/b)
        a = a - proj * b
        va = va - proj * vb
        if np.abs(a) < thres:
            break
    return np.abs(b), vb

def check_integer_arr(arr):
    non_integer = np.abs(arr - np.round(arr)) > thres
    all_integer = np.sum(non_integer) == 0
    return all_integer

def normalize_basis(basis):
    basis = copy.deepcopy(basis)
    ii = 0
    while True and ii < 5:
        projs = []
        for i in range(D):
            proj = np.round(np.sum(basis*basis[[i]],axis=1)/np.linalg.norm(basis[i])**2)
            proj[i] = 0
            basis -= proj[:,np.newaxis] * basis[[i]]
            projs.append(proj)
        projs = np.array(projs)
        if np.sum(np.abs(projs)) == 0:
            break
        ii += 1
    basis *= (-1)**(np.sum(basis, axis=1)<0)[:,np.newaxis]
    return basis

In [36]:
# shift lattice such that (0,0) is on lattice
hidden_batch = subset(hidden)
shift_id = np.argmin(np.sum(hidden_batch, axis=1))
shift = copy.deepcopy(hidden_batch[[shift_id]])
hidden_batch_shift = hidden_batch - shift
basis = GCD_arr(hidden_batch_shift)
basis_comp = normalize_basis(basis)

In [37]:
h_int_lattice = np.matmul(hidden - shift, np.linalg.inv(basis_comp))
h_last_int_lattice = np.matmul(hidden_last - shift, np.linalg.inv(basis_comp))
input_lattice = data[0][:,-1].cpu().detach().numpy()
output_lattice = data[1][:,-1].cpu().detach().numpy()

In [38]:
def linear(input, output):
    reg = LinearRegression().fit(input, output)
    score = reg.score(input, output)
    coeff = reg.coef_
    intercept = reg.intercept_
    return score, coeff, intercept 

In [39]:
# (h, x) => h
# (h_last_lattice, input_lattice) => h_lattice

hx = np.concatenate([np.round(h_last_int_lattice), np.round(input_lattice)[:,np.newaxis]], axis=1)
h_next = np.round(h_int_lattice)

reg = linear(hx, h_next)

In [40]:
reg

(0.9999999949643814,
 array([[ 7.1282074e-08, -9.9999577e-01]], dtype=float32),
 array([95.99774], dtype=float32))

$h_{t} = x_t + 92$

In [41]:
# h => y
# h_lattice => output_lattice


reg = linear(np.round(h_int_lattice), output_lattice)

In [44]:
reg

(0.9999999998887231, array([-0.99998945], dtype=float32), 95.99951)

$y_t = - x_t + 92$

In [43]:
# !jupyter nbconvert --to script rnn_identity_numerical.ipynb