In [15]:
import os
import numpy as np
import cv2
import pandas as pd
from torch.utils import data
from sklearn.model_selection import train_test_split
from net import Modle
import torch
from torch import nn
import torch.optim as optim
import time

In [16]:
def img_recover(img_list, label):
    index = range(0,4)
    index_1 = np.lexsort((index, label))
    sort = [index[i] for i in index_1]
    heng_1 = np.concatenate((img_list[sort[0]],img_list[sort[1]]),axis=1)
    heng_2 = np.concatenate((img_list[sort[2]],img_list[sort[3]]),axis=1)
    fl = np.concatenate((heng_1,heng_2), axis=0)
    return fl


In [17]:
# 创建dataset 的子类
class Mydataset(data.Dataset):
    def __init__(self, imgs_path, labels ):
        self.imgs_path = imgs_path
        self.labels = labels

    def __getitem__(self, index):
        img = self.imgs_path[index]
        img = self.read_img(img)
        img = torch.tensor(img, dtype=torch.float32)
        label = self.labels[index]

        return img,label

    def __len__(self):
        return len(self.imgs_path)


    def read_img(self, path):
        read_img = cv2.imread(path,0)
        img = self.preprocess(read_img)
        img = self.add_dim(img)
        return img


    def preprocess(self, image):
        first = image[:100,:100]
        second = image[:100,100:]
        third = image[100:,:100]
        fourth = image[100:,100:]
        return first, second, third, fourth

    def add_dim(self, img_list):
        # np.expand_dims(img_list[0], axis=0).shape
        input = np.expand_dims(img_list[0], axis=0)
        for i in list(img_list)[1:]:
            # 增加维度
            i = np.expand_dims(i, axis=0)
            # 拼接
            input = np.concatenate((input, i), axis=0)

        # 变换顺序
        # input = np.transpose(input, (0,3,1,2))
        return input

In [54]:
def read_data(path,type):
    # 拼接路径
    train_img = os.path.join(path, type)
    train_label = os.path.join(path, type+ '.csv')

    # 获取图片的path列表
    train_img_path = os.listdir(train_img)
    train_img_path.sort(key=lambda x:int(x.split('.')[0]))
    train_img_path = list(map(lambda x:os.path.join(train_img,x),train_img_path))

    # 处理标签数据
    labels = pd.read_csv(train_label)
    labels = labels.label.apply(lambda x:np.array([int(i) for i in x.split()]).reshape(1,-1))
    labels = labels.values

    first_label = labels[0]
    other_label = labels[1:]
    for i in other_label:
        first_label = np.concatenate((first_label,i), axis=0)

    first_label = torch.tensor(first_label,dtype=torch.float32)

    return train_img_path, first_label


In [19]:
train_path = './data/puzzle_2x2/'
train_img,train_label = read_data(train_path,type='train')


In [20]:
# 划分数据集
X_train, X_test, Y_train, Y_test = train_test_split(train_img, train_label, random_state=42,shuffle=True)

In [21]:
# 训练集
train_dataset = Mydataset(X_train,Y_train)
# 创建dataloader
train_dataloader = data.DataLoader(dataset=train_dataset,
                                   batch_size=64,
                                   # 设置读取线程
                                   num_workers=0,
                                   drop_last=True)

# 测试机
test_dataset = Mydataset(X_test,Y_test)
test_dataloader = data.DataLoader(dataset=test_dataset,
                                  batch_size=128)


In [22]:
# # 测试代码
# imgs_batch, labels_batch = next(iter(train_dataloader))
# # 前向传播一次
# model = Modle()
# model.eval()
# pre = model(imgs_batch)
# print(pre.shape)
# loss_fn = nn.CrossEntropyLoss()
# loss = loss_fn(pre,labels_batch)
# # loss.backward()
# loss

In [51]:
def decode(pre_data):

    b = np.array(range(4))
    c = []
    d = []
    # 按照a的大小顺序对b进行排序
    for i in np.lexsort((b, pre_data.data)):
        c.append(b[i])

    for i in np.lexsort((b, c)):
        d.append(b[i])
    return np.array(d).reshape(1,-1)

# a = np.array([0.1,2.5,1.5,1.9])
# decode(a)

def pre2label(pre):
    first = decode(pre[0])
    for k in pre[1:]:
        first = np.concatenate((first,decode(k)),axis=0)
    return first

# labels_batch
# a = pre2label(pre)
def acc(pre,label):
    pre = pre2label(pre)
    return ((pre == label.numpy()).sum(1)==4).mean()

# acc(pre,labels_batch)

In [24]:
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
# 训练
model = Modle()
model.to(device)

# 定义损失函数， 优化器
loss_fn = nn.CrossEntropyLoss()
optim = optim.Adam(model.parameters(), lr=0.001)


def train(train_dataloader,epoches=50):

    for epoch in range(epoches):
        # 初始化
        train_acc = 0
        train_loss = 0
        test_acc = 0
        test_loss = 0
        start = time.time()

        # 设置pytorch的训练模式drop_out发挥作用
        model.train()
        for x, y in train_dataloader:
            # 将数据集转移到gpu
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            optim.zero_grad()
            loss = loss_fn(y_pred, y)
            loss.backward()
            optim.step()

            with torch.no_grad():
                # 计算正确率与损失
                train_acc = train_acc + acc(y_pred.cpu(), y.cpu())
                train_loss = train_loss + loss.data.item()
                # print(train_acc)

        # 预测模式，drop_out不发挥作用 主要影响drop_out 与 BN层
        model.eval()
        with torch.no_grad():
            for x, y in test_dataloader:
                x, y = x.to(device), y.to(device)
                y_pred = model(x)
                loss = loss_fn(y_pred, y).data.item()
                test_acc = test_acc + acc(y_pred.cpu(), y.cpu())
                test_loss = test_loss + loss

            end = time.time()
            # 计算平均值
            train_loss = train_loss / len(train_dataloader)
            train_acc = train_acc / len(train_dataloader)

            test_loss = test_loss / len(test_dataloader)
            test_acc = test_acc / len(test_dataloader)
            print('当前epoch为:{},训练集损失为:{},训练集正确率为:{},验证集损失为:{},验证集正确率为:{},用时:{}s'.format(epoch,
                                                                                        train_loss,
                                                                                        train_acc,
                                                                                        test_loss,
                                                                                        test_acc,
                                                                                        end-start))



当前epoch为:0,训练集损失为:8.18324263526364,训练集正确率为:0.0,验证集损失为:7.816941675249037,验证集正确率为:0.0,用时:230.42760157585144s
当前epoch为:1,训练集损失为:7.6636555380607065,训练集正确率为:0.0,验证集损失为:7.559809344155448,验证集正确率为:0.0,用时:109.05908536911011s
当前epoch为:2,训练集损失为:7.492563196544141,训练集正确率为:0.0,验证集损失为:7.439709521911957,验证集正确率为:0.0,用时:107.99856066703796s
当前epoch为:3,训练集损失为:7.404663341620119,训练集正确率为:0.0,验证集损失为:7.389477543778472,验证集正确率为:0.0,用时:107.68720865249634s
当前epoch为:4,训练集损失为:7.32446755628647,训练集正确率为:0.0,验证集损失为:7.2936125089834025,验证集正确率为:0.0,用时:114.0962426662445s
当前epoch为:5,训练集损失为:7.286965325160599,训练集正确率为:0.0,验证集损失为:7.251766225793859,验证集正确率为:0.0,用时:100.54153203964233s
当前epoch为:6,训练集损失为:7.219972894565432,训练集正确率为:0.0,验证集损失为:7.15572113257188,验证集正确率为:0.0,用时:99.08280158042908s
当前epoch为:7,训练集损失为:7.160276596530876,训练集正确率为:0.0,验证集损失为:7.10502288367722,验证集正确率为:0.0,用时:100.00805616378784s
当前epoch为:8,训练集损失为:7.090392413432176,训练集正确率为:0.0,验证集损失为:7.06576985841269,验证集正确率为:0.0,用时:99.27525091171265s
当前epoch为:9,训练集损失为:7.04465181777282

In [27]:
# 保存模型
torch.save(model.state_dict(), './model.pkl')

In [58]:
# 在valid数据集验证

path_valid = './data/puzzle_2x2/'
valid_img,valid_label = read_data(path_valid,type='valid')

# valid数据集
valid_dataset = Mydataset(valid_img,valid_label)
# 创建dataloader
valid_dataloader = data.DataLoader(dataset=valid_dataset,
                                   batch_size=512,)

In [59]:
def valid(model,valid_dataloader,loss_fn):
    valid_acc = 0
    valid_loss = 0
    model.eval()
    for x, y in valid_dataloader:
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y).data.item()
        valid_acc = valid_acc + acc(y_pred.cpu(), y.cpu())
        valid_loss = valid_loss + loss


    valid_loss = valid_loss / len(valid_dataloader)
    valid_acc = valid_acc / len(valid_dataloader)
    print('valid数据集损失为:{},valid数据集正确率为:{}'.format(valid_loss,valid_acc))

    return valid_acc,valid_loss




valid数据集损失为:6.723812103271484,valid数据集正确率为:0.7194805194805195


  x = self.softmax_1(x)
