# Import necessities


In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torchvision.transforms as transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Build the model

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # Define the parameters here
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.c1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)  # Input(3, 32, 32) Output(6, 28, 28)
        self.s2 = nn.MaxPool2d(2, 2)  # Output (6, 14, 14)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)  # Input(6, 14, 14) Output(16, 10, 10)
        self.s4 = nn.MaxPool2d(2, 2)  # Output (16, 5, 5)
        self.c5 = nn.Linear(16 * 5 * 5, 120)
        self.f6 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    def forward(self, x):
        x = self.relu(self.c1(x))
        x = self.s2(x)
        x = self.relu(self.c3(x))
        x = self.s4(x)
        x = self.relu(self.c5(x.view(x.size()[0], -1)))
        x = self.relu(self.f6(x))
        x = self.out(x)

        return x


model = LeNet().to(device)
# 使用随机生成的样例测试模型
# 四个维度分别为[batch_size, channels, height, width]
t1 = torch.rand([10, 3, 32, 32])
model(t1)

# Set optimizer and loss function

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Load data for training
和MNIST类似，加载数据集有两种方式torch.utils.data.DataLoader
- 使用Dataset与DataLoader加载
- 使用官方提供的函数加载

其中前者由于torchvision中提供了函数`torchvision.datasets.CIFAR10`，不用手写继承 Dataset 类处理函数，因此更为方便

In [None]:
# 使用torchvision提供的函数

# 归一化，after = (before - mean) / std
# 前三个0.5代表每个通道的mean 后三个代表每个通道的std
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./src/', train=True,
                                         download=False, transform=transform)
train_loader = Data.DataLoader(train_set, batch_size=36,
                               shuffle=False, num_workers=0)

test_set = torchvision.datasets.CIFAR10(root='./src/', train=False,
                                        download=False, transform=transform)
test_loader = Data.DataLoader(test_set, batch_size=5000,
                              shuffle=False, num_workers=0)


def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        _dict = pickle.load(fo, encoding='bytes')
    return _dict


a = unpickle("./src/cifar-10-batches-py/data_batch_1")

以下函数可显示读取到的数据，可用于检测数据是否正确读取

此外需要注意的是，DataLoader类型的数据需要转化为迭代器后取出

`np.transpose`用于调整图片的维度顺序，原本为[3, 32, 32]，调整后为[32, 32, 3]以便输出

In [None]:
# 展示图片
def img_show(data):
    data = data / 2 + 0.5
    plt.imshow(np.transpose(data.numpy(), (1, 2, 0)))
    plt.show()


img_data = iter(train_loader)
img_data, _ = next(img_data)
img_show(img_data[1])
data = img_data[1] / 2 + 0.5
np.transpose(data.numpy(), (1, 2, 0)).shape

# Start Training

In [None]:
model.train()
epochs = 5
for epoch in range(epochs):
    sum_loss = .0
    for i, (in_data, out_data) in enumerate(train_loader):
        in_data = in_data.to(device)
        out_data = out_data.to(device)
        pred = model(in_data)
        loss = criterion(pred, out_data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        sum_loss += loss
    print("epoch ", epoch, "  Loss: ", np.float32(sum_loss.data))

# Test

In [None]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for i, (in_data, out_data) in enumerate(test_loader):
        outputs = model(in_data)
        _, predicted = torch.max(outputs.data, 1)
        total += out_data.shape[0]
        correct += (predicted == out_data).sum()
print("Correct rate: ", np.float32(100 * correct / total), "%")


# 尝试使用`pytorch lightning`
For more tutorials, visit [here](https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html)

You can simply turn PyTorch into Lightning by just reading [this](https://pytorch-lightning.readthedocs.io/en/latest/starter/converting.html)

In [325]:
import pytorch_lightning as pl
from torchmetrics.functional import accuracy
class LeNet_pl(pl.LightningModule):
    def __init__(self):
        super(LeNet_pl, self).__init__()
        # Define the parameters here
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.c1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5)  # Input(3, 32, 32) Output(6, 28, 28)
        self.s2 = nn.MaxPool2d(2, 2)  # Output (6, 14, 14)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)  # Input(6, 14, 14) Output(16, 10, 10)
        self.s4 = nn.MaxPool2d(2, 2)  # Output (16, 5, 5)
        self.c5 = nn.Linear(16 * 5 * 5, 120)
        self.f6 = nn.Linear(120, 84)
        self.out = nn.Linear(84, 10)

    def forward(self, x):
        x = self.relu(self.c1(x))
        x = self.s2(x)
        x = self.relu(self.c3(x))
        x = self.s4(x)
        x = self.relu(self.c5(x.view(x.size()[0], -1)))
        x = self.relu(self.f6(x))
        x = self.out(x)

        return x

    def configure_optimizers(self):
        _optimizer = optim.Adam(model.parameters(), lr=1e-3)
        return _optimizer

    def training_step(self, train_batch, batch_idx):
        _x, _y = train_batch
        _pred = self(_x)
        _loss = F.cross_entropy(_pred, _y)
        # _loss = F.mse_loss(_pred, _y)
        self.log('train_loss', _loss)
        return _loss

    def validation_step(self, val_batch, batch_idx):
        _x, _y = val_batch
        _pred = self(_x)
        _loss = F.cross_entropy(_pred, _y)
        # _loss = F.mse_loss(_pred, _y)
        acc = accuracy(_pred, _y)
        self.log('acc', acc, prog_bar=True)
        self.log('val_loss', _loss)


# Set up for the dataset and dataloader
trans = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_set = torchvision.datasets.CIFAR10(root='./src/', train=True,
                                         download=False, transform=trans)
train_loader = Data.DataLoader(train_set, batch_size=36,
                               shuffle=True, num_workers=0)

val_set = torchvision.datasets.CIFAR10(root='./src/', train=False,
                                       download=False, transform=trans)
val_loader = Data.DataLoader(val_set, batch_size=5000,
                             shuffle=False, num_workers=0)

# model
model = LeNet_pl()
# using tensorboard to visualize in pytorch lightning
# execute `tensorboard --logdir ./lightning_logs` and visit through the browser
from pytorch_lightning.loggers import TensorBoardLogger
logger = TensorBoardLogger('lightning_logs/', name='LeNet-5')
# Remember to set your own training parameters for the trainer
# You can get to know more about the trainer at https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#module-pytorch_lightning.trainer.trainer
trainer = pl.Trainer(max_epochs=15, logger=logger)
# training
trainer.fit(model, train_loader, val_loader)





GPU available: False, used: False
TPU available: None, using: 0 TPU cores

  | Name    | Type      | Params
--------------------------------------
0 | relu    | ReLU      | 0     
1 | sigmoid | Sigmoid   | 0     
2 | c1      | Conv2d    | 456   
3 | s2      | MaxPool2d | 0     
4 | c3      | Conv2d    | 2.4 K 
5 | s4      | MaxPool2d | 0     
6 | c5      | Linear    | 48.1 K
7 | f6      | Linear    | 10.2 K
8 | out     | Linear    | 850   
--------------------------------------
62.0 K    Trainable params
0         Non-trainable params
62.0 K    Total params
0.248     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1