<a href="https://colab.research.google.com/github/lipeng2021/-python1/blob/main/5_12ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
from torch.cuda import is_available
import torch
import torch.nn as nn
import torchvision 
import torchvision.transforms as transforms
from IPython import display
import time
FILENAME = '/home/lp'
device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')
import torch.nn.functional as F

In [15]:
class Residual(nn.Module):
  def __init__(self,in_channels,out_channels,use_1x1conv=False,stride=1):
    super(Residual,self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,stride=stride),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,stride=1),
        nn.BatchNorm2d(out_channels),
    )
    if use_1x1conv:
      self.conv_1x1 = nn.Conv2d(in_channels,out_channels,stride=stride,kernel_size=1)
    else:
      self.conv_1x1 = None
    self.conv_end = nn.Sequential(
        nn.ReLU()
    )
  def forward(self,x):
    feature = self.conv(x)
    if self.conv_1x1:
      x = self.conv_1x1(x)
    y = self.conv_end(feature+x)
    return y

In [16]:
blk = Residual(3,3)
x = torch.rand((4,3,6,6))
print(blk(x).shape)
blk=Residual(3,6,use_1x1conv=True,stride=2)
blk(x).shape

torch.Size([4, 3, 6, 6])


torch.Size([4, 6, 3, 3])

In [17]:
net = nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)

In [18]:
def resnet_block(in_channels,out_channels,num_residuals,first_block=False):
  if first_block:
    assert in_channels==out_channels
  blk = []
  for i in range(num_residuals):
    if i==0 and not first_block:    
      blk.append(Residual(in_channels,out_channels,use_1x1conv=True,stride=2))
    else:
      blk.append(Residual(out_channels,out_channels))
  return nn.Sequential(*blk)

In [19]:
net.add_module('resnet_block1',resnet_block(64,64,2,first_block=True))
net.add_module('resnet_block2',resnet_block(64,128,2))
net.add_module('resnet_block3',resnet_block(128,256,2))
net.add_module('resnet_block4',resnet_block(256,512,2))

In [20]:
class GlobalAveragePool(nn.Module):
  def __init__(self):
    super(GlobalAveragePool,self).__init__()
  def forward(self,x):
    return F.avg_pool2d(x,kernel_size=x.size()[2:])

In [21]:
class FlattenLayer(nn.Module):
  def __init__(self):
    super(FlattenLayer,self).__init__()
  def forward(self,x):
    return x.view(x.shape[0],-1)

In [22]:
net.add_module('GlobalAveragePool',GlobalAveragePool())
net.add_module('FlattenLayer',FlattenLayer())
net.add_module('Fullcollection',nn.Linear(512,10))

In [23]:
x = torch.rand(1,1,224,224)
for name,blk in net.named_children():
  x = blk(x)
  print(name,x.shape)

0 torch.Size([1, 64, 112, 112])
1 torch.Size([1, 64, 112, 112])
2 torch.Size([1, 64, 112, 112])
3 torch.Size([1, 64, 56, 56])
resnet_block1 torch.Size([1, 64, 56, 56])
resnet_block2 torch.Size([1, 128, 28, 28])
resnet_block3 torch.Size([1, 256, 14, 14])
resnet_block4 torch.Size([1, 512, 7, 7])
GlobalAveragePool torch.Size([1, 512, 1, 1])
FlattenLayer torch.Size([1, 512])
Fullcollection torch.Size([1, 10])


In [28]:
def load_data_fashion_mnist(batch_size,resize=None,root='/home/lp'):
  trans = []
  if resize:
    trans.append(torchvision.transforms.Resize(size=resize))
  trans.append(torchvision.transforms.ToTensor())
  transform = torchvision.transforms.Compose(trans)
  mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,transform=transform,download=True)
  mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,transform=transform,download=True)
  train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=2)
  test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=2)
  return train_iter,test_iter

batch_size = 256
train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=96)
def evalucate_accuracy(data_iter,net,device):
  acc_sum=0.0
  n=0
  with torch.no_grad():
    for X,y in data_iter:
     if isinstance(net,nn.Module):
        net.eval()
        acc_sum += (net(X.to(device)).argmax(dim=1)==y.to(device)).float().sum().cpu().item()
        net.train()
     else:
      if ('is_training'in net.__code__.co_varnames):
        acc_sum +=(net(X,is_training=False).argmax(dim=1)==y).float().sum().item()
      else:
        acc_sum +=(net(X).argmax(dim=1)==y).float().sum().item()
     n +=y.shape[0]
    return acc_sum/n
def train_ch5(net,train_iter,test_iter,batch_size,device,optimizer,num_epochs):
  loss = torch.nn.CrossEntropyLoss()
  net.to(device)
  print('train on',device)
  batch_cout = 0 
  for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc_sum = 0.0
    n = 0
    start = time.time()
    for X,y in train_iter:
      X=X.to(device)
      y=y.to(device)
      y_hat = net(X)
      l = loss(y_hat,y)
      optimizer.zero_grad()
      l.backward()
      optimizer.step()
      train_loss += l.cpu().item()
      train_acc_sum += (net(X).argmax(dim=1)==y).float().sum().cpu().item()
      n += y.shape[0]
      batch_cout +=1
      print(batch_cout,end=' ')
    test_acc_sum = evalucate_accuracy(test_iter,net,device)
    print('epoch:%d,loss:%.4f,train_acc:%.3f,test_acc:%.3f,time:%.1f'%(epoch+1,train_loss/batch_cout,train_acc_sum/n,test_acc_sum,time.time()-start))

In [None]:
lr =0.001
num_epochs =5
optimizer = torch.optim.Adam(net.parameters(),lr=lr)
train_ch5(net,train_iter,test_iter,batch_size,device,optimizer,num_epochs)