In [51]:
import torch
import torchvision.transforms as trans
import torchvision
from torch.utils import data

def data_iter(batchsize,resize=None):
    tran=[trans.ToTensor()]
    if resize:
        tran.insert(0,trans.Resize(resize))
    
    tran=trans.Compose(tran)
    train=torchvision.datasets.FashionMNIST(root="../data",train=True,transform=tran,download=True)
    test=torchvision.datasets.FashionMNIST(root="../data",train=False,transform=tran,download=True)
    return (data.DataLoader(train,batchsize,True),data.DataLoader(test,batchsize,True))

In [52]:
batch_size=10
train_iter,test_iter=data_iter(batch_size)
for x,y in train_iter:
    print(x.shape)
    break;

for x,y in test_iter:
    print(x.shape)
    break;

#print(torchvision.datasets.FashionMNIST.classes)

torch.Size([10, 1, 28, 28])
torch.Size([10, 1, 28, 28])


In [53]:
class Accumulator():
    def __init__(self,n):
        self.data=[0.0]*n
    
    def add(self,*args):
        self.data=[a+float(b) for a,b in zip(self.data,args)]

    def __getitem__(self,i):
        return self.data[i]
    
def accuracy(y_hat,y):
    y_hat=y_hat.argmax(axis=1)
    cmd=y_hat.type(y.dtype)==y
    return float(cmd.type(y.dtype).sum())

In [54]:
from torch import nn

net=nn.Sequential(nn.Flatten(),nn.Linear(784,10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,std=0.01)
        nn.init.zeros_(m.bias)

net.apply(init_weights)

loss=nn.CrossEntropyLoss(reduction='none')

updater=torch.optim.SGD(net.parameters(),lr=0.1)

def train_ch3(net,loss,updater,train_iter):

    net.train()
    metric=Accumulator(3)
    for x,y in train_iter:
        y_hat=net(x)
        l=loss(y_hat,y)
        updater.zero_grad()
        l.mean().backward()
        updater.step()
        metric.add(l.sum(),accuracy(y_hat,y),y.numel())
    return metric[0]/metric[2],metric[1]/metric[2]

def train(net,loss,updater,train_iter,num):
    for t in range(num):
        train_ch3(net,loss,updater,train_iter)

train(net,loss,updater,train_iter,10)

In [55]:
def pred(test_iter):
    labels=torchvision.datasets.FashionMNIST.classes
    print(labels)
    for x,y in test_iter:
        break
    real=[labels[int(i)] for i in y]
    predin=net(x).argmax(axis=1)
    pred=[labels[int(i)] for i in predin]
    print(real)
    print(pred)

pred(test_iter)




['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
['T-shirt/top', 'Coat', 'Pullover', 'Ankle boot', 'Dress', 'Coat', 'Trouser', 'T-shirt/top', 'Sneaker', 'Pullover']
['Pullover', 'Coat', 'Pullover', 'Ankle boot', 'Dress', 'Coat', 'Trouser', 'T-shirt/top', 'Sneaker', 'Pullover']
