In [1]:
import torch
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from hessian import hessian

In [2]:
root = './data/fashion'

In [3]:
batch_size = 100
epoch = 10

In [4]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = FashionMNIST(root, train=True, transform=transform, target_transform=None, download=False)
test_dataset = FashionMNIST(root, train=False, transform=transform, target_transform=None, download=False)

print('train:', len(train_dataset))
print('test:', len(test_dataset))

train_loader = DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                        batch_size=batch_size,
                        shuffle=False)

train: 60000
test: 10000


In [5]:
def train_accuracy(W, k):
    correct = 0
    for idx, (X, y) in enumerate(train_loader):
        if idx == k:
            break
            
        _X = X.view(batch_size, -1)
        y_score = torch.matmul(_X, W)
        y_pred = y_score.argmax(dim=1)
        correct_map = torch.zeros_like(y_pred)
        num_correct = (y_pred==y).sum()
        correct += num_correct.item()
    print('train_acc:', correct/(k*batch_size))

In [6]:
W = torch.rand((784,10), requires_grad=True)
loss_func = nn.CrossEntropyLoss()
learning_rate = 0.1
for i in range(epoch):
    loss = 0
    for idx, (X, y) in enumerate(train_loader):
        _X = X.view(batch_size, -1)
        y_score = torch.matmul(_X, W)
        output = loss_func(y_score, y)
        h = hessian(output, W)
        print(W.size())
        output.backward()
        loss += output.item()
        
        
        with torch.no_grad():
            W -= learning_rate * W.grad
            W.grad.zero_()
        
    print(loss)
    train_accuracy(W,100)
    
    

torch.Size([784, 10])
torch.Size([784, 10])
torch.Size([784, 10])
torch.Size([784, 10])
torch.Size([784, 10])
torch.Size([784, 10])
torch.Size([784, 10])


KeyboardInterrupt: 