In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

In [None]:
from numpy.linalg import inv, det
from tqdm import tqdm
from time import sleep

from torch import sigmoid, tanh
from torch import Tensor, exp, log
from torch.nn import Sequential
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
from torch.nn.init import uniform_
from torch.nn.functional import linear

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

paper: Neural Arithmetic Logic Units (https://arxiv.org/pdf/1808.00508.pdf)

In [None]:
# class NacCell(nn.Module):
#     def __init__(self, in_shape, out_shape):
#         super().__init__()
#         self.in_shape = in_shape
#         self.out_shape = out_shape

#         self.W_ = Parameter(Tensor(out_shape, in_shape))
#         self.M_ = Parameter(Tensor(out_shape, in_shape)

#         uniform_(self.W_, 0.0, 1.0), uniform_(self.M_, 0.0, 1.0)
#         self.register_parameter('bias', None)

#     # a = Wx
#     # W = tanh(W) * sigmoid(M)
#     # * is elementwise product
#     def forward(self, X):
#         #print('W:', self.W_.shape, 'X:', X.shape)
#         W = tanh(self.W_) * sigmoid(self.M_)

#         # linear: XW^T + b
#         return linear(X, W, self.bias)
#         #return torch.matmul(X, W.T)

In [None]:
# Cell for addition and subtraction
class ASCell(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape

        self.W_ = Parameter(Tensor(out_shape, in_shape))
        self.M_ = Parameter(Tensor(out_shape, in_shape))

        uniform_(self.W_, 0.0, 1.0), uniform_(self.M_, 0.0, 1.0)
        self.register_parameter('bias', None)

    # a = Wx
    # W = tanh(W) * sigmoid(M)
    # * is elementwise product
    def forward(self, X):
        #print('W:', self.W_.shape, 'X:', X.shape)
        W = tanh(self.W_) * sigmoid(self.M_)

        # linear: XW^T + b
        return linear(X, W, self.bias)
        #return torch.matmul(X, W.T)

In [None]:
# class NaluCell(nn.Module):
#     def __init__(self, in_shape, out_shape):
#         super().__init__()
#         self.in_shape = in_shape
#         self.out_shape = out_shape

#         self.G = Parameter(Tensor(out_shape, in_shape))
#         self.nac = NacCell(in_shape, out_shape)

#         uniform_(self.G, 0.0, 1.0)
#         # epsilon prevents log0
#         self.eps = 1e-5
#         self.register_parameter('bias', None)

#     # y = g * a + (1 - g) * m
#     # m = exp W(log(|x| + e)), g = sigmoid(Gx)
#     # * is elementwise product
#     # a is from nac
#     def forward(self, X):
#         a = self.nac(X)
#         g = sigmoid(linear(X, self.G, self.bias))

#         ag = g * a
#         log_in = log(abs(X) + self.eps)
#         m = exp(self.nac(log_in))
#         md = (1 - g) * m

#         return ag + md

In [None]:
# Cell for multiplication and division
class MDCell(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape

#         self.G = Parameter(Tensor(out_shape, in_shape))
        self.nac = ASCell(in_shape, out_shape)

#         uniform_(self.G, 0.0, 1.0)
        # epsilon prevents log0
        self.eps = 1e-5
        self.register_parameter('bias', None)

    # y = g * a + (1 - g) * m
    # m = exp W(log(|x| + e)), g = sigmoid(Gx)
    # * is elementwise product
    # a is from nac
    
    # y = exp W(log(|X| + e))
    # W = tanh(W) * sigmoid(M)
    # * is elementalwise product
    def forward(self, X):
#         a = self.nac(X)
#         g = sigmoid(linear(X, self.G, self.bias))

#         ag = g * a
        log_in = log(abs(X) + self.eps)
        m = exp(self.nac(log_in))
#         md = (1 - g) * m

#         return ag + md
        return m

In [None]:
# class NaluLayer(nn.Module):
#     def __init__(self, input_shape, output_shape, n_layers, hidden_shape):
#         super().__init__()
#         self.input_shape = input_shape
#         self.output_shape = output_shape
#         self.n_layers = n_layers
#         self.hidden_shape = hidden_shape
        
#         layers = [NaluCell(hidden_shape if n > 0 else input_shape, hidden_shape if n < n_layers - 1 else output_shape) for n in range(n_layers)]
#         self.model = Sequential(*layers)

#     def forward(self, X):
#         return self.model(X)

In [None]:
# class NALU(nn.Module):
#     def __init__(self, input_shape, output_shape, n_layers, hidden_shape):
#         super().__init__()
#         self.nalu1 = NaluLayer(input_shape, output_shape, n_layers, hidden_shape)

#         print('input: {}, output: {}, n: {}, hidden: {}'.format(input_shape, output_shape, n_layers, hidden_shape))
        
#     def forward(self, X):
#         X = self.nalu1(X)
#         X = X.squeeze()

#         return X

In [None]:
class NALU2(nn.Module):
    def __init__(self):
        super().__init__()
        self.MD1 = MDCell(in_shape=4, out_shape=2)
        self.AS1 = ASCell(in_shape=2, out_shape=1)
        
    def forward(self, X):
        X = self.MD1(X)
        X = self.AS1(X)
        X = X.squeeze()

        return X

In [None]:
num = 1
n_train = 100000
X_train = []
y_train = []

while num <= n_train:
    A = np.random.randint(1, 50, size=(2, 2))
    
    if det(A) == 0:
        continue
    else:
        A = A.astype(np.float32)
        X_train.append(A)
        y_train.append(det(A))
        num += 1

In [None]:
X_train = np.array(X_train)
y_train = np.array(y_train)

X_train = torch.from_numpy(X_train)
X_train = X_train.reshape(n_train, 4)
y_train = torch.from_numpy(y_train)

In [None]:
num = 1
n_val = 10000
X_val = []
y_val = []

while num <= n_val:
    A = np.random.randint(1, 50, size=(2, 2))
    
    if det(A) == 0:
        continue
    else:
        A = A.astype(np.float32)
        X_val.append(A)
        y_val.append(det(A))
        num += 1

In [None]:
X_val = np.array(X_val)
y_val = np.array(y_val)

X_val = torch.from_numpy(X_val)
X_val = X_val.reshape(n_val, 4)
y_val = torch.from_numpy(y_val)

In [None]:
print(X_val.shape)
X_val

In [None]:
print(X_train.shape)
X_train

In [None]:
print(y_train.shape)
print(sum(abs(y_train)) / n_train)
y_train

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# model = NALU(input_shape=4, output_shape=1, n_layers=2, hidden_shape=2)
model = NALU2()

In [None]:
for name, param in model.named_parameters():
  print(name, param.data)

In [None]:
model.to(device)
X_train = X_train.to(device)
y_train = y_train.to(device)
X_val = X_val.to(device)
y_val = y_val.to(device)

X_train

In [None]:
dataset = TensorDataset(X_train, y_train)
dataset

In [None]:
b_size = 128
dataloader = DataLoader(dataset, batch_size=b_size, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
epochs = 100
train_loss = 0.0
val_loss = 0.0
b_idx = 0
epos = []
los = []
v_los = []

In [None]:
print(n_train / b_size)

In [None]:
for epoch in range(epochs):
    with tqdm(dataloader) as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}/{epochs}")
        train_loss = 0
        val_loss = 0
        b_idx = 0
        
        for batch_idx, samples in enumerate(tepoch):
            b_idx += 1
            X_t, y_t = samples
            
            pred = model(X_t)
            cost = F.mse_loss(pred, y_t)
            
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
            
            loss = cost.item()
            train_loss += loss
            
            if b_idx < 782:
                tepoch.set_postfix({'Train loss(in progress)': loss})
            else:
                with torch.no_grad():
                    p_val = model(X_val)
                    val_cost = F.mse_loss(p_val, y_val)
                    val_loss = val_cost.item()
                
                train_loss /= b_idx 
#                 print('train:', train_loss, 'val:', val_loss, 'b:', b_idx)
                tepoch.set_postfix({'Train loss(final)': train_loss, 'Val loss': val_loss})
#                 tepoch.close()
        
        epos.append(epoch + 1)
        los.append(train_loss)
        v_los.append(val_loss)

initialization 에 따라 수렴 여부가 결정됨 -> 항상 수렴시킬 방법?

-> initialization 할 때 exclude negative number

-> 이래도 잘 안 됨. G 학습이 어려워서 그럴지도? G 를 학습시키지 말고 덧셈과 곱셈을 분리해버리면 어떨까?

-> 훨씬 안정적임. G 가 문제였던 것 같음!!

<!-- 음수가 포함되면 나누기 연산이 들어갈 확률이 높아져서 그런 것 같음 -->

필요한 연산보다 더 많은 layer 를 만들면 학습이 되나?

-> 되기는 하지만 타이트할 때보다 잘 되지는 않음

그럼 이제 3 x 3 이랑 4 x 4 도전

-> 3 x 3 은 성공 16023.6797

-> 4 x 4 는 좀 어려운듯? 493209.6875

In [None]:
for name, param in model.named_parameters():
  print(name, param.data)

In [None]:
# nn = 2
# i = 4
# o = 1
# h = 2
# for n in range(nn):
#     f1 = h if n > 0 else i
#     f2 = h if n < nn - 1 else o
    
#     print('({}, {})'.format(f1, f2))

In [None]:
plt.plot(epos, los, label='train')
plt.plot(epos, v_los, label='validation')
plt.title('model loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend()
# plt.savefig('aaa.png')
plt.show()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True)
fig.subplots_adjust(hspace=0.05)

ax1.plot(epos, los, label='train')
ax2.plot(epos, v_los, label='validation')

ax1.set_ylim(4e+5, 5e+5)
ax2.set_ylim(0, 100)

ax1.spines['bottom'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)
ax2.xaxis.tick_bottom()

kwargs = dict(mark=[(-1, -0.5), (1, 0.5)], markersize=12, linestyle='none', color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.title('model loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend()
plt.show()

In [None]:
# plt.plot(epos, los, label='train')
# plt.title('model loss')
# plt.xlabel('Epoch')
# plt.ylabel('loss')
# plt.legend()
# plt.savefig('aaa.png')
# plt.show()

# torch.save(model, 'aaa.pt')

# nn = 2
# i = 4
# o = 1
# h = 2
# for n in range(nn):
#     f1 = h if n > 0 else i
#     f2 = h if n < nn - 1 else o
    
#     print('({}, {})'.format(f1, f2))