In [1]:
import torch,sys,os
from torch import nn

from torch.utils.data import DataLoader
from torchvision import datasets

from torchvision.transforms import ToTensor,transforms

from tqdm import tqdm

In [2]:

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [3]:
from statistics import mode
from turtle import forward


class LeNet(nn.Module):
    def __init__(self, num_classes, init_weights=False) -> None:
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2))

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(kernel_size=2, stride=2))
        
        self.out = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=400,out_features=120),
            nn.Linear(in_features=120,out_features=84),
            nn.Linear(in_features=84,out_features=num_classes)
            )
            
        if init_weights:
            self.weight_init()

    def forward(self,x):
        x = self.conv1(x)
        x = self.conv2(x)
        y = self.out(x)
        return y

    def weight_init(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='tanh')
                # nn.init.normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

In [4]:
# Download training data from open datasets.
train_set = datasets.CIFAR10(
    root="~/data/CIFAR10/",
    train=True,
    download=True,
    transform=ToTensor() # transform,
)
trainloader=torch.utils.data.DataLoader(
	train_set,
	batch_size=144,
	shuffle=False,
	pin_memory=True,
    num_workers=8
	)


# Download test data from open datasets.
test_set = datasets.CIFAR10(
    root="~/data/CIFAR10/",
    train=False,
    download=True,
    transform=ToTensor() # transform,
)
testloader=torch.utils.data.DataLoader(
	test_set,
	batch_size=10000,
	shuffle=False,
	pin_memory=True,
    num_workers=8
	)

test_data_iter=iter(testloader)
test_image,test_label=test_data_iter.next()
test_num  = len(test_set)

train_steps = len(trainloader)

model = LeNet(num_classes=10,init_weights=True).to(device)
print(model)

Files already downloaded and verified
Files already downloaded and verified
LeNet(
  (conv1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (out): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=400, out_features=120, bias=True)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Linear(in_features=84, out_features=10, bias=True)
  )
)


In [5]:
# 定义一个损失函数
loss_fn = nn.CrossEntropyLoss()

# 定义一个优化器
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)

epochs = 40

save_path= './LeNet.pth'
best_acc = 0.0


In [6]:
# train

for epoch in range(epochs):
        # train
        model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = model(images.to(device))
            loss = loss_fn(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout) # show progress
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / test_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

print('Finished Training')

train epoch[1/40] loss:2.221: 100%|██████████| 348/348 [00:04<00:00, 83.40it/s] 
100%|██████████| 1/1 [00:01<00:00,  1.61s/it]
[epoch 1] train_loss: 1.949  val_accuracy: 0.341
train epoch[2/40] loss:2.067: 100%|██████████| 348/348 [00:03<00:00, 95.58it/s] 
100%|██████████| 1/1 [00:01<00:00,  1.65s/it]
[epoch 2] train_loss: 1.738  val_accuracy: 0.399
train epoch[3/40] loss:1.915: 100%|██████████| 348/348 [00:03<00:00, 92.90it/s] 
100%|██████████| 1/1 [00:01<00:00,  1.43s/it]
[epoch 3] train_loss: 1.632  val_accuracy: 0.429
train epoch[4/40] loss:1.809: 100%|██████████| 348/348 [00:03<00:00, 97.63it/s] 
100%|██████████| 1/1 [00:01<00:00,  1.44s/it]
[epoch 4] train_loss: 1.567  val_accuracy: 0.445
train epoch[5/40] loss:1.761: 100%|██████████| 348/348 [00:03<00:00, 102.18it/s]
100%|██████████| 1/1 [00:01<00:00,  1.44s/it]
[epoch 5] train_loss: 1.517  val_accuracy: 0.457
train epoch[6/40] loss:1.707: 100%|██████████| 348/348 [00:03<00:00, 101.81it/s]
100%|██████████| 1/1 [00:01<00:00,  1.4

In [7]:
test_index = 5

test_lab = test_label[test_index]
test_img = val_images[test_index].unsqueeze(0)

weights_path = "./LeNet.pth"
assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
model.load_state_dict(torch.load(weights_path))

model.eval()

LeNet(
  (conv1): Sequential(
    (0): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (conv2): Sequential(
    (0): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (out): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=400, out_features=120, bias=True)
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [8]:

with torch.no_grad():
    outputs = model(test_img.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    print(predict_y)
    print(test_lab)

tensor([6], device='cuda:0')
tensor(6)
