In [1]:
import torch
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt
import torchvision
import torchvision.datasets
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
from net import CNN
from mydataset import MyDataSet
from myloss import TripletLoss

torch.manual_seed(10)

<torch._C.Generator at 0x285c4e90350>

In [2]:
EPOCH = 25
BATCH_SIZE = 4
LR = 0.001

In [3]:
transform = transforms.Compose([   
        transforms.Resize([96, 96]),  
        transforms.ToTensor(),  
        transforms.Normalize(mean=(0,0,0),std=(1,1,1))
])

train_set = MyDataSet('./birds/train/',transform=transform, labels=None)
train_loader = Data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
test_set = MyDataSet('./birds/test/', transform=transform, labels=None)
test_loader = Data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

In [4]:
print(train_set.classes)
print(train_set.class_to_idx)
print(train_set.__len__)

print(test_set.classes)
print(test_set.class_to_idx)
print(test_set.__len__)

['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird']
{'001.Black_footed_Albatross': 0, '002.Laysan_Albatross': 1, '003.Sooty_Albatross': 2, '004.Groove_billed_Ani': 3, '005.Crested_Auklet': 4, '006.Least_Auklet': 5, '007.Parakeet_Auklet': 6, '008.Rhinoceros_Auklet': 7, '009.Brewer_Blackbird': 8, '010.Red_winged_Blackbird': 9}
<bound method MyDataSet.__len__ of <mydataset.MyDataSet object at 0x00000285C7E8F908>>
['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird']
{'001.Black_footed_Albatross': 0, '002.Laysan_Albatross': 1, '003.Sooty_Albatross': 2, '004.Groove_billed_Ani': 3, '005.Crested_Auklet': 4, '006.Lea

In [5]:
cnn = CNN()
print(cnn)

CNN(
  (conv1): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=18432, out_features=768, bias=True)
  (out): Linear(in_features=768, out_features=10, bias=True)
)


In [6]:
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = TripletLoss(20) # 选择损失函数

for epoch in range(EPOCH):
    print('EPOCH ' + str(epoch))
    # for step, (b_x, b_y) in enumerate(train_loader):
    for step, (anchor, positive, negative) in enumerate(train_loader):
        #output = cnn(b_x)[0]
        #loss = loss_func(output, b_y)
        anchor_output = cnn(anchor[0])[1]
        positive_output = cnn(positive[0])[1]
        negative_output = cnn(negative[0])[1]
        #print(anchor_output.detach().numpy().shape)
        
        loss = loss_func(anchor_output, positive_output, negative_output)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 50 == 0:
            correct = 0
            total = 0
            
            for _, (test, _, _) in enumerate(test_loader):
                test_output = cnn(test[0])[0]
                predicted = torch.max(test_output, 1).indices
                labels = test[1]
                total += labels.size(0)
                correct += (predicted == labels).sum()
                
            print('Accuracy on the test images: %d %%' % (100 * correct / total))
            
#             for data in test_loader:
#                 images,labels = data
#                 outputs = cnn(Variable(images))
#                 predicted = torch.max(outputs[0], 1).indices
#                 total += labels.size(0)
#                 correct += (predicted == labels).sum()
                
#             print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))
print('End')

EPOCH 0
Accuracy on the test images: 9 %
Accuracy on the test images: 10 %
EPOCH 1
Accuracy on the test images: 12 %
Accuracy on the test images: 12 %
EPOCH 2
Accuracy on the test images: 12 %
Accuracy on the test images: 10 %
EPOCH 3
Accuracy on the test images: 10 %
Accuracy on the test images: 10 %
EPOCH 4
Accuracy on the test images: 5 %
Accuracy on the test images: 10 %
EPOCH 5
Accuracy on the test images: 10 %
Accuracy on the test images: 16 %
EPOCH 6
Accuracy on the test images: 13 %
Accuracy on the test images: 10 %
EPOCH 7
Accuracy on the test images: 7 %
Accuracy on the test images: 10 %
EPOCH 8
Accuracy on the test images: 10 %
Accuracy on the test images: 10 %
EPOCH 9
Accuracy on the test images: 10 %
Accuracy on the test images: 6 %
EPOCH 10
Accuracy on the test images: 13 %
Accuracy on the test images: 10 %
EPOCH 11
Accuracy on the test images: 10 %
Accuracy on the test images: 6 %
EPOCH 12
Accuracy on the test images: 6 %
Accuracy on the test images: 12 %
EPOCH 13
Accura