In [1]:
import torch,sys,os

from tqdm import tqdm

from torch import nn

from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import torchvision.transforms as transforms

# Model architecture

In [2]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=102, init_weights=False):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=4, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, k=2, alpha=1e-4, beta=0.75),
            nn.MaxPool2d(kernel_size=3, stride=2))

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=1, padding=2, bias=True),
            nn.ReLU(inplace=True),
            nn.LocalResponseNorm(size=5, k=2, alpha=1e-4, beta=0.75),
            nn.MaxPool2d(kernel_size=3, stride=2))

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True)
            )
        self.conv4 =nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True)
            )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

        
        self.FC = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(in_features=256*6*6,out_features= 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features= 4096,out_features= 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(in_features= 4096,out_features= num_classes),
        )

        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.conv1(x)
        # print(x.shape)
        x = self.conv2(x)
        # print(x.shape)
        x = self.conv3(x)
        # print(x.shape)
        x = self.conv4(x)
        # print(x.shape)
        x = self.conv5(x)
        # print(x.shape)
        x = torch.flatten(x, start_dim=1)
        y = self.FC(x)
        
        return y

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)



# STL10 Dataset

In [3]:
transform=transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5,),(0.5,0.5,0.5)),
	transforms.Resize([227, 227])
    ]
)

# Download training data from open datasets.
train_set = datasets.STL10(
    root="~/data/STL10/",
    split ='train',
    download=True,
    transform=transform, # transform,
)
trainloader=torch.utils.data.DataLoader(
	train_set,
	batch_size=72,
	shuffle=True,
	pin_memory=True,
    num_workers=8
	)


# Download test data from open datasets.
test_set = datasets.STL10(
    root="~/data/STL10/",
    split ='test',
    download=True,
    transform=transform, # transform,
)
testloader=torch.utils.data.DataLoader(
	test_set,
	batch_size=72,
	shuffle=False,
    pin_memory=True,
    num_workers=8
	)

# test_data_iter=iter(testloader)
# test_image,test_label=test_data_iter.next()
test_num  = len(test_set)
train_steps = len(trainloader)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to /root/data/STL10/stl10_binary.tar.gz


  0%|          | 0/2640397119 [00:00<?, ?it/s]

Extracting /root/data/STL10/stl10_binary.tar.gz to /root/data/STL10/


  cpuset_checked))


Files already downloaded and verified


# model define

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

model = AlexNet(num_classes=10, init_weights=True).to(device)

Using cuda device


# loss and optimize

In [5]:
# 定义一个损失函数
loss_fn = nn.CrossEntropyLoss()

# 定义一个优化器
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
# optimizer = torch.optim.SGD(model.parameters(),lr=0.001)

epochs = 40

save_path= './AlexNet.pth'
best_acc = 0.0

In [6]:
# train

for epoch in range(epochs):
        # train
        model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()
            outputs = model(images.to(device))
            loss = loss_fn(outputs, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

# validate
        model.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout) # show progress
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = model(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

        val_accurate = acc / test_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(model.state_dict(), save_path)

print('Finished Training') 

train epoch[1/40] loss:2.225: 100%|██████████| 70/70 [00:19<00:00,  3.65it/s]
100%|██████████| 112/112 [00:12<00:00,  9.25it/s]
[epoch 1] train_loss: 2.220  val_accuracy: 0.284
train epoch[2/40] loss:1.804: 100%|██████████| 70/70 [00:13<00:00,  5.30it/s]
100%|██████████| 112/112 [00:17<00:00,  6.39it/s]
[epoch 2] train_loss: 1.815  val_accuracy: 0.326
train epoch[3/40] loss:1.524: 100%|██████████| 70/70 [00:13<00:00,  5.32it/s]
100%|██████████| 112/112 [00:18<00:00,  5.98it/s]
[epoch 3] train_loss: 1.645  val_accuracy: 0.402
train epoch[4/40] loss:1.598: 100%|██████████| 70/70 [00:13<00:00,  5.26it/s]
100%|██████████| 112/112 [00:17<00:00,  6.37it/s]
[epoch 4] train_loss: 1.528  val_accuracy: 0.416
train epoch[5/40] loss:1.483: 100%|██████████| 70/70 [00:13<00:00,  5.29it/s]
100%|██████████| 112/112 [00:17<00:00,  6.36it/s]
[epoch 5] train_loss: 1.404  val_accuracy: 0.464
train epoch[6/40] loss:1.030: 100%|██████████| 70/70 [00:13<00:00,  5.15it/s]
100%|██████████| 112/112 [00:17<00:00

In [7]:
best_acc

0.541625