In [3]:
# extra utils
import logging
from abc import ABC, abstractmethod
import math

# torch
import torch
import torch.nn as nn
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# this project
from utils import finite_field_truncation, to_int_domain, ToIntDomain, from_int_to_real_domain, to_int_domain_int


class SimpleNetwork(nn.Module):
    def __init__(self, num_class=10):
        super().__init__()
        self.hidden_layer = nn.Linear(784, 64, bias=False)
        self.output_layer = nn.Linear(64, num_class, bias=False)

    def forward(self, data):
        data = data.squeeze().T.reshape(-1)
        data = self.hidden_layer(data)
        data = torch.square(data)
        data = self.output_layer(data)
        return data


class AbstractVectorizedNet(ABC):
    def __init__(self, input_vector_size=784, hidden_layer_size=64, num_classes=10,
                 device=None, verbose=True):
        self.__input_vector_size = input_vector_size
        self.__hidden_layer_size = hidden_layer_size
        self.__num_classes = num_classes
        if device is None:
            self.__device = 'cuda' if torch.cuda.is_available() else 'cpu'
        else:
            self.__device = device
        self.__verbose = verbose
        self._weight_1 = None
        self._weight_2 = None

        self.__init_weight()
        self.__to_device()

    def __init_weight(self):
        range_low, range_high = -1 / math.sqrt(self.__input_vector_size), 1 / math.sqrt(self.__input_vector_size)
        self._weight_1 = range_low + torch.rand((self.__input_vector_size, self.__hidden_layer_size)) * (range_high -
                                                                                                         range_low)
        range_low, range_high = -1 / math.sqrt(self.__hidden_layer_size), 1 / math.sqrt(self.__hidden_layer_size)
        self._weight_2 = range_low + torch.rand((self.__hidden_layer_size, self.__num_classes)) * (range_high -
                                                                                                   range_low)
        logging.info('weights are initialized')
        if self.__verbose:
            print('weights are initialized')

    def __to_device(self):
        self._weight_1 = self._weight_1.to(self.__device)
        self._weight_2 = self._weight_2.to(self.__device)

        logging.info('weights are sent to {}'.format(self.__device))
        if self.__verbose:
            print('weights are sent to {}'.format(self.__device))

    @property
    def input_vector_size(self):
        return self.__input_vector_size

    @input_vector_size.setter
    def input_vector_size(self, value):
        self.__input_vector_size = value

    @property
    def hidden_layer_size(self):
        return self.__hidden_layer_size

    @hidden_layer_size.setter
    def hidden_layer_size(self, value):
        self.__hidden_layer_size = value

    @property
    def num_classes(self):
        return self.__num_classes

    @num_classes.setter
    def num_classes(self, value):
        self.__num_classes = value

    @property
    def device(self):
        return self.__device

    @device.setter
    def device(self, value):
        self.__device = value

    @property
    def verbose(self):
        return self.__verbose

    @verbose.setter
    def verbose(self, value):
        self.__verbose = value

    @abstractmethod
    def _criterion(self, label: torch.Tensor, prediction: torch.Tensor) -> torch.Tensor:
        pass

    @abstractmethod
    def _optimizer(self, learning_rate):
        pass

    @abstractmethod
    def _forward(self, input_vector: torch.Tensor, mode: str = 'train') -> torch.Tensor:
        pass

    @abstractmethod
    def _backward(self):
        pass

    @abstractmethod
    def train(self, data_path: str, num_of_epochs: int, learning_rate):
        pass

# noinspection DuplicatedCode
class ScaledVectorizedIntegerNet(AbstractVectorizedNet):
    def __init__(self, scale_input_parameter, scale_weight_parameter, scale_learning_rate_parameter, **kwargs):
        super().__init__(**kwargs)
        self.__save_for_backward = None
        self.__gradients = None

        self.__running_loss = None
        self.__running_acc = None

        self.__scale_input_parameter = scale_input_parameter
        self.__scale_weight_parameter = scale_weight_parameter
        self.__scale_learning_rate_parameter = scale_learning_rate_parameter

        self.__scale_init_weight()

    def __scale_init_weight(self):
        self._weight_1 = to_int_domain(self._weight_1, self.__scale_weight_parameter)
        self._weight_2 = to_int_domain(self._weight_2, self.__scale_weight_parameter)

    @property
    def running_loss(self):
        return self.__running_loss

    @running_loss.setter
    def running_loss(self, value):
        self.__running_loss = value

    @property
    def running_acc(self):
        return self.__running_acc

    @running_acc.setter
    def running_acc(self, value):
        self.__running_acc = value

    @property
    def scale_input_parameter(self):
        return self.__scale_input_parameter

    @scale_input_parameter.setter
    def scale_input_parameter(self, value):
        self.__scale_input_parameter = value

    @property
    def scale_weight_parameter(self):
        return self.__scale_weight_parameter

    @scale_weight_parameter.setter
    def scale_weight_parameter(self, value):
        self.__scale_weight_parameter = value

    @property
    def scale_learning_rate_parameter(self):
        return self.__scale_learning_rate_parameter

    @scale_learning_rate_parameter.setter
    def scale_learning_rate_parameter(self, value):
        self.__scale_learning_rate_parameter = value

    def _criterion(self, label: torch.Tensor, prediction: torch.Tensor) -> torch.Tensor:
        self.__save_for_backward['label'] = label
        real_label = from_int_to_real_domain(label, self.__scale_weight_parameter)
        real_prediction = from_int_to_real_domain(prediction, self.__scale_weight_parameter)
        diff = real_label - real_prediction
        diff_norm = torch.linalg.norm(diff)
        return torch.square(diff_norm)

    def _optimizer(self, learning_rate: float):
        learning_rate = to_int_domain_int(learning_rate, self.__scale_learning_rate_parameter)
        weight_2_grad = finite_field_truncation(learning_rate * self.__gradients['weight_2_grad'],
                                                self.__scale_learning_rate_parameter)
        weight_1_grad = finite_field_truncation(learning_rate * self.__gradients['weight_1_grad'],
                                                self.__scale_learning_rate_parameter)
        self._weight_2 = self._weight_2 - weight_2_grad
        self._weight_1 = self._weight_1 - weight_1_grad

    def _forward(self, input_vector: torch.Tensor, mode: str = 'train') -> torch.Tensor:

        first_forward = finite_field_truncation(torch.matmul(torch.t(self._weight_1).type(torch.float),
                                                             input_vector.type(torch.float)),
                                                self.__scale_input_parameter)
        first_forward = finite_field_truncation(torch.square(first_forward.type(torch.float)),
                                                self.__scale_weight_parameter)
        out = finite_field_truncation(torch.matmul(torch.t(self._weight_2).type(torch.float),
                                                   first_forward.type(torch.float)), self.__scale_weight_parameter)
        if mode == 'train':
            self.__save_for_backward = {
                'input_vector': input_vector,
                'first_forward': first_forward,
                'out': out
            }

        return out

    def _backward(self):
        first_forward, out, label, input_vector = self.__save_for_backward['first_forward'], \
            self.__save_for_backward['out'], self.__save_for_backward['label'], self.__save_for_backward['input_vector']

        weight_2_grad = -2 * finite_field_truncation(torch.matmul(first_forward.type(torch.float),
                                                                  torch.t(label - out).type(torch.float)),
                                                     self.__scale_weight_parameter)

        # weight_1 gradients
        second_chain = 2 * finite_field_truncation(torch.diag(torch.matmul(torch.t(self._weight_1).type(torch.float),
                                                                           input_vector.type(torch.float)).reshape(-1)),
                                                   self.__scale_input_parameter)
        third_chain = torch.t(self._weight_2)
        fourth_chain = -2 * torch.t(label - out)
        weight_1_grad = second_chain
        weight_1_grad = finite_field_truncation(torch.matmul(third_chain.type(torch.float),
                                                             weight_1_grad.type(torch.float)),
                                                self.__scale_weight_parameter)
        weight_1_grad = finite_field_truncation(torch.matmul(fourth_chain.type(torch.float),
                                                             weight_1_grad.type(torch.float)),
                                                self.__scale_weight_parameter)
        weight_1_grad = finite_field_truncation(torch.matmul(input_vector.type(torch.float),
                                                             weight_1_grad.type(torch.float)),
                                                self.__scale_input_parameter)

        self.__gradients = {
            'weight_2_grad': weight_2_grad,
            'weight_1_grad': weight_1_grad
        }

    # noinspection DuplicatedCode
    def train(self, data_path: str, num_of_epochs: int, learning_rate: float):
        # transformations
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
            ToIntDomain(self.__scale_input_parameter)
        ])

        target_transform = transforms.Compose([
            transforms.Lambda(lambda y: torch.zeros(10, dtype=torch.float)
                              .scatter_(0, torch.tensor(y), 1)),
            ToIntDomain(self.__scale_weight_parameter)
        ])

        # load data
        train_dataset = FashionMNIST(data_path, train=True, transform=transform, target_transform=target_transform,
                                     download=True)
        train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

        curr_loss = torch.zeros(1).to(self.device)
        for epoch in range(num_of_epochs):
            for idx, (data, label) in enumerate(train_loader):
                data, label = data.to(self.device), label.to(self.device)
                data, label = data.squeeze().T.reshape(-1, 1), label.reshape(-1, 1)

                out = self._forward(data)
                loss = self._criterion(label, out)
                self._backward()
                self._optimizer(learning_rate)
                curr_loss += loss

                if idx == 0:
                    print('input, label type: {}'.format(data.type()))
                    print('weight 1 type: {}'.format(self._weight_1.type()))
                    print('weight 2 type: {}'.format(self._weight_2.type()))
                    for key in self.__gradients.keys():
                        print('{} type: {}'.format(key, self.__gradients[key].type()))
                    for key in self.__save_for_backward.keys():
                        print('{} type: {}'.format(key, self.__save_for_backward[key].type()))
                    break
            break

In [4]:
scaled_net = ScaledVectorizedIntegerNet(8, 8, 10, device='cpu')
scaled_net.train('./data', 1, 0.001)

weights are initialized
weights are sent to cpu
input, label type: torch.LongTensor
weight 1 type: torch.LongTensor
weight 2 type: torch.LongTensor
weight_2_grad type: torch.LongTensor
weight_1_grad type: torch.LongTensor
input_vector type: torch.LongTensor
first_forward type: torch.LongTensor
out type: torch.LongTensor
label type: torch.LongTensor
