<a href="https://colab.research.google.com/github/lipeng2021/-python1/blob/main/5_12DenseNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from IPython import display
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
FILENAME = '/home/lp'
device = torch.device('cuda'if torch.cuda.is_available()else 'cpu')

In [34]:
class GlobalAveragePool(nn.Module):
  def __init__(self):
    super(GlobalAveragePool,self).__init__()
  def forward(self,x):
    return F.avg_pool2d(x,kernel_size=x.size()[2:])
class FlattenLayer(nn.Module):
  def __init__(self):
    super(FlattenLayer,self).__init__()
  def forward(self,x):
    return x.view(x.shape[0],-1)

In [23]:
def conv_bolck(in_channels,out_channels):
  blk = nn.Sequential(
      nn.BatchNorm2d(in_channels),
      nn.ReLU(),
      nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1)
  )
  return blk

In [24]:
from torch.nn.modules.container import ModuleList
class Dense_block(nn.Module):
  def __init__(self,num_convs,in_channels,out_channels):
    super(Dense_block,self).__init__()
    net = []
    for i in range(num_convs):
      in_c = in_channels+i*out_channels
      net.append(conv_bolck(in_c,out_channels))
    self.net = nn.ModuleList(net)
    self.out_channels = in_channels+num_convs*out_channels
  def forward(self,x):
    for blk in self.net:
      y = blk(x)
      x = torch.cat((y,x),dim=1)
    return x

In [25]:
x = torch.rand(4,3,8,8)
blk = Dense_block(2,3,10)
y = blk(x)
y.shape

torch.Size([4, 23, 8, 8])

In [26]:
def transition_block(in_channels,out_channels):
  blk =nn.Sequential(
      nn.BatchNorm2d(in_channels),
      nn.ReLU(),
      nn.Conv2d(in_channels,out_channels,kernel_size=1),
      nn.AvgPool2d(kernel_size=2,stride=2)
  )
  return blk

In [27]:
blk = transition_block(23,10)
y = blk(y)
y.shape

torch.Size([4, 10, 4, 4])

In [29]:
net = nn.Sequential(
    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
)

In [31]:
num_channels=64
growth_rate=32
num_convs_in_dense_blocks = [4,4,4,4]
for i,num_convs in enumerate(num_convs_in_dense_blocks):
  DB = Dense_block(num_convs,num_channels,growth_rate)
  net.add_module('DenseBlock_%d'%i,DB)
  num_channels = DB.out_channels
  if i!= len(num_convs_in_dense_blocks)-1:
    net.add_module('transition_block_%d'%i,transition_block(num_channels,num_channels//2))
    num_channels = num_channels//2

In [35]:
net.add_module('bn',nn.BatchNorm2d(num_channels))
net.add_module('relu',nn.ReLU())
net.add_module('GlobalAveragePool',GlobalAveragePool())
net.add_module('FlattenLayer',FlattenLayer())
net.add_module('Linear',nn.Linear(num_channels,10))

In [38]:
x = torch.rand(1,1,96,96)
net(x)
for blk in net.children():
  x = blk(x)
  print(x.shape)

torch.Size([1, 64, 48, 48])
torch.Size([1, 64, 48, 48])
torch.Size([1, 64, 48, 48])
torch.Size([1, 64, 24, 24])
torch.Size([1, 192, 24, 24])
torch.Size([1, 96, 12, 12])
torch.Size([1, 224, 12, 12])
torch.Size([1, 112, 6, 6])
torch.Size([1, 240, 6, 6])
torch.Size([1, 120, 3, 3])
torch.Size([1, 248, 3, 3])
torch.Size([1, 248, 3, 3])
torch.Size([1, 248, 3, 3])
torch.Size([1, 248, 1, 1])
torch.Size([1, 248])
torch.Size([1, 10])


In [42]:
def load_data_fashion_mnist(batch_size,resize=None,root='/home/lp'):
  trans = []
  if resize:
    trans.append(torchvision.transforms.Resize(size=resize))
  trans.append(torchvision.transforms.ToTensor())
  transform = torchvision.transforms.Compose(trans)
  mnist_train = torchvision.datasets.FashionMNIST(root=root,train=True,transform=transform,download=True)
  mnist_test = torchvision.datasets.FashionMNIST(root=root,train=False,transform=transform,download=True)
  train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=2)
  test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=True,num_workers=2)
  return train_iter,test_iter

batch_size = 256
train_iter,test_iter = load_data_fashion_mnist(batch_size,resize=96)
def evalucate_accuracy(data_iter,net,device):
  acc_sum=0.0
  n=0
  with torch.no_grad():
    for X,y in data_iter:
     if isinstance(net,nn.Module):
        net.eval()
        acc_sum += (net(X.to(device)).argmax(dim=1)==y.to(device)).float().sum().cpu().item()
        net.train()
     else:
      if ('is_training'in net.__code__.co_varnames):
        acc_sum +=(net(X,is_training=False).argmax(dim=1)==y).float().sum().item()
      else:
        acc_sum +=(net(X).argmax(dim=1)==y).float().sum().item()
     n +=y.shape[0]
    return acc_sum/n
def train_ch5(net,train_iter,test_iter,batch_size,device,optimizer,num_epochs):
  loss = torch.nn.CrossEntropyLoss()
  net.to(device)
  print('train on',device)
  batch_cout = 0 
  for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc_sum = 0.0
    n = 0
    start = time.time()
    for X,y in train_iter:
      X=X.to(device)
      y=y.to(device)
      y_hat = net(X)
      l = loss(y_hat,y)
      optimizer.zero_grad()
      l.backward()
      optimizer.step()
      train_loss += l.cpu().item()
      train_acc_sum += (net(X).argmax(dim=1)==y).float().sum().cpu().item()
      n += y.shape[0]
      batch_cout +=1
      print(batch_cout,end=' ')
    test_acc_sum = evalucate_accuracy(test_iter,net,device)
    print('epoch:%d,loss:%.4f,train_acc:%.3f,test_acc:%.3f,time:%.1f'%(epoch+1,train_loss/batch_cout,train_acc_sum/n,test_acc_sum,time.time()-start))

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /home/lp/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /home/lp/FashionMNIST/raw/train-images-idx3-ubyte.gz to /home/lp/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /home/lp/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /home/lp/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /home/lp/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /home/lp/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /home/lp/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /home/lp/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /home/lp/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /home/lp/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /home/lp/FashionMNIST/raw

