In [2]:
from OptTensor import OptTensor
from BinaryCrossEntropyLoss import BinaryCrossEntropyLoss
import numpy as np
from matplotlib import pyplot as plt
%load_ext autoreload
%autoreload 2

In [3]:
from torch.utils.data import Dataset, random_split, DataLoader
class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
    def __getitem__(self, index):
        sample = {
            "features": self.x[index],
            "label": self.y[index]
        }
        return sample
    def __len__(self):
        return len(self.x)

In [4]:
dim_in = 100
dim_out = [5000, 300, 2]

In [5]:
encodings = {
    'doc2vec': np.load('./data/d2v.npy'),
    'bigram': np.load('./data/pca_bigram.npy'),
    'tfidf': np.load('./data/pca_tfidf.npy'),
    'sent2vec': np.load('./data/s2v.npy'),
    'word2vec': np.load('./data/w2v.npy'),
}
labels = np.load('./data/label.npy')
choices = np.random.choice(np.arange(0, encodings['tfidf'].shape[0]), size=5000, replace=False)

In [23]:
def header(heading: str) -> str:
    return '-'*4 + heading + '-'*5

def intro(layername, data, grad):
    print(header(f'{layername} data'))
    print(data)
    print(header(f'{layername} grad'))
    print(grad)

mydata = MyDataset(encodings['word2vec'][choices], labels[choices])
loader = DataLoader(mydata, batch_size=4)
linear1 = OptTensor(np.random.randn(100, 10))
linear2 = OptTensor(np.random.randn(10, 5))
linear3 = OptTensor(np.random.randn(5, 2))
loss = BinaryCrossEntropyLoss()

In [24]:
learning_step = 0.0001
losses = []
data_iter = iter(loader)
data = next(data_iter)

In [25]:
x = data['features'].numpy()
y = data['label'].numpy()
print('current_labels', y)

current_labels [1 1 1 1]


In [26]:
lr1 = linear1(x)
relu1 = lr1.relu()
lr2 = linear2(relu1).relu()
relu2 = lr2.relu()
lr3 = linear3(relu2).relu()
relu3 = lr3.relu()
probs = relu3.softmax(temperature=1000)
current_loss = loss(probs, y)
current_loss.backward()
lr1.data -= learning_step * lr1.grad
lr2.data -= learning_step * lr2.grad
lr3.data -= learning_step * lr3.grad
current_loss.data

0.7063497734311761

In [27]:
# data
print(header('lr1'))
print(lr1.data)
print(header('relu1'))
print(relu1.data)
print(header('lr2'))
print(lr2.data)
print(header('relu2'))
print(relu2.data)
print(header('lr3'))
print(lr3.data)
print(header('relu3'))
print(relu3.data)
print(header('probs'))
print(probs.data)

----lr1-----
[[ 6.25807609  4.28419693 -0.36452158 -1.81746562 -4.48174538 -6.00675018
  -8.34381229 -6.70599973  5.40045223  2.57787236]
 [ 2.84002382 -1.36003246 -6.47156583 -2.98476361 -8.59985688 -0.78794719
  -7.46760978 -8.76397994  6.81839787 -2.61832472]
 [ 4.00704915  1.98788014 -1.59186209  1.43881693 -1.15406523 -0.41316787
  -3.87338969 -6.07594748  3.82904389  4.88647684]
 [ 1.68868859 -5.78642045  1.75742817 -1.24848944 -6.50074144  2.18706026
  -0.85237329 -4.86325113  3.73226513 -1.19498366]]
----relu1-----
[[6.25849853 4.28416773 0.         0.         0.         0.
  0.         0.         5.40066985 2.57785195]
 [2.84023112 0.         0.         0.         0.         0.
  0.         0.         6.81877747 0.        ]
 [4.00746785 1.9878512  0.         1.43849551 0.         0.
  0.         0.         3.82925959 4.88645661]
 [1.68890902 0.         1.75728574 0.         0.         2.18714629
  0.         0.         3.732701   0.        ]]
----lr2-----
[[ 0.          0.6681

In [29]:
# grad
print(header('lr1 grad'))
print(lr1.grad)
print(header('relu1 grad'))
print(relu1.grad)
print(header('lr2 grad'))
print(lr2.grad)
print(header('relu2 grad'))
print(relu2.grad)
print(header('lr3 grad'))
print(lr3.grad)
print(header('relu3 grad'))
print(relu3.grad)
print(header('probs grad'))
print(probs.grad)

----lr1 grad-----
[[ 4.2243483  -0.29202355  0.         -0.         -0.          0.
  -0.          0.          2.17629351 -0.20403274]
 [ 2.07300481  0.          0.         -0.         -0.          0.
   0.          0.          3.79595819 -0.        ]
 [ 4.18697102 -0.28943971  0.         -3.21421575 -0.          0.
  -0.          0.          2.15703754 -0.20222745]
 [ 2.20425266  0.         -1.42425438 -0.         -0.          0.86030182
   0.         -0.          4.35876275 -0.        ]]
----relu1 grad-----
[[ 4.2243483  -0.29202355  1.353444   -3.2429092  -4.76690893  1.67485444
  -2.8509284   0.79238716  2.17629351 -0.20403274]
 [ 2.07300481  0.75245202  1.51467951 -2.77176286 -1.04672067  1.98563307
   0.99484806  0.11552939  3.79595819 -1.52754434]
 [ 4.18697102 -0.28943971  1.34146865 -3.21421575 -4.72473105  1.66003523
  -2.82570323  0.78537607  2.15703754 -0.20222745]
 [ 2.20425266  3.54085758 -1.42425438 -1.66927827 -1.60677936  0.86030182
   3.42535428 -0.43166776  4.3587627