In [0]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import os
from datetime import datetime
from PIL import Image
import matplotlib.pyplot as plt

In [0]:
pwd = '/content/drive/My Drive/python/PyTorch_Tutorial'
train_txt_path = os.path.join(pwd, '..','..','Data','train.txt')
train_dir = os.path.join(pwd, '..','..','Data', 'train')

valid_txt_path = os.path.join(pwd, '..','..','Data','Valid.txt')
valid_dir = os.path.join(pwd, '..','..','Data','valid')

classes_name = ['plane','car','bird','cat','deer','dog','frog','horse','ship','truck']

train_bs = 64
valid_bs = 64
lr_init = 0.001
max_epoch = 200

result_dir = os.path.join(pwd, '..','..','Data','Result')
now_time = datetime.now()
time_str = datetime.strftime(now_time, '%m-%d_%H-%M-%S')

log_dir = os.path.join(result_dir, time_str)
if not os.path.exists(log_dir):
  os.makedirs(log_dir)

In [0]:
class MyDataset(Dataset):
    def __init__(self, txt_path, transform = None, target_transform = None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split('*')
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list， 然后DataLoader中给index，通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255，在transfrom.totensor会除以255，使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform，转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

In [0]:
def validate(net, data_loader, set_name, classes_name):
  net.eval()
  cls_num = len(classes_name)
  conf_mat = np.zeros([cls_num, cls_num])

  for data in data_loader:
    images, labels = data
    images = Variable(images)
    labels = Variable(labels)

    outputs = net(images)
    outputs.detach_()

    _, predicted = torch.max(outputs.data, 1)

    # 统计混淆矩阵
    for i in range(len(labels)):
      cate_i = labels[i].numpy()
      pre_i = predicted[i].numpy()
      conf_mat[cate_i, pre_i] += 1.0


  for i in range(cls_num):
    print('class;{:<10}, total num:{:<6}, correct num:{:<5}  Recall:{:.2%}  Precision:{:.2%}'
    .format(classes_name[i], np.sum(conf_mat[i,:]), conf_mat[i, i], conf_mat[i, i]/(1+np.sum(conf_mat[i,:])),
           conf_mat[i,i] / (1 + np.sum(conf_mat[:, i]))))
    
  print('{} set Accuracy:{:.2%}'.format(set_name, np.trace(conf_mat) / np.sum(conf_mat))) # trace 对角线元素求和
  
  return conf_mat, '{:.2}'.format(np.trace(conf_mat) / np.sum(conf_mat))



def show_confMat(confusion_mat, classes, set_name, out_dir):

  #归一化
  confusion_mat_N = confusion_mat.copy()
  for i in range(len(classes)):
    confusion_mat_N[i, :] = confusion_mat[i, :] / confusion_mat[i, :].sum()

  #获取颜色
  cmap = plt.cm.get_cmap('Greys')
  plt.imshow(confusion_mat_N, cmap=cmap)
  plt.colorbar()

  #设置文字
  xlocations = np.array(range(len(classes)))
  plt.xticks(xlocations, list(classes), rotation=60)
  plt.yticks(xlocations, list(classes))
  plt.xlabel('Predict label')
  plt.ylabel('True label')
  plt.title('Confusion_Matrix_' + set_name)

  #打印数字
  for i in range(confusion_mat_N.shape[0]):
    for j in range(confusion_mat_N.shape[1]):
      plt.text(x=j, y=i, s=int(confusion_mat[i,j]), va='center', ha='center',
               color='red', fontsize=10)
      
  #保存
  plt.savefig(os.path.join(out_dir,'Confusion_Matrix' + set_name + '.png'))
  plt.close()

In [0]:
#------- step 1/5 加载数据

# 数据预处理设置

normMean = [0.4948052, 0.48568845, 0.44682974]
normStd = [0.24580306, 0.24236229, 0.2603115]
normTransform = transforms.Normalize(normMean, normStd)
trainTransform = transforms.Compose([
                                     transforms.Resize(32),
                                     transforms.RandomCrop(32, padding=4),
                                     transforms.ToTensor(),
                                     normTransform
])

validTransform = transforms.Compose([
                                     transforms.ToTensor(),
                                     normTransform
])

# 构建Mydataset实例
train_data = MyDataset(txt_path=train_txt_path, transform=trainTransform)
valid_data = MyDataset(txt_path=valid_txt_path, transform=validTransform)

#构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=train_bs, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=valid_bs)

In [6]:
#--------- step 2/5: 定义网络

class Net(nn.Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool1 = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.pool2 = nn.MaxPool2d(2, 2)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x):
    x = self.pool1(F.relu(self.conv1(x)))
    x = self.pool2(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

  #定义权值初始化
  def initialize_weights(self):
    for m in self.modules():
      if isinstance(m,nn.Conv2d):
        torch.nn.init.xavier_normal_(m.weight.data)
        if m.bias is not None:
          m.bias.data.zero_()
      elif isinstance(m,nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
      elif isinstance(m,nn.Linear):
        torch.nn.init.normal_(m.weight.data, 0, 0.01)
        m.bias.data.zero_()

net = Net()

# load params
pretrained_dict = torch.load(os.path.join(pwd,'net_params.pkl'))

#dict
net_state_dict = net.state_dict()

#delect
pretrained_dict_1 = {k:v for k, v in pretrained_dict.items() if k in net_state_dict}

#update
net_state_dict.update(pretrained_dict_1)

#将包含预训练模型参数的字典放到新模型中
net.load_state_dict(net_state_dict)




<All keys matched successfully>

In [0]:
#--------- step 3/5: 定义损失函数和优化器

# 按需设置学习率

#将fc3层的参数从原始网络参数中剔除
ignored_params = list(map(id, net.fc3.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())

#为fc3层设置需要的学习率
optimizer = optim.SGD(
                      [
                       {'params': base_params},
                       {'params': net.fc3.parameters(),'lr':lr_init*10}
                       ],
                      lr_init,
                      momentum=0.9,
                      weight_decay=1e-4
                      )

criterion = nn.CrossEntropyLoss()
# 每过50个epoch 学习率就乘0.1
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1) 

In [21]:
# ---------- step 4/5 :训练
for epoch in range(max_epoch):
  loss_sigma = 0.0 # 记录一个epoch的loss之和
  correct = 0.0
  total = 0.0
  

  for i, data in enumerate(train_loader):
    #获取图片和标签
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(labels)

    
    outputs = net(inputs)
    loss = criterion(outputs, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #统计预测信息
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).squeeze().sum().numpy()
    loss_sigma += loss.item()

    #每10个iter打印一次训练信息  loss为10个iter的平均
    if i % 10 == 9 :
      loss_avg = loss_sigma / 10
      loss_sigma = 0.0
      print('Training : Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss {:.4f} Acc:{:.2%}'
      .format(epoch + 1, max_epoch, i + 1, len(train_loader), loss_avg, correct / total))
  print('参数组1的学习率:{},参数组2的学习率:{}'.format(scheduler.get_lr()[0],scheduler.get_lr()[1]))
    
  loss_sigma = 0.0
  cls_num = len(classes_name)
  conf_mat = np.zeros([cls_num, cls_num])  # 混淆矩阵
  net.eval()
  for i, data in enumerate(valid_loader):
    # 获取图片和标签
    images, labels = data
    images, labels = Variable(images), Variable(labels)

    # forward
    outputs = net(images)
    outputs.detach_()

    # 计算loss
    loss = criterion(outputs, labels)
    loss_sigma += loss.item()

    # 统计
    _, predicted = torch.max(outputs.data, 1)
    # labels = labels.data    # Variable --> tensor

    # 统计混淆矩阵
    for j in range(len(labels)):
      cate_i = labels[j].numpy()
      pre_i = predicted[j].numpy()
      conf_mat[cate_i, pre_i] += 1.0

    print('{} set Accuracy:{:.2%}'.format('Valid', conf_mat.trace() / conf_mat.sum()))

  scheduler.step() #更新学习率
print('Finished Training')





[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
Training : Epoch[076/200] Iteration[020/125] Loss 0.8202 Acc:70.62%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[030/125] Loss 0.8011 Acc:70.05%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[040/125] Loss 0.7860 Acc:70.08%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[050/125] Loss 0.7955 Acc:70.56%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[060/125] Loss 0.9155 Acc:70.23%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[070/125] Loss 0.8689 Acc:70.00%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[080/125] Loss 0.8198 Acc:70.27%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[090/125] Loss 0.8417 Acc:70.19%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[100/125] Loss 0.8532 Acc:70.03%
参数组1的学习率:0.0001,参数组2的学习率:0.001
Training : Epoch[076/200] Iteration[110/125] Loss 0.8078 Acc:70.23%


In [22]:
conf_mat_train, train_acc = validate(net, train_loader, 'train', classes_name)
conf_mat_valid, valid_acc = validate(net, valid_loader, 'valid', classes_name)

show_confMat(conf_mat_train, classes_name, 'train', log_dir)
show_confMat(conf_mat_valid, classes_name, 'valid', log_dir)

class;plane     , total num:800.0 , correct num:614.0  Recall:76.65%  Precision:75.25%
class;car       , total num:800.0 , correct num:676.0  Recall:84.39%  Precision:78.33%
class;bird      , total num:800.0 , correct num:472.0  Recall:58.93%  Precision:65.46%
class;cat       , total num:800.0 , correct num:401.0  Recall:50.06%  Precision:56.80%
class;deer      , total num:800.0 , correct num:548.0  Recall:68.41%  Precision:61.71%
class;dog       , total num:800.0 , correct num:515.0  Recall:64.29%  Precision:64.54%
class;frog      , total num:800.0 , correct num:633.0  Recall:79.03%  Precision:76.27%
class;horse     , total num:800.0 , correct num:598.0  Recall:74.66%  Precision:75.41%
class;ship      , total num:800.0 , correct num:653.0  Recall:81.52%  Precision:80.72%
class;truck     , total num:800.0 , correct num:617.0  Recall:77.03%  Precision:78.50%
train set Accuracy:71.59%
class;plane     , total num:100.0 , correct num:73.0   Recall:72.28%  Precision:59.84%
class;car       ,

In [23]:
torch.cuda.get_device_name(0)

'Tesla P4'