In [6]:
#include <iostream>

/*a workaround to solve cling issue*/
#include "../inc/macos_cling_workaround.hpp"
/*set libtorch path, load libs*/
#include "../inc/load_libtorch.hpp"
/*import custom defined macros*/
#include "../inc/custom_def.hpp"
/*import matplotlibcpp*/
#include "../inc/load_matplotlibcpp.hpp"

/*import libtorch header file*/
#include <torch/torch.h>

# 图像分类数据集：Fasion-MNIST

***注意：libtorch官方有MNIST数据集分类[例程](https://pytorch.org/cppdocs/frontend.html)，需要的可以自行参考；***    
fashion-mnist数据集与mnist数据集的文件类型、数据格式一样，因此可以考虑搬用MNIST的处理代码；

In [7]:
/*design a net*/
/**
 * 同官方例子，骨干网络为三个全连接层，但是每层外加relu和dropout，
 * 最后再跟一个softmax进行分类；
 *
 */
struct Net : torch::nn::Module {
  Net() {
    // Construct and register two Linear submodules.
    fc1 = register_module("fc1", torch::nn::Linear(784, 64));
    fc2 = register_module("fc2", torch::nn::Linear(64, 32));
    fc3 = register_module("fc3", torch::nn::Linear(32, 10));
  }

  // Implement the Net's algorithm.
  torch::Tensor forward(torch::Tensor x) {
    // Use one of many tensor manipulation functions.
    x = torch::relu(fc1->forward(x.reshape({x.size(0), 784})));
    x = torch::dropout(x, /*p=*/0.5, /*train=*/is_training());
    x = torch::relu(fc2->forward(x));
    x = torch::log_softmax(fc3->forward(x), /*dim=*/1);
    return x;
  }

  // Use one of many "standard library" modules.
  torch::nn::Linear fc1{nullptr}, fc2{nullptr}, fc3{nullptr};
};

In [9]:
// Create a new Net.
auto net = std::make_shared<Net>();

// Create a multi-threaded data loader for the MNIST dataset.
auto data_loader = torch::data::make_data_loader(
  torch::data::datasets::MNIST("../dataset/fashion_mnist").map(
      torch::data::transforms::Stack<>()),
  /*batch_size=*/100);

// Instantiate an SGD optimization algorithm to update our Net's parameters.
torch::optim::SGD optimizer(net->parameters(), /*lr=*/0.01);

for (size_t epoch = 1; epoch <= 5; ++epoch) {
    size_t batch_index = 0;
    // Iterate the data loader to yield batches from the dataset.
    for (auto& batch : *data_loader) {
        // Reset gradients.
        optimizer.zero_grad();
        // Execute the model on the input data.
        torch::Tensor prediction = net->forward(batch.data);
        // Compute a loss value to judge the prediction of our model.
        torch::Tensor loss = torch::nll_loss(prediction, batch.target);
        // Compute gradients of the loss w.r.t. the parameters of our model.
        loss.backward();
        // Update the parameters based on the calculated gradients.
        optimizer.step();
        // Output the loss and checkpoint every 100 batches.
        if (++batch_index % 500 == 0) {
        std::cout << "Epoch: " << epoch << " | Batch: " << batch_index
                  << " | Loss: " << loss.item<float>() << std::endl;
        // Serialize your model periodically as a checkpoint.
        torch::save(net, "net.pt");
        }
    }    
}

std::cout << std::endl << "\r\nTraining finished!\r\n" << std::endl;

Epoch: 1 | Batch: 500 | Loss: 1.8686
Epoch: 2 | Batch: 500 | Loss: 1.09727
Epoch: 3 | Batch: 500 | Loss: 0.933681
Epoch: 4 | Batch: 500 | Loss: 0.837059
Epoch: 5 | Batch: 500 | Loss: 0.757047


Training finished!

