In [5]:
import os
import sys
if '/opt/ros/kinetic/lib/python2.7/dist-packages' in sys.path:
    sys.path.remove('/opt/ros/kinetic/lib/python2.7/dist-packages')
import cv2
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
import torch 
from ConvAE import Encoder, Decoder
import torch.functional as F
import torch.nn as nn
import tensorboardX
import random

In [6]:
class ConvAE_Dataset(Dataset):
    def __init__(self, dir="./maps") -> None:
        super().__init__()
        self.dataset_dir = dir

    def __len__(self) -> int:
        files = os.listdir(self.dataset_dir)
        return len(files)

    def __getitem__(self, index: int):
        image_name = "{}.png".format(index)
        image_path = os.path.join(self.dataset_dir, image_name)
        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        image = image / 255.0
        image = np.expand_dims(image, 0)
        return image, index

In [7]:
myDataset = ConvAE_Dataset("./localmaps")

# train_dataset_size = int(len(myDataset) * 0.6)
# test_dataset_size = int(len(myDataset) * 0.3)
# validate_dataset_size = len(myDataset) - (train_dataset_size + test_dataset_size)

# train_dataset, test_dataset, validate_dataset = random_split(myDataset, [train_dataset_size, test_dataset_size, validate_dataset_size])

train_dataset_loader = DataLoader(myDataset, batch_size=20, shuffle=False) 

# train_dataset_loader = DataLoader(train_dataset, batch_size=20, shuffle=False) 
# test_dataset_loader = DataLoader(test_dataset, batch_size=20, shuffle=False) 
# validate_dataset_loader = DataLoader(validate_dataset, batch_size=20, shuffle=False) 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


print('数据集样本数量: ', len(myDataset))

print('图片尺寸: ', myDataset[0][0].shape)

writer = tensorboardX.SummaryWriter('./log', flush_secs=2)

train_epoches = 500

test_frequency = 10

save_weight_frequency = 100

loss_fn = nn.MSELoss()
# loss_fn = nn.Softmax2d()

lr = 0.005
 
encoder = Encoder()
decoder = Decoder()

params_to_optimize = [
    {'params': encoder.parameters()},
    {'params': decoder.parameters()}
]

optimizer = torch.optim.Adam(params_to_optimize, lr=lr, weight_decay=1e-5)

encoder.to(device)
decoder.to(device)

数据集样本数量:  5424
图片尺寸:  (1, 200, 200)


Decoder(
  (decoder_conv): Sequential(
    (0): Upsample(scale_factor=2.0, mode=nearest)
    (1): ConvTranspose2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace)
    (4): ConvTranspose2d(64, 64, kernel_size=(2, 2), stride=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace)
    (7): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(1, 1))
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace)
    (10): Upsample(scale_factor=2.0, mode=nearest)
    (11): ConvTranspose2d(32, 32, kernel_size=(2, 2), stride=(1, 1))
    (12): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace)
    (14): ConvTranspose2d(32, 16, kernel_size=(2, 2), stride=(1, 1))
    (15): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, trac

In [8]:
for i in range(train_epoches):
    train_loss_epoch = []
    for image_batch, index_batch in train_dataset_loader:
        # print(index_batch)
        image_batch_tensor = image_batch.clone().detach().float().to(device)
        encoded_data = encoder(image_batch_tensor)
        # print(encoded_data.shape)
        decoded_data = decoder(encoded_data)
        # print(decoded_data.shape)
        # loss = loss_fn(decoded_data, image_batch_tensor)
        loss = torch.sqrt((decoded_data - image_batch_tensor).pow(2).mean())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_numpy = loss.detach().cpu().numpy()
        # print(loss_numpy)
        train_loss_epoch.append(loss_numpy)
    if (i % test_frequency == 0 and i != 0):
        print('------test------')
        selected_sample_index = random.randint(0, len(myDataset))
        selected_sample = myDataset[selected_sample_index][0]
        selected_image_origin = np.expand_dims(selected_sample, 0)
        # print(selected_image_origin.shape)
        selected_image_origin = torch.tensor(selected_image_origin, dtype=torch.float)
        selected_image_origin_tensor = selected_image_origin.clone().detach().to(device)
        selected_image_infer_tensor = decoder(encoder(selected_image_origin_tensor))
        selected_image_infer = selected_image_infer_tensor.detach().cpu().numpy()[0]
        selected_image_origin_show =  selected_sample[0] * 255.0
        selected_image_infer_show = selected_image_infer[0] * 255.0
        image_concat = np.hstack((selected_image_origin_show, selected_image_infer_show))
        cv2.imwrite("result.png", image_concat)
    if (i % save_weight_frequency == 0 and i != 0):
        torch.save(encoder.state_dict(), "./weights/{}_encoder.pth".format(i))
        torch.save(decoder.state_dict(), "./weights/{}_decoder.pth".format(i))
    train_loss_avg = np.mean(train_loss_epoch)
    writer.add_scalar('loss', train_loss_avg * 1000, global_step=i+1)
    print("train epoch: {}, train loss: {}".format(i+1, train_loss_avg * 1000))

train epoch: 1, train loss: 249.25260245800018
train epoch: 2, train loss: 215.58405458927155
train epoch: 3, train loss: 196.9134956598282
train epoch: 4, train loss: 191.64380431175232
train epoch: 5, train loss: 185.77070534229279
train epoch: 6, train loss: 179.85732853412628
train epoch: 7, train loss: 176.38416588306427
train epoch: 8, train loss: 173.51478338241577
train epoch: 9, train loss: 168.60464215278625
train epoch: 10, train loss: 168.03908348083496
------test------
train epoch: 11, train loss: 168.68948936462402
train epoch: 12, train loss: 165.63820838928223
train epoch: 13, train loss: 167.0791208744049
train epoch: 14, train loss: 161.7806851863861
train epoch: 15, train loss: 160.65073013305664
train epoch: 16, train loss: 161.2980216741562
train epoch: 17, train loss: 158.60827267169952
train epoch: 18, train loss: 160.23363173007965
train epoch: 19, train loss: 157.59053826332092
train epoch: 20, train loss: 156.64660930633545
------test------
train epoch: 21, tr

KeyboardInterrupt: 