In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader,Dataset
from torchvision import models,datasets,transforms

  Referenced from: '/opt/anaconda3/envs/aloha/lib/python3.9/site-packages/torchvision/image.so'
  warn(


In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [3]:
batch_size = 100
train_data = datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=False)
test_data = datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor(),download=False)
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=False)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Sequential(nn.Conv2d(1,64,kernel_size=3,padding=1),
                                   nn.ReLU(),
                                   nn.Conv2d(64,128,kernel_size=3,padding=1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(kernel_size=2,stride=2))
        self.fc = nn.Sequential(nn.Linear(128*14*14,1024),
                                nn.ReLU(),
                                nn.Dropout(p=0.5),
                                nn.Linear(1024,10))
    def forward(self,X):
        X = self.conv1(X)
        X = X.view(-1,128*14*14)
        X = self.fc(X)
        return X

In [5]:
model = Net()
optimizer = optim.Adam(model.parameters(),lr=0.001)
criterion = nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)

CrossEntropyLoss()

In [18]:
output[:10][0].softmax(dim=-1)

tensor([7.3876e-05, 3.6528e-08, 4.6603e-06, 1.0847e-03, 6.0481e-08, 9.9870e-01,
        4.5019e-06, 6.8175e-06, 1.1494e-04, 7.8394e-06], device='mps:0',
       grad_fn=<SoftmaxBackward0>)

In [8]:
label[:10]

tensor([5, 1, 0, 0, 7, 8, 5, 1, 3, 0], device='mps:0')

In [6]:
epoch_nums = 5
#writer = SummaryWriter(log_dir='log')
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.1)
for epoch in range(epoch_nums):
    # train
    for batch_idx, (img, label) in enumerate(train_loader):
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        output = model(img)
        train_loss = criterion(output, label)
        train_loss.backward()
        optimizer.step()
        if (batch_idx+1)*batch_size % 1000 == 0:
            for param_group in optimizer.param_groups:
                current_lr = param_group['lr']
            print(f'Epoch {epoch+1}: lr={current_lr} {(batch_idx+1)*batch_size}/{batch_size*len(train_loader)} \
                    train_loss: {train_loss.item()}')
    writer.add_scalar('Loss/train', train_loss.item(), epoch)
    scheduler.step()
    # test
    total_loss, total_acc = 0, 0
    model.eval()
    with torch.no_grad():
        for img, label in test_loader:
            img, label = img.to(device), label.to(device)
            output = model(img)
            test_loss = criterion(output,label)
            pred = output.argmax(dim=1)
            total_loss += test_loss.item()
            total_acc += pred.eq(label.view_as(pred)).sum().item()
    print(f'Epoch {epoch+1}: lr={current_lr} test_loss={total_loss/len(test_loader)} test_acc={total_acc/len(test_loader)}')
    writer.add_scalar('Loss/test', test_loss.item(), epoch)        
            

Epoch 1: lr=0.001 1000/60000                     train_loss: 1.1865397691726685
Epoch 1: lr=0.001 2000/60000                     train_loss: 0.5906389355659485
Epoch 1: lr=0.001 3000/60000                     train_loss: 0.5485586524009705
Epoch 1: lr=0.001 4000/60000                     train_loss: 0.5059179067611694
Epoch 1: lr=0.001 5000/60000                     train_loss: 0.30246537923812866
Epoch 1: lr=0.001 6000/60000                     train_loss: 0.4741373360157013
Epoch 1: lr=0.001 7000/60000                     train_loss: 0.20468351244926453
Epoch 1: lr=0.001 8000/60000                     train_loss: 0.2261902391910553
Epoch 1: lr=0.001 9000/60000                     train_loss: 0.1321466565132141


KeyboardInterrupt: 

In [7]:
torch.save(model.state_dict(), 'MNISTmodel5.pth')
model = Net()
model.load_state_dict(torch.load('MNISTmodel5.pth'))
model.to(device)

Net(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=25088, out_features=1024, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=1024, out_features=10, bias=True)
  )
)

In [12]:
# 引入预训练模型
pretrained_model_path = '/pretrainedmodel/resnet18-f37072fd.pth'
# 加载模型
state_dict = torch.load(pretrained_model_path)

resnet = models.resnet18()
resnet.load_state_dict(state_dict)
resnet.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
for param in resnet.parameters():
    param.requires_grad = False
num_classes = 10   
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
for param in resnet.fc.parameters():
    param.requires_grad = True
    
for name, param in resnet.named_parameters():
    if param.requires_grad:
        print(f"{name} is trainable.")
    else:
        print(f"{name} is frozen.")
model = resnet

conv1.weight is frozen.
conv1.bias is frozen.
bn1.weight is frozen.
bn1.bias is frozen.
layer1.0.conv1.weight is frozen.
layer1.0.bn1.weight is frozen.
layer1.0.bn1.bias is frozen.
layer1.0.conv2.weight is frozen.
layer1.0.bn2.weight is frozen.
layer1.0.bn2.bias is frozen.
layer1.1.conv1.weight is frozen.
layer1.1.bn1.weight is frozen.
layer1.1.bn1.bias is frozen.
layer1.1.conv2.weight is frozen.
layer1.1.bn2.weight is frozen.
layer1.1.bn2.bias is frozen.
layer2.0.conv1.weight is frozen.
layer2.0.bn1.weight is frozen.
layer2.0.bn1.bias is frozen.
layer2.0.conv2.weight is frozen.
layer2.0.bn2.weight is frozen.
layer2.0.bn2.bias is frozen.
layer2.0.downsample.0.weight is frozen.
layer2.0.downsample.1.weight is frozen.
layer2.0.downsample.1.bias is frozen.
layer2.1.conv1.weight is frozen.
layer2.1.bn1.weight is frozen.
layer2.1.bn1.bias is frozen.
layer2.1.conv2.weight is frozen.
layer2.1.bn2.weight is frozen.
layer2.1.bn2.bias is frozen.
layer3.0.conv1.weight is frozen.
layer3.0.bn1.weig