In [80]:
import torch
from torch import nn
import time
import torch.nn.functional as F
import matplotlib.pyplot as plt
import dlc_practical_prologue as prologue

In [81]:
N_PAIRS = 1000
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N_PAIRS)

In [82]:
print(train_input.shape)
print(train_classes.shape)
print(train_target.shape)
print(test_input.shape)
print(test_classes.shape)
print(test_target.shape)

torch.Size([1000, 2, 14, 14])
torch.Size([1000, 2])
torch.Size([1000])
torch.Size([1000, 2, 14, 14])
torch.Size([1000, 2])
torch.Size([1000])


In [83]:
# 处理数据集合
tran_train_input=train_input.view([2000, 1, 14, 14])
tran_train_classes=train_classes.view([2000])
tran_test_input=test_input.view([2000, 1, 14, 14])
tran_test_classes=test_classes.view([2000])

In [84]:
#CNN
class CNN_one_by_one_Net(nn.Module):
    def __init__(self):
        super(CNN_one_by_one_Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5)
        self.conv2 = nn.Conv2d(64,128,kernel_size=2)
        self.conv2_drop=nn.Dropout2d()
        self.fc1 = nn.Linear(512, 128)
        self.fc2 = nn.Linear(128, 10)
        #parameters
        self.batch_size = 50
        self.criterion = nn.CrossEntropyLoss()
        self.num_epochs = 25
        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        x=self.conv1(x)
        x=F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(x)
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 512)))
        x = F.dropout(x,training=self.training)
        x = self.fc2(x)
        return x

        # Training Function

    def trainer(self, train_input, train_target):
        """
        Train the model on a training set
        :param train_input: Training features
        :param train_target: Training labels
        """
        start_time = time.time()
        self.train()
        for epoch in range(self.num_epochs):
            for batch_idx in range(0,train_input.size(0),self.batch_size):
                output = self(train_input[batch_idx:batch_idx+self.batch_size]) 
                loss = self.criterion(output, train_target[batch_idx:batch_idx+self.batch_size])  
                self.optimizer.zero_grad()                          #清零梯度(set gradients to zero)
                loss.backward()                                #反向求梯度(backpropagate)
                self.optimizer.step()
#                 每隔50组数据，输出一次loss值(Every 50 data, output loss once)
                if not batch_idx % 50:
                    print ('Epoch: %03d/%03d | Batch %03d/%03d | Loss: %.6f' 
                           %(epoch+1, self.num_epochs, batch_idx, 
                             len(train_input), loss))
            print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))

        print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))
        

        # Test error

    def compute_error(self, input_data, target):
        """
        Compute the number of error of the model on a test set
        :param input_data: test features
        :param target: test target
        :return: error rate of the input data
        """  
    
        #test mode
        self.eval()     
        errors = 0
        for idx in range(0,input_data.size(0),self.batch_size):
            input_batch=input_data.narrow(0,idx,self.batch_size)
            outputs = self(input_batch)
            _, predicted = torch.max(outputs, 1)   #返回值和索引
            target_labels = target.narrow(0, idx, self.batch_size)
            errors += torch.sum(predicted != target_labels)

        return float(errors)*100/input_data.size(0)
    
    def compare_two_digit(self, input_data, comp_targets):
        
        #test mode
        self.eval() 
        errors = 0
        for pairs,comp_target in zip(input_data, comp_targets):
            input_num1=pairs[0].view([1,1,14,14])
            input_num2=pairs[1].view([1,1,14,14])
            output_1 = self(input_num1)
            output_2 = self(input_num2)
            _, predicted_1 = torch.max(output_1, 1)   #return value and key
            _, predicted_2 = torch.max(output_2, 1)   #return value and key
            if(predicted_2-predicted_1>0):
                result=1
            else:
                result=0
            if(comp_target!=result):
                errors=errors+1
        return float(errors)*100/input_data.size(0)
        
    def save_model(self,model_name):
        """
        Save the model to a direction
        :param model_name: the model name, e.g. CNN_Net.pth
        """         
        torch.save(self, './model/'+ model_name)

In [85]:
my_model = CNN_one_by_one_Net()
my_model.save_model('CNN_one_by_one.pth')

In [86]:
# train the model
my_model.trainer(tran_train_input, tran_train_classes)

Epoch: 001/025 | Batch 000/2000 | Loss: 18.023806
Epoch: 001/025 | Batch 050/2000 | Loss: 16.137241
Epoch: 001/025 | Batch 100/2000 | Loss: 11.463906
Epoch: 001/025 | Batch 150/2000 | Loss: 7.641965
Epoch: 001/025 | Batch 200/2000 | Loss: 4.693309
Epoch: 001/025 | Batch 250/2000 | Loss: 4.641695
Epoch: 001/025 | Batch 300/2000 | Loss: 3.356478
Epoch: 001/025 | Batch 350/2000 | Loss: 2.508665
Epoch: 001/025 | Batch 400/2000 | Loss: 2.326320
Epoch: 001/025 | Batch 450/2000 | Loss: 2.258588
Epoch: 001/025 | Batch 500/2000 | Loss: 2.132219
Epoch: 001/025 | Batch 550/2000 | Loss: 2.101699
Epoch: 001/025 | Batch 600/2000 | Loss: 2.113091
Epoch: 001/025 | Batch 650/2000 | Loss: 2.017087
Epoch: 001/025 | Batch 700/2000 | Loss: 2.195000
Epoch: 001/025 | Batch 750/2000 | Loss: 1.910613
Epoch: 001/025 | Batch 800/2000 | Loss: 2.039286
Epoch: 001/025 | Batch 850/2000 | Loss: 2.151676
Epoch: 001/025 | Batch 900/2000 | Loss: 2.003508
Epoch: 001/025 | Batch 950/2000 | Loss: 1.953738
Epoch: 001/025 | 

Epoch: 005/025 | Batch 1050/2000 | Loss: 0.488937
Epoch: 005/025 | Batch 1100/2000 | Loss: 0.524890
Epoch: 005/025 | Batch 1150/2000 | Loss: 0.280886
Epoch: 005/025 | Batch 1200/2000 | Loss: 0.440577
Epoch: 005/025 | Batch 1250/2000 | Loss: 0.569734
Epoch: 005/025 | Batch 1300/2000 | Loss: 0.522786
Epoch: 005/025 | Batch 1350/2000 | Loss: 0.789116
Epoch: 005/025 | Batch 1400/2000 | Loss: 0.667226
Epoch: 005/025 | Batch 1450/2000 | Loss: 0.321491
Epoch: 005/025 | Batch 1500/2000 | Loss: 0.499212
Epoch: 005/025 | Batch 1550/2000 | Loss: 0.447713
Epoch: 005/025 | Batch 1600/2000 | Loss: 0.642900
Epoch: 005/025 | Batch 1650/2000 | Loss: 0.448431
Epoch: 005/025 | Batch 1700/2000 | Loss: 0.565500
Epoch: 005/025 | Batch 1750/2000 | Loss: 0.831777
Epoch: 005/025 | Batch 1800/2000 | Loss: 0.511505
Epoch: 005/025 | Batch 1850/2000 | Loss: 0.448998
Epoch: 005/025 | Batch 1900/2000 | Loss: 0.658491
Epoch: 005/025 | Batch 1950/2000 | Loss: 0.303392
Time elapsed: 0.04 min
Epoch: 006/025 | Batch 000/

Epoch: 009/025 | Batch 1850/2000 | Loss: 0.140820
Epoch: 009/025 | Batch 1900/2000 | Loss: 0.236900
Epoch: 009/025 | Batch 1950/2000 | Loss: 0.210475
Time elapsed: 0.07 min
Epoch: 010/025 | Batch 000/2000 | Loss: 0.282119
Epoch: 010/025 | Batch 050/2000 | Loss: 0.283419
Epoch: 010/025 | Batch 100/2000 | Loss: 0.420123
Epoch: 010/025 | Batch 150/2000 | Loss: 0.179703
Epoch: 010/025 | Batch 200/2000 | Loss: 0.199899
Epoch: 010/025 | Batch 250/2000 | Loss: 0.076157
Epoch: 010/025 | Batch 300/2000 | Loss: 0.266939
Epoch: 010/025 | Batch 350/2000 | Loss: 0.244945
Epoch: 010/025 | Batch 400/2000 | Loss: 0.315995
Epoch: 010/025 | Batch 450/2000 | Loss: 0.319256
Epoch: 010/025 | Batch 500/2000 | Loss: 0.113902
Epoch: 010/025 | Batch 550/2000 | Loss: 0.364612
Epoch: 010/025 | Batch 600/2000 | Loss: 0.353613
Epoch: 010/025 | Batch 650/2000 | Loss: 0.381521
Epoch: 010/025 | Batch 700/2000 | Loss: 0.376264
Epoch: 010/025 | Batch 750/2000 | Loss: 0.052346
Epoch: 010/025 | Batch 800/2000 | Loss: 0.4

Epoch: 014/025 | Batch 350/2000 | Loss: 0.126529
Epoch: 014/025 | Batch 400/2000 | Loss: 0.121135
Epoch: 014/025 | Batch 450/2000 | Loss: 0.203354
Epoch: 014/025 | Batch 500/2000 | Loss: 0.156234
Epoch: 014/025 | Batch 550/2000 | Loss: 0.283224
Epoch: 014/025 | Batch 600/2000 | Loss: 0.149723
Epoch: 014/025 | Batch 650/2000 | Loss: 0.085048
Epoch: 014/025 | Batch 700/2000 | Loss: 0.301092
Epoch: 014/025 | Batch 750/2000 | Loss: 0.069403
Epoch: 014/025 | Batch 800/2000 | Loss: 0.110063
Epoch: 014/025 | Batch 850/2000 | Loss: 0.219906
Epoch: 014/025 | Batch 900/2000 | Loss: 0.103472
Epoch: 014/025 | Batch 950/2000 | Loss: 0.224084
Epoch: 014/025 | Batch 1000/2000 | Loss: 0.076041
Epoch: 014/025 | Batch 1050/2000 | Loss: 0.123767
Epoch: 014/025 | Batch 1100/2000 | Loss: 0.118335
Epoch: 014/025 | Batch 1150/2000 | Loss: 0.154276
Epoch: 014/025 | Batch 1200/2000 | Loss: 0.084514
Epoch: 014/025 | Batch 1250/2000 | Loss: 0.208988
Epoch: 014/025 | Batch 1300/2000 | Loss: 0.085546
Epoch: 014/02

Epoch: 018/025 | Batch 1250/2000 | Loss: 0.045240
Epoch: 018/025 | Batch 1300/2000 | Loss: 0.191820
Epoch: 018/025 | Batch 1350/2000 | Loss: 0.173458
Epoch: 018/025 | Batch 1400/2000 | Loss: 0.207420
Epoch: 018/025 | Batch 1450/2000 | Loss: 0.114630
Epoch: 018/025 | Batch 1500/2000 | Loss: 0.087928
Epoch: 018/025 | Batch 1550/2000 | Loss: 0.062894
Epoch: 018/025 | Batch 1600/2000 | Loss: 0.081232
Epoch: 018/025 | Batch 1650/2000 | Loss: 0.137118
Epoch: 018/025 | Batch 1700/2000 | Loss: 0.068093
Epoch: 018/025 | Batch 1750/2000 | Loss: 0.117140
Epoch: 018/025 | Batch 1800/2000 | Loss: 0.058127
Epoch: 018/025 | Batch 1850/2000 | Loss: 0.073985
Epoch: 018/025 | Batch 1900/2000 | Loss: 0.097633
Epoch: 018/025 | Batch 1950/2000 | Loss: 0.113918
Time elapsed: 0.14 min
Epoch: 019/025 | Batch 000/2000 | Loss: 0.202075
Epoch: 019/025 | Batch 050/2000 | Loss: 0.094120
Epoch: 019/025 | Batch 100/2000 | Loss: 0.128343
Epoch: 019/025 | Batch 150/2000 | Loss: 0.166892
Epoch: 019/025 | Batch 200/2000

Epoch: 023/025 | Batch 100/2000 | Loss: 0.102002
Epoch: 023/025 | Batch 150/2000 | Loss: 0.113289
Epoch: 023/025 | Batch 200/2000 | Loss: 0.025732
Epoch: 023/025 | Batch 250/2000 | Loss: 0.115393
Epoch: 023/025 | Batch 300/2000 | Loss: 0.160562
Epoch: 023/025 | Batch 350/2000 | Loss: 0.032659
Epoch: 023/025 | Batch 400/2000 | Loss: 0.040091
Epoch: 023/025 | Batch 450/2000 | Loss: 0.044041
Epoch: 023/025 | Batch 500/2000 | Loss: 0.018382
Epoch: 023/025 | Batch 550/2000 | Loss: 0.090279
Epoch: 023/025 | Batch 600/2000 | Loss: 0.098186
Epoch: 023/025 | Batch 650/2000 | Loss: 0.025286
Epoch: 023/025 | Batch 700/2000 | Loss: 0.097253
Epoch: 023/025 | Batch 750/2000 | Loss: 0.081296
Epoch: 023/025 | Batch 800/2000 | Loss: 0.098267
Epoch: 023/025 | Batch 850/2000 | Loss: 0.099129
Epoch: 023/025 | Batch 900/2000 | Loss: 0.057075
Epoch: 023/025 | Batch 950/2000 | Loss: 0.071073
Epoch: 023/025 | Batch 1000/2000 | Loss: 0.067098
Epoch: 023/025 | Batch 1050/2000 | Loss: 0.041091
Epoch: 023/025 | B

In [87]:
# output the train error and test error when figuring out what the number is
print("Train error : %.1f%% \nTest error : %.1f%%" %
      (my_model.compute_error(tran_train_input, tran_train_classes),
       my_model.compute_error(tran_test_input, tran_test_classes)))

print("The total number of the parameters is: %d" % (sum(p.numel() for p in my_model.parameters())))

Train error : 0.0% 
Test error : 3.0%
The total number of the parameters is: 101514


Now, we have trained a model to figure out what a number is, then I can apply this model to compare the number pairs.

In [88]:
# output the train error and test error when comparing the two numbers
print("Train error : %.1f%% \nTest error : %.1f%%" %
      (my_model.compare_two_digit(train_input,train_target),
      my_model.compare_two_digit(test_input,test_target)))

Train error : 9.2% 
Test error : 14.0%


The advantage of this method is we can split the 1000 pairs of data into 2000 data. Written digital number recognition has acquired very high accuracy, so we are going to firstly recognize the numbers and then then compare them. Finally, we compare the results with the targets.