In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torch.utils.data as Data
import numpy as np
import torch.nn.functional as F

In [None]:
# Load the EMNIST dataset

train_data = torchvision.datasets.EMNIST(
    root='data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True,
    split='letters'
)

test_data = torchvision.datasets.EMNIST(
    root='data',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=False,
    split='letters'     
)

In [None]:
x_train = train_data.data / 255
y_train = train_data.targets - 1

# Show some random image of a character and its label

img_index = 16
img = x_train[img_index]
print("Image Label: " + str(chr(y_train[img_index]+96)))
plt.imshow(img.reshape((28,28)))

img

In [None]:
a = torch.tensor([[x + 1 for x in range(28 * y, 28 * (y + 1))] for y in range(28)]).view(1, 1, 28, 28).float()
c = F.unfold(a, kernel_size=(4, 4), stride=4).transpose(1, 2).view(1, 7, 7, 16)
c

In [None]:
# choose a particular image and convert to the right type
temp_image = F.unfold(img.view(1, 1, 28, 28), kernel_size=(4, 4), stride=4).transpose(1, 2).view(7, 7, 16)
temp_image

In [None]:
class TwoDimensionalLSTM_fixed_direction:
    """A 2-D LSTM scanning in the given direction.
       The input should be a 3-D tensor of shape row*col*input_feature.
       The output should be a 2-D tensor of shape row*(col*hidden_units).
       """
    def __init__(self, input_features, hidden_units, rows, cols, row_decre, col_decre):
        self.input_features = input_features
        self.hidden_units = hidden_units
        self.rows, self.cols = rows, cols
        self.row_decre = row_decre
        self.col_decre = col_decre
        self.row_init, self.col_init = rows[0], cols[0]

        # To avoid scattering of tanh(), initialization issue needs to be solved!!!

        # Input gate
        self.weight_input_gate = torch.randn((input_features, hidden_units), requires_grad=True) * 0.01
        self.weight_input_state = torch.randn((hidden_units, hidden_units), requires_grad=True) * 0.01
        self.weight_input_cellout = torch.randn((2, hidden_units, hidden_units), requires_grad=True) * 0.01
        self.bias_input_gate = torch.randn(hidden_units, requires_grad=True) * 0.01

        # Forget gate
        self.weight_forget_gate = torch.randn((2, input_features, hidden_units), requires_grad=True) * 0.01
        self.weight_forget_cellout = torch.randn((2, 2, hidden_units, hidden_units), requires_grad=True) * 0.01
        self.weight_forget_state = torch.randn((2, hidden_units, hidden_units), requires_grad=True) * 0.01
        self.bias_forget_gate = torch.randn((2, hidden_units), requires_grad=True) * 0.01

        # Cell
        self.weight_cell = torch.randn((input_features, hidden_units), requires_grad=True) * 0.01
        self.weight_cell_cellout = torch.randn((2, hidden_units, hidden_units), requires_grad=True) * 0.01
        self.bias_cell = torch.randn(hidden_units, requires_grad=True) * 0.01

        # Output gate
        self.weight_output_gate = torch.randn((input_features, hidden_units), requires_grad=True) * 0.01
        self.weight_output_cellout = torch.randn((2, hidden_units, hidden_units), requires_grad=True) * 0.01
        self.weight_output_state = torch.randn((hidden_units, hidden_units), requires_grad=True) * 0.01
        self.bias_output_gate = torch.randn(hidden_units, requires_grad=True) * 0.01


    def parameters(self):
        return [self.weight_input_gate, 
                self.weight_input_state, 
                self.weight_input_cellout,
                self.bias_input_gate,  # Input gate
                self.weight_forget_gate, 
                self.weight_forget_cellout,
                self.weight_forget_state, 
                self.bias_forget_gate,  # Forget gate
                self.weight_cell, 
                self.weight_cell_cellout,
                self.bias_cell,  # Cell
                self.weight_output_gate, 
                self.weight_output_cellout, 
                self.weight_output_state, 
                self.bias_output_gate] # Output gate
    
    def __call__(self, input_image):
        states = np.zeros((len(self.rows), len(self.cols))).tolist()
        cell_outputs = np.zeros((len(self.rows), len(self.cols))).tolist()

        for row in self.rows:
            for col in self.cols:
                _input = input_image[row][col]
                row_m = row + self.row_decre
                col_m = col + self.col_decre
                
                # Deal with input gate
                input_gate = _input @ self.weight_input_gate + self.bias_input_gate
                if row != self.row_init:
                    input_gate += states[row_m][col] @ self.weight_input_state \
                                + cell_outputs[row_m][col] @ self.weight_input_cellout[0]
                if col != self.col_init:
                    input_gate += states[row][col_m] @ self.weight_input_state \
                                + cell_outputs[row][col_m] @ self.weight_input_cellout[1]
                input_gate = input_gate.sigmoid()

                # Deal with forget gate
                forget_gates = []
                for dim in range(2):
                    forget_gate = _input @ self.weight_forget_gate[dim] + self.bias_forget_gate[dim]
                    if row != self.row_init:
                        forget_gate += cell_outputs[row_m][col] @ self.weight_forget_cellout[0][dim]
                    if col != self.col_init:
                        forget_gate += cell_outputs[row][col_m] @ self.weight_forget_cellout[1][dim]
                    if dim == 0 and row != self.row_init:
                        forget_gate += states[row_m][col] @ self.weight_forget_state[dim]
                    if dim == 1 and col != self.col_init:
                        forget_gate += states[row][col_m] @ self.weight_forget_state[dim]
                    forget_gates.append(forget_gate.sigmoid())

                # Deal with cell
                cell = _input @ self.weight_cell + self.bias_cell
                if row != self.row_init:
                    cell += cell_outputs[row_m][col] @ self.weight_cell_cellout[0]
                if col != self.col_init:
                    cell += cell_outputs[row][col_m] @ self.weight_cell_cellout[1]

                # Deal with state
                state = input_gate * cell.tanh()
                if row != self.row_init:
                    state += states[row_m][col] @ forget_gates[0]
                if col != self.col_init:
                    state += states[row][col_m] @ forget_gates[1]
                states[row][col] = state

                # Deal with output gate
                output_gate = _input @ self.weight_output_gate + self.bias_output_gate + state @ self.weight_output_state
                if row != self.row_init:
                    output_gate += cell_outputs[row_m][col] @ self.weight_cell_cellout[0]
                if col != self.col_init:
                    output_gate += cell_outputs[row][col_m] @ self.weight_cell_cellout[1]
                output_gate = output_gate.sigmoid()

                # Deal with cell output
                cell_outputs[row][col] = output_gate * state.tanh()

        return torch.stack([torch.stack(row_cell_outputs) for row_cell_outputs in cell_outputs]).view(49, 2)

In [None]:
class TwoDimensionalLSTM:
    def __init__(self, input_features, hidden_units, row_size, col_size):
        self.hidden_units = hidden_units
        self.row_size, self.col_size = row_size, col_size
        self.top_left = TwoDimensionalLSTM_fixed_direction(input_features,
                                                           hidden_units,
                                                           np.arange(row_size),
                                                           np.arange(col_size),
                                                           -1, -1)
        self.top_right = TwoDimensionalLSTM_fixed_direction(input_features,
                                                            hidden_units,
                                                            np.arange(row_size),
                                                            np.arange(col_size - 1, -1, -1),
                                                            -1, 1)
        self.down_left = TwoDimensionalLSTM_fixed_direction(input_features,
                                                            hidden_units,
                                                            np.arange(row_size - 1, -1, -1),
                                                            np.arange(col_size),
                                                            1, -1)
        self.down_right = TwoDimensionalLSTM_fixed_direction(input_features,
                                                             hidden_units,
                                                             np.arange(row_size - 1, -1, -1),
                                                             np.arange(col_size - 1, -1, -1),
                                                             1, 1)
        self.weight = torch.randn((4 * hidden_units, hidden_units), requires_grad=True) * 0.01
        self.bias = torch.randn(hidden_units, requires_grad=True) * 0.01
        self.parameters = self.top_left.parameters() + self.top_right.parameters()\
                        + self.down_left.parameters() + self.down_right.parameters() + [self.weight, self.bias]
        
    def parameters(self):
        return self.parameters
    
    def __call__(self, input_image):
        temp = torch.cat((self.top_left(input_image), self.top_right(input_image),
                          self.down_left(input_image), self.down_right(input_image)), dim=1)
        temp = (temp @ self.weight + self.bias).tanh()
        return F.fold(temp.transpose(0, 1).view(1, self.hidden_units, self.row_size * self.col_size),
                      (self.row_size, self.col_size),
                      (1, self.hidden_units),
                      stride=(1, self.hidden_units))[0][0]

In [None]:
# Parameters Settings 32 * 32

parameters = []

# The first layer of mdlstm, transfer to 8 * (8 * 2)
first_layer = TwoDimensionalLSTM(16, 2, 8, 8)
parameters.extend(first_layer.parameters())

# The second layer of mdlstm, transfer to 4 * (8 * 4)
second_layer = TwoDimensionalLSTM(4, 4, 4, 8)
parameters.extend(second_layer.parameters())

# The third layer of mklstm transfer to 2 * (16 * 8)
third_layer = TwoDimensionalLSTM(4, 8, 2, 16)
parameters.extend(third_layer.parameters())

# The fourth layer of mdlstm, transfer to  1 * (64 * 5)
fourth_layer = TwoDimensionalLSTM(4, 5, 1, 64)
parameters.extend(fourth_layer.parameters())

hidden_units1 = 10
hidden_units2 = 100
output_units = 26

# Lookup matrix, reduce the dimension of 320
lookup = torch.randn((320, hidden_units1), requires_grad=True) * 0.01
bias_lookup = torch.randn(hidden_units1, requires_grad=True) * 0.01

# The hidden layer
weight1 = torch.randn((hidden_units1, hidden_units2), requires_grad=True) * 0.01
bias1 = torch.randn(hidden_units2, requires_grad=True) * 0.01

weight2 = torch.randn((hidden_units2, output_units), requires_grad=True) * 0.01
bias2 = torch.randn(output_units, requires_grad=True) * 0.01

parameters += [lookup, bias_lookup, weight1, bias1, weight2, bias2]

In [None]:
def forward_func(minibatch):
    # Minibacth is of shape N * 28 * 28
    batch_size = minibatch.shape[0]
    minibatch = minibatch.view(batch_size, 1, 28, 28)
    minibatch = F.unfold(minibatch, kernel_size=4, padding=2, stride=4).transpose(1, 2).view(batch_size, 8, 8, 16)
    minibatch = torch.stack([first_layer(batch) for batch in minibatch])
    # Minibacth is of shape N * 8 * 16
    minibatch = minibatch.view(batch_size, 1, 8, 16)
    minibatch = F.unfold(minibatch, kernel_size=2, stride=2).transpose(1, 2).view(batch_size, 4, 8, 4)
    minibatch = torch.stack([second_layer(batch) for batch in minibatch])
    # Minibatch is of shape N * 4 * 32
    minibatch = minibatch.view(batch_size, 1, 4, 32)
    minibatch = F.unfold(minibatch, kernel_size=2, stride=2).transpose(1, 2).view(batch_size, 2, 16, 4)
    minibatch = torch.stack([third_layer(batch) for batch in minibatch])
    # Minibatch is of shape N * 2 * 128
    minibatch = minibatch.view(batch_size, 1, 2, 128)
    minibatch = F.unfold(minibatch, kernel_size=2, stride=2).transpose(1, 2).view(batch_size, 1, 64, 4)
    minibatch = torch.stack([fourth_layer(batch) for batch in minibatch])

    minibatch = minibatch.view(batch_size, 320)

    minibatch = ((minibatch @ lookup + bias_lookup).tanh() @ weight1 + bias1).tanh() @ weight2 + bias2

    return minibatch

In [None]:
batch_size = 32

# Randomly select batch from the dataset
selected = torch.randint(high=len(x_train), size=batch_size)
batch_x = forward_func(x_train[selected])
batch_y = y_train[selected]
loss = F.cross_entropy(batch_x, batch_y)
loss.backward()

In [None]:
# Parameters of the first layer, scanning from four directions

hidden_units1 = 2
input_feature1 = 16

# Input gate
weight_input_gate1_tl = torch.randn((input_feature1, hidden_units1), requires_grad=True)
weight_input_state1_tl = torch.randn((hidden_units1, hidden_units1), requires_grad=True)
weight_input_cellout1_tl = torch.randn((2, hidden_units1, hidden_units1), requires_grad=True)
bias_input_gate1_tl = torch.randn(hidden_units1, requires_grad=True)

# Forget gate
weight_forget_gate1_tl = torch.randn((2, input_feature1, hidden_units1), requires_grad=True)
weight_forget_cellout1_tl = torch.randn((2, hidden_units1, hidden_units1), requires_grad=True)
weight_forget_state1_tl = torch.randn((2, hidden_units1, hidden_units1), requires_grad=True)
bias_forget_gate1_tl = torch.randn(hidden_units1, requires_grad=True)

# Cell
weight_cell1_tl = torch.randn((input_feature1, hidden_units1), requires_grad=True)
weight_cell_cellout1_tl = torch.randn((2, hidden_units1, hidden_units1), requires_grad=True)
bias_cell1_tl = torch.randn(hidden_units1, requires_grad=True)

# Output gate
weight_output_gate1_tl = torch.randn((input_feature1, hidden_units1), requires_grad=True)
weight_output_cellout1_tl = torch.randn((2, hidden_units1, hidden_units1), requires_grad=True)
weight_output_state1_tl = torch.randn((hidden_units1, hidden_units1), requires_grad=True)
bias_output_gate1_tl = torch.randn(hidden_units1, requires_grad=True)
