In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from numpy.linalg import inv, det
from tqdm import tqdm
from time import sleep

from torch import sigmoid, tanh
from torch import Tensor, exp, log
from torch.nn import Sequential
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_
from torch.nn.init import uniform_
from torch.nn.functional import linear

from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

In [3]:
from mnist import MNIST

paper: Neural Arithmetic Logic Units (https://arxiv.org/pdf/1808.00508.pdf)

In [4]:
# Cell for addition and subtraction
class ASCell(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape

        self.W_ = Parameter(Tensor(out_shape, in_shape))
        self.M_ = Parameter(Tensor(out_shape, in_shape))

        uniform_(self.W_, 0.0, 1.0), uniform_(self.M_, 0.0, 1.0)
        self.register_parameter('bias', None)

    # a = Wx
    # W = tanh(W) * sigmoid(M)
    # * is elementwise product
    def forward(self, X):
        #print('W:', self.W_.shape, 'X:', X.shape)
        W = tanh(self.W_) * sigmoid(self.M_)

        # linear: XW^T + b
        return linear(X, W, self.bias)
        #return torch.matmul(X, W.T)

In [5]:
# Cell for multiplication and division
class MDCell(nn.Module):
    def __init__(self, in_shape, out_shape):
        super().__init__()
        self.in_shape = in_shape
        self.out_shape = out_shape

#         self.G = Parameter(Tensor(out_shape, in_shape))
        self.nac = ASCell(in_shape, out_shape)

#         uniform_(self.G, 0.0, 1.0)
        # epsilon prevents log0
        self.eps = 1e-5
        self.register_parameter('bias', None)

    # y = g * a + (1 - g) * m
    # m = exp W(log(|x| + e)), g = sigmoid(Gx)
    # * is elementwise product
    # a is from nac
    
    # y = exp W(log(|X| + e))
    # W = tanh(W) * sigmoid(M)
    # * is elementalwise product
    def forward(self, X):
#         a = self.nac(X)
#         g = sigmoid(linear(X, self.G, self.bias))

#         ag = g * a
        log_in = log(abs(X) + self.eps)
        m = exp(self.nac(log_in))
#         md = (1 - g) * m

#         return ag + md
        return m

In [6]:
class Custom(nn.Module):
    def __init__(self):
        super().__init__()
        self.MD1 = MDCell(in_shape=784, out_shape=64)
        self.MD2 = MDCell(in_shape=64, out_shape=32)
        self.MD3 = MDCell(in_shape=32, out_shape=16)
        self.AS1 = ASCell(in_shape=64, out_shape=10)
        
    def forward(self, X):
        X = self.MD1(X)
        X = self.AS1(X)

        return X

In [7]:
mnist = MNIST('/home/data/MNIST')

In [8]:
X_train, y_train = mnist.load_training()
X_test, y_test = mnist.load_training()

print(type(X_train[0]), len(X_train))

<class 'list'> 60000


In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [10]:
# model = NALU(input_shape=4, output_shape=1, n_layers=2, hidden_shape=2)
model = Custom()

In [11]:
X_train = np.array(X_train)
y_train = np.array(y_train)

X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)

In [12]:
model.to(device)
X_train = X_train.to(device)
y_train = y_train.to(device)

X_train

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], device='cuda:0')

In [19]:
y_train

tensor([5, 0, 4,  ..., 5, 6, 8], device='cuda:0', dtype=torch.uint8)

In [14]:
dataset = TensorDataset(X_train, y_train)
dataset

<torch.utils.data.dataset.TensorDataset at 0x7f30b0828fd0>

In [15]:
b_size = 128
dataloader = DataLoader(dataset, batch_size=b_size, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [16]:
epochs = 100
cur_loss = 0.0
epos = []
los = []

In [21]:
for epoch in range(epochs):
    with tqdm(dataloader) as tepoch:
        for batch_idx, samples in enumerate(tepoch):
            tepoch.set_description(f"Epoch {epoch+1}/{epochs}")
            X_t, y_t = samples
            
            pred = model(X_t)
            
            cost = F.cross_entropy(pred, y_t)
            
            optimizer.zero_grad()
            cost.backward()
            optimizer.step()
            
            tepoch.set_postfix(loss=cost.item())
            cur_loss = cost.item()
            
#         with torch.no_grad():
#             y_test = model(X_test)
#             print(f"End of epoch {epoch + 1}")
#             print_loss_accuracy(y_test_log_pred, y_test, loss_function)
#             print("---")
            
        epos.append(epoch + 1)
        los.append(cur_loss)

Epoch 1/100:   0%|          | 0/469 [00:00<?, ?it/s]


RuntimeError: Expected object of scalar type Long but got scalar type Byte for argument #2 'target' in call to _thnn_nll_loss_forward

initialization 에 따라 수렴 여부가 결정됨 -> 항상 수렴시킬 방법?

-> initialization 할 때 exclude negative number

-> 이래도 잘 안 됨. G 학습이 어려워서 그럴지도? G 를 학습시키지 말고 덧셈과 곱셈을 분리해버리면 어떨까?

-> 훨씬 안정적임. G 가 문제였던 것 같음!!

<!-- 음수가 포함되면 나누기 연산이 들어갈 확률이 높아져서 그런 것 같음 -->

필요한 연산보다 더 많은 layer 를 만들면 학습이 되나?

-> 되기는 하지만 타이트할 때보다 잘 되지는 않음

그럼 이제 3 x 3 이랑 4 x 4 도전

-> 3 x 3 은 성공 16023.6797

-> 4 x 4 는 좀 어려운듯? 493209.6875

In [None]:
# nn = 2
# i = 4
# o = 1
# h = 2
# for n in range(nn):
#     f1 = h if n > 0 else i
#     f2 = h if n < nn - 1 else o
    
#     print('({}, {})'.format(f1, f2))

In [None]:
plt.plot(epos, los, label='train')
plt.title('model loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend()
# plt.savefig('aaa.png')
plt.show()

In [None]:
# plt.plot(epos, los, label='train')
# plt.title('model loss')
# plt.xlabel('Epoch')
# plt.ylabel('loss')
# plt.legend()
# plt.savefig('aaa.png')
# plt.show()

# torch.save(model, 'aaa.pt')

# nn = 2
# i = 4
# o = 1
# h = 2
# for n in range(nn):
#     f1 = h if n > 0 else i
#     f2 = h if n < nn - 1 else o
    
#     print('({}, {})'.format(f1, f2))