In [2]:
import torch
import torchvision
from torch.utils import data

# data/model/loss/optimize/train

# 1 dataset

In [3]:
#mnist_train = torchvision.datasets.FashionMNIST?

In [4]:
mnist_train = torchvision.datasets.FashionMNIST('../data',train=True,download=True,
                                                transform = torchvision.transforms.ToTensor())

In [5]:
mnist_test = torchvision.datasets.FashionMNIST('../data',train=False,download=True,
                                              transform = torchvision.transforms.ToTensor())

In [6]:
mnist_train

Dataset FashionMNIST
    Number of datapoints: 60000
    Split: train
    Root Location: ../data
    Transforms (if any): ToTensor()
    Target Transforms (if any): None

In [7]:
mnist_train.targets[0]

tensor(9)

In [8]:
mnist_train.data[0].shape

torch.Size([28, 28])

## 1.2 dataloader

In [9]:
batch_size = 256

In [10]:
train_loader = data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True)
test_loader = data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False)

# 2 model

In [11]:
net = torch.nn.Sequential?

In [None]:
net = torch.nn.Sequential

In [12]:
net = torch.nn.Sequential

In [13]:
net = torch.nn.Sequential(torch.nn.Flatten(),torch.nn.Linear(28*28,10))

In [14]:
for i in net.children():print(i)

Flatten(start_dim=1, end_dim=-1)
Linear(in_features=784, out_features=10, bias=True)


In [15]:
net[1].weight.data[0][:10]

tensor([-0.0009, -0.0047, -0.0203,  0.0125,  0.0243, -0.0202, -0.0129,  0.0022,
         0.0015,  0.0210])

In [16]:
def init_weights(l):
    if isinstance(l,torch.nn.Linear):
        l.weight.data.normal_(mean=0,std=0.01)

In [17]:
net.apply(init_weights)

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=10, bias=True)
)

In [18]:
net[1].weight.data[0][:10]

tensor([ 0.0126,  0.0062, -0.0021,  0.0056, -0.0020,  0.0026,  0.0099,  0.0028,
        -0.0119,  0.0136])

## 2.2 loss

In [19]:
torch.nn.CrossEntropyLoss?

In [20]:
torch.nn.NLLLoss?

In [21]:
torch.nn.LogSoftmax?

In [22]:
loss_func = torch.nn.CrossEntropyLoss?

In [None]:
loss_func = torch.nn.CrossEntropyLoss

In [23]:
loss_func = torch.nn.CrossEntropyLoss

In [24]:
loss_func = torch.nn.CrossEntropyLoss()

## 2.3 optimize

In [25]:
optimize = torch.optim.Adam?

In [None]:
optimize = torch.optim.Adam

In [26]:
optimize = torch.optim.Adam

In [27]:
optimize = torch.optim.Adam(net.parameters(),lr=0.001,weight_decay=0.1)

# 3 train

In [28]:
import logging
from  tqdm import tqdm

In [29]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger

In [30]:
class Accumulator:
    def __init__(self,n):
        self.data=[0.0]*n
    
    def add(self,*args):
        self.data = [a+float(b) for a,b in zip(self.data,args)]
    
    def reset(self):
        self.data = [0.0]*n
        
    def __getitem__(self,idx):
        return self.data[idx]

In [31]:
def accuracy(y_pred,y):
    y_pred = torch.argmax(y_pred,-1)
    return sum(y_pred==y)

In [32]:
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])

accuracy(y_hat, y)

tensor(1)

In [33]:
torch.argmax(y_hat,-1)

tensor([2, 2])

In [34]:
def evaluate_acc(net,data_iter):
    if isinstance(net,torch.nn.Module):
        net.eval()
    metric = Accumulator(2)
    with torch.no_grad():
        for x,y in data_iter:
            y_pred = net(x)
            ##compute 
            metric.add(accuracy(y_pred,y),y.numel())        
    return metric[0]/metric[1]

In [35]:
a = torch.Tensor([2,3,4])

In [36]:
b = torch.Tensor([1,3,2])

In [37]:
evaluate_acc(net,test_loader)

0.1211

In [38]:
def train_epoch(train_iter,net,loss,optimize):
    if isinstance(net,torch.nn.Module):
        net.train()
    ##metric
    metric = Accumulator(3)
    for x,y in train_iter:
        pred_y = net(x)
        #print(y)
        #print('predict',pred_y)
        #print(torch.argmax(pred_y))
        #print(y)
        loss_tmp = loss(pred_y,y)
        metric.add(loss_tmp*len(y),accuracy(pred_y,y),y.numel())
        optimize.zero_grad()
        loss_tmp.backward()
        optimize.step()
    return metric[0]/metric[2],metric[1]/metric[2]

In [39]:
train_epoch(train_loader,net,loss_func,optimize)

(0.9580627958933512, 0.7231666666666666)

In [45]:
def train(train_iter,test_iter,net,loss,optimize,num_epochs,logger):
    for epoch in tqdm(range(num_epochs)):
        train_metrics = train_epoch(train_iter,net,loss,optimize)
        test_acc = evaluate_acc(net,test_loader)
        print('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
        logger.info('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
    train_loss,train_acc = train_metrics

In [46]:
log = get_logger('./sfm.log')

In [None]:
train(train_loader,test_loader,net,loss_func,optimize,num_epochs=10,logger=log)

  0%|          | 0/10 [00:00<?, ?it/s][2021-08-15 13:35:20,888][448897896.py][line:6][INFO] epoch:0	train_loss:0.822481	train_acc:0.770567	test_acc:0.758400
[2021-08-15 13:35:20,888][448897896.py][line:6][INFO] epoch:0	train_loss:0.822481	train_acc:0.770567	test_acc:0.758400
 10%|█         | 1/10 [00:16<02:31, 16.84s/it]

epoch:0	train_loss:0.822481	train_acc:0.770567	test_acc:0.758400


[2021-08-15 13:35:36,179][448897896.py][line:6][INFO] epoch:1	train_loss:0.820440	train_acc:0.771867	test_acc:0.752800
[2021-08-15 13:35:36,179][448897896.py][line:6][INFO] epoch:1	train_loss:0.820440	train_acc:0.771867	test_acc:0.752800
 20%|██        | 2/10 [00:32<02:11, 16.38s/it]

epoch:1	train_loss:0.820440	train_acc:0.771867	test_acc:0.752800


[2021-08-15 13:35:52,547][448897896.py][line:6][INFO] epoch:2	train_loss:0.821278	train_acc:0.770150	test_acc:0.764300
[2021-08-15 13:35:52,547][448897896.py][line:6][INFO] epoch:2	train_loss:0.821278	train_acc:0.770150	test_acc:0.764300
 30%|███       | 3/10 [00:48<01:54, 16.37s/it]

epoch:2	train_loss:0.821278	train_acc:0.770150	test_acc:0.764300


[2021-08-15 13:36:08,629][448897896.py][line:6][INFO] epoch:3	train_loss:0.820255	train_acc:0.772517	test_acc:0.759500
[2021-08-15 13:36:08,629][448897896.py][line:6][INFO] epoch:3	train_loss:0.820255	train_acc:0.772517	test_acc:0.759500
 40%|████      | 4/10 [01:04<01:37, 16.29s/it]

epoch:3	train_loss:0.820255	train_acc:0.772517	test_acc:0.759500


[2021-08-15 13:36:23,864][448897896.py][line:6][INFO] epoch:4	train_loss:0.820468	train_acc:0.770900	test_acc:0.762800
[2021-08-15 13:36:23,864][448897896.py][line:6][INFO] epoch:4	train_loss:0.820468	train_acc:0.770900	test_acc:0.762800
 50%|█████     | 5/10 [01:19<01:19, 15.97s/it]

epoch:4	train_loss:0.820468	train_acc:0.770900	test_acc:0.762800


[2021-08-15 13:36:38,482][448897896.py][line:6][INFO] epoch:5	train_loss:0.820980	train_acc:0.770483	test_acc:0.749400
[2021-08-15 13:36:38,482][448897896.py][line:6][INFO] epoch:5	train_loss:0.820980	train_acc:0.770483	test_acc:0.749400
 60%|██████    | 6/10 [01:34<01:02, 15.57s/it]

epoch:5	train_loss:0.820980	train_acc:0.770483	test_acc:0.749400


[2021-08-15 13:36:52,364][448897896.py][line:6][INFO] epoch:6	train_loss:0.820612	train_acc:0.770817	test_acc:0.759000
[2021-08-15 13:36:52,364][448897896.py][line:6][INFO] epoch:6	train_loss:0.820612	train_acc:0.770817	test_acc:0.759000
 70%|███████   | 7/10 [01:48<00:45, 15.06s/it]

epoch:6	train_loss:0.820612	train_acc:0.770817	test_acc:0.759000
