In [1]:
import torch
import torchvision
from torch.utils import data

# data/model/optimize/loss/init/train

# data

In [2]:
mnist_train = torchvision.datasets.FashionMNIST('../data',train=True,download=True,
                                                transform = torchvision.transforms.ToTensor())

In [3]:
mnist_test = torchvision.datasets.FashionMNIST('../data',train=False,download=True,
                                              transform = torchvision.transforms.ToTensor())

In [4]:
batch_size = 256

In [5]:
train_loader = data.DataLoader?

In [None]:
train_loader = data.DataLoader

In [6]:
train_loader = data.DataLoader(mnist_train,batch_size=batch_size)

In [7]:
test_loader = data.DataLoader(mnist_test,batch_size=batch_size)

In [8]:
next(iter(train_loader))[0][0][0].shape

torch.Size([28, 28])

# model

In [9]:
net = torch.nn.Sequential?

In [None]:
net = torch.nn.Sequential

In [10]:
torch.nn.Conv2d?

In [11]:
torch.nn.MaxPool2d?

In [12]:
net = torch.nn.Sequential(torch.nn.Conv2d(1,6,(5,5),1,2),torch.nn.Sigmoid(),torch.nn.MaxPool2d((2,2)),
                         torch.nn.Conv2d(6,16,(5,5),1,0),torch.nn.Sigmoid(),torch.nn.MaxPool2d((2,2)),
                         torch.nn.Flatten(),torch.nn.Linear(in_features=400,out_features=120),torch.nn.Sigmoid(),
                         torch.nn.Linear(in_features=120,out_features=84),torch.nn.Sigmoid(),
                         torch.nn.Linear(in_features=84,out_features=10))

In [13]:
x = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)
for layer in net:
    x = layer(x)
    print(layer.__class__.__name__, 'output shape: \t', x.shape)

Conv2d output shape: 	 torch.Size([1, 6, 28, 28])
Sigmoid output shape: 	 torch.Size([1, 6, 28, 28])
MaxPool2d output shape: 	 torch.Size([1, 6, 14, 14])
Conv2d output shape: 	 torch.Size([1, 16, 10, 10])
Sigmoid output shape: 	 torch.Size([1, 16, 10, 10])
MaxPool2d output shape: 	 torch.Size([1, 16, 5, 5])
Flatten output shape: 	 torch.Size([1, 400])
Linear output shape: 	 torch.Size([1, 120])
Sigmoid output shape: 	 torch.Size([1, 120])
Linear output shape: 	 torch.Size([1, 84])
Sigmoid output shape: 	 torch.Size([1, 84])
Linear output shape: 	 torch.Size([1, 10])


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


# loss 

In [14]:
loss_func = torch.nn.CrossEntropyLoss()

# optimize

In [15]:
optim = torch.optim.SGD(net.parameters(),lr=0.9)

# init 

In [16]:
def init_weight(layer):
    if isinstance(layer,torch.nn.Linear) or isinstance(layer,torch.nn.Conv2d):
        torch.nn.init.xavier_normal_(layer.weight)
        

In [17]:
net.apply(init_weight)

Sequential(
  (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): Sigmoid()
  (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (4): Sigmoid()
  (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=400, out_features=120, bias=True)
  (8): Sigmoid()
  (9): Linear(in_features=120, out_features=84, bias=True)
  (10): Sigmoid()
  (11): Linear(in_features=84, out_features=10, bias=True)
)

# train

In [18]:
import logging
from  tqdm import tqdm

In [19]:
def get_logger(filename, verbosity=1, name=None):
    level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
    formatter = logging.Formatter(
        "[%(asctime)s][%(filename)s][line:%(lineno)d][%(levelname)s] %(message)s"
    )
    logger = logging.getLogger(name)
    logger.setLevel(level_dict[verbosity])

    fh = logging.FileHandler(filename, "w")
    fh.setFormatter(formatter)
    logger.addHandler(fh)

    sh = logging.StreamHandler()
    sh.setFormatter(formatter)
    logger.addHandler(sh)

    return logger

In [20]:
class Accumulator:
    def __init__(self,n):
        self.data=[0.0]*n
    
    def add(self,*args):
        self.data = [a+float(b) for a,b in zip(self.data,args)]
    
    def reset(self):
        self.data = [0.0]*n
        
    def __getitem__(self,idx):
        return self.data[idx]

In [21]:
def accuracy(y_pred,y):
    y_pred = torch.argmax(y_pred,-1)
    return sum(y_pred==y)

In [22]:
def evaluate_acc(net,data_iter):
    if isinstance(net,torch.nn.Module):
        net.eval()
    metric = Accumulator(2)
    with torch.no_grad():
        for x,y in data_iter:
            y_pred = net(x)
            ##compute 
            metric.add(accuracy(y_pred,y),y.numel())        
    return metric[0]/metric[1]

In [23]:
def train_epoch(train_iter,net,loss,optimize):
    if isinstance(net,torch.nn.Module):
        net.train()
    ##metric
    metric = Accumulator(3)
    for x,y in train_iter:
        pred_y = net(x)
        #print(y)
        #print('predict',pred_y)
        #print(torch.argmax(pred_y))
        #print(y)
        loss_tmp = loss(pred_y,y)
        metric.add(loss_tmp*len(y),accuracy(pred_y,y),y.numel())
        optimize.zero_grad()
        loss_tmp.backward()
        optimize.step()
    return metric[0]/metric[2],metric[1]/metric[2]

In [24]:
def train(train_iter,test_iter,net,loss,optimize,num_epochs,logger):
    for epoch in tqdm(range(num_epochs)):
        train_metrics = train_epoch(train_iter,net,loss,optimize)
        test_acc = evaluate_acc(net,test_loader)
        print('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
        logger.info('epoch:%d\ttrain_loss:%f\ttrain_acc:%f\ttest_acc:%f'%(epoch,train_metrics[0],train_metrics[1],test_acc))
    train_loss,train_acc = train_metrics

In [25]:
log = get_logger('./lenet.log')

In [None]:
train(train_loader,test_loader,net,loss_func,optim,num_epochs=10,logger=log)

  0%|          | 0/10 [00:00<?, ?it/s][2021-08-15 13:46:08,139][448897896.py][line:6][INFO] epoch:0	train_loss:2.316995	train_acc:0.102467	test_acc:0.100000
 10%|█         | 1/10 [00:45<06:47, 45.33s/it]

epoch:0	train_loss:2.316995	train_acc:0.102467	test_acc:0.100000


[2021-08-15 13:46:50,080][448897896.py][line:6][INFO] epoch:1	train_loss:1.575852	train_acc:0.383183	test_acc:0.623000
 20%|██        | 2/10 [01:27<05:54, 44.31s/it]

epoch:1	train_loss:1.575852	train_acc:0.383183	test_acc:0.623000


[2021-08-15 13:47:31,308][448897896.py][line:6][INFO] epoch:2	train_loss:0.833814	train_acc:0.672967	test_acc:0.693000
 30%|███       | 3/10 [02:08<05:03, 43.39s/it]

epoch:2	train_loss:0.833814	train_acc:0.672967	test_acc:0.693000


[2021-08-15 13:48:18,716][448897896.py][line:6][INFO] epoch:3	train_loss:0.671732	train_acc:0.738033	test_acc:0.750800
 40%|████      | 4/10 [02:55<04:27, 44.59s/it]

epoch:3	train_loss:0.671732	train_acc:0.738033	test_acc:0.750800


[2021-08-15 13:49:05,191][448897896.py][line:6][INFO] epoch:4	train_loss:0.602597	train_acc:0.766183	test_acc:0.776200
 50%|█████     | 5/10 [03:42<03:45, 45.16s/it]

epoch:4	train_loss:0.602597	train_acc:0.766183	test_acc:0.776200


[2021-08-15 13:49:48,421][448897896.py][line:6][INFO] epoch:5	train_loss:0.548664	train_acc:0.790233	test_acc:0.799400
 60%|██████    | 6/10 [04:25<02:58, 44.58s/it]

epoch:5	train_loss:0.548664	train_acc:0.790233	test_acc:0.799400


[2021-08-15 13:50:34,891][448897896.py][line:6][INFO] epoch:6	train_loss:0.503689	train_acc:0.808633	test_acc:0.812800
 70%|███████   | 7/10 [05:12<02:15, 45.15s/it]

epoch:6	train_loss:0.503689	train_acc:0.808633	test_acc:0.812800


[2021-08-15 13:51:24,670][448897896.py][line:6][INFO] epoch:7	train_loss:0.469131	train_acc:0.823067	test_acc:0.821500
 80%|████████  | 8/10 [06:01<01:33, 46.54s/it]

epoch:7	train_loss:0.469131	train_acc:0.823067	test_acc:0.821500
