In [1]:
import torch
from torchvision import datasets,transforms
import torch.nn as nn
import torch.utils.data as Data
from torch.autograd import Variable
import torch.optim as optim


train_filepath = 'Birds/train'
test_filepath = 'Birds/test'
data_transfrom = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0,0,0),std=(1,1,1))]
)

torch.manual_seed(1)
torch.cuda.is_available()

True

In [2]:
import os
from PIL import Image
import cv2
import numpy as np 
from torch.utils.data import DataLoader

class MyDataset(DataLoader):#定义数据读取类
    def __init__(self,basepath,transforms=None):#输入参数：数据集路径,和转化方法
        self.basepath = basepath
        self.transforms = transforms
        self.classes = sorted(os.listdir(basepath))
        self.filelist = []
        for idx,Set in enumerate(sorted(self.classes)):#按类载入所有训练数据的文件名
            files = os.listdir(os.path.join(basepath,Set))
            self.filelist.append(files)

        self.class_to_idx = dict()#类别对应标签
        for i,classes in enumerate(self.classes):
            self.class_to_idx[classes] = i
        
        self.num_class = len(self.classes)
        classesname = sorted(self.classes)
        imgset = []
        labelset = []
        for i,Set in enumerate(self.filelist):#逐类读取
            for j,file in enumerate(Set):#在每个类中逐个读取样本
                img_path = os.path.join(self.basepath,classesname[i],file)#读取图像  
                imgset.append(img_path)#载入图像路径
                labelset.append(self.class_to_idx[classesname[i]])#载入标签
                
        self.imgset = imgset
        self.labelset = labelset
        
        #获取同个标签下所有图像的索引
        self.label_to_index = {label: np.where(np.array(self.labelset) == np.array(label))[0]
                               for label in self.labelset}

        
    def __len__(self):#统计数据集样本总数
        return len(self.imgset)
    
    def getpnsample(self,idx):
        fn1,label = self.imgset[idx],self.labelset[idx]
        p_idx = idx
        while p_idx == idx:#查找同标签不同图像
            p_idx = np.random.choice(self.label_to_index[label])
        
        #查找不用标签的样本
        n_label = np.random.choice(list(set(self.labelset) - set([label])))
        n_idx = np.random.choice(self.label_to_index[n_label])
        
        fn2 = self.imgset[p_idx]
        fn3 = self.imgset[n_idx]
        
        return fn1,fn2,fn3,label
        
    def __getitem__(self,index):#获取训练数据
        fn1,fn2,fn3,label = self.getpnsample(index)
        
        img = Image.open(fn1).convert('RGB')
        imgp = Image.open(fn2).convert('RGB')
        imgn = Image.open(fn3).convert('RGB')
        
        if self.transforms is not None:
            img = self.transforms(img)
            imgp = self.transforms(imgp)
            imgn = self.transforms(imgn)
        
        return img,imgp,imgn,torch.tensor(label)


In [3]:
#获取训练测试样本
trainset = MyDataset(train_filepath,data_transfrom)
testset = MyDataset(test_filepath,data_transfrom)
traindata = DataLoader(trainset,batch_size=16,shuffle=True)
testdata = DataLoader(testset,batch_size=16)

In [11]:

import torch.nn.functional as F
from torchvision import models
class TripletLoss(nn.Module):

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

model = models.vgg16(pretrained=True)

for parma in model.parameters():
    parma.requires_grad = False

model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 2048),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(0.5),
                                       torch.nn.Linear(2048, 10))

model = model.cuda()
crossloss = nn.CrossEntropyLoss().cuda()
tripltloss = TripletLoss(100).cuda()
optimizer = optim.Adam(model.classifier.parameters(),lr=0.001)

In [12]:
class LayerActivations:
    features = None
    
    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
        
    def hook_fn(self, module, input, output):
        self.features = output
    def remove(self):
        self.hook.remove()

In [18]:
EPOCH = 25
for epoch in range(EPOCH):
    print('epoch',epoch+1)
    running_loss = 0.0
    for i,train_data in enumerate(traindata):#开始训练

        train_inputs,p_imgs,n_imgs,tlabels = train_data
        train_inputs = Variable(train_inputs).cuda()
        p_imgs = Variable(p_imgs).cuda()
        n_imgs = Variable(n_imgs).cuda()
        tlabels = Variable(tlabels).cuda()
        
        conv_out = LayerActivations(model.classifier,0)
        train_outputs = model(train_inputs)
        conv_out.remove()
        act = conv_out.features
        
        conv_out = LayerActivations(model.classifier,0)
        p_outputs = model(p_imgs)
        conv_out.remove()
        actp = conv_out.features
        
        conv_out = LayerActivations(model.classifier,0)
        n_outputs = model(n_imgs)
        conv_out.remove()
        actn = conv_out.features
        optimizer.zero_grad()
        loss =  2 * crossloss(train_outputs,tlabels) + 0.001 * tripltloss(act,actp,actn)
        loss.backward()
        optimizer.step()
        #计算损失
        running_loss += loss.data
        print('\r%d'%(i+1),end='')

    correct = 0
    total = 0
    for j,test_data in enumerate(testdata):#验证
        simage,_,_,slabel = test_data
        simage = Variable(simage).cuda()
        slabel = Variable(slabel).cuda()
        output = model(simage)
        predicted = torch.max(output,1).indices
        total += slabel.size(0)
        correct += (predicted == slabel).sum()
    print('loss: %.3f' % (running_loss))
    print('test Accuracy: %.2f %%' % (100.0 * correct / total))
    running_loss = 0.0
print('Finished Training')

epoch 1
30loss: 284.006
test Accuracy: 67.50 %
epoch 2
30loss: 284.572
test Accuracy: 66.25 %
epoch 3
30loss: 19.475
test Accuracy: 68.75 %
epoch 4
30loss: 44.654
test Accuracy: 71.25 %
epoch 5
30loss: 4.092
test Accuracy: 70.00 %
epoch 6
30loss: 181.638
test Accuracy: 68.75 %
epoch 7
30loss: 383.566
test Accuracy: 66.25 %
epoch 8
30loss: 131.074
test Accuracy: 72.50 %
epoch 9
30loss: 99.836
test Accuracy: 66.25 %
epoch 10
30loss: 186.943
test Accuracy: 70.00 %
epoch 11
30loss: 93.492
test Accuracy: 71.25 %
epoch 12
30loss: 161.518
test Accuracy: 67.50 %
epoch 13
30loss: 79.958
test Accuracy: 70.00 %
epoch 14
30loss: 88.897
test Accuracy: 70.00 %
epoch 15
30loss: 73.733
test Accuracy: 67.50 %
epoch 16
30loss: 0.000
test Accuracy: 68.75 %
epoch 17
30loss: 198.810
test Accuracy: 68.75 %
epoch 18
30loss: 339.725
test Accuracy: 71.25 %
epoch 19
30loss: 18.306
test Accuracy: 72.50 %
epoch 20
30loss: 4.729
test Accuracy: 72.50 %
epoch 21
30loss: 26.053
test Accuracy: 70.00 %
epoch 22
30loss: