<a href="https://colab.research.google.com/github/cnhzgb/MachineL/blob/main/RESNET_TINY_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import ipdb
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from einops import rearrange
from torchsummary import summary
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
device = torch.device("mps")
print(device)



True
True
mps


In [2]:
trans = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomGrayscale(), transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
dataset = datasets.CIFAR10(root="/Users/bin.guanb/code/MachineL/dataset/", transform=trans, download=False, train=True) # 5W张图片, 10种分类
loader = DataLoader(dataset, batch_size=100, shuffle=True)
batch_num,(image, label) = next(enumerate(loader))
print(len(dataset.classes), len(dataset), image.shape, label.shape) # 10; 5W; 100,3,32,32; 100

trans_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
dataset_test = datasets.CIFAR10(root="/Users/bin.guanb/code/MachineL/dataset/", transform=trans_test, download=False, train=False) # 1W张图片
loader_test = DataLoader(dataset_test, batch_size=100, shuffle=True)
criterion_test = nn.CrossEntropyLoss()

10 50000 torch.Size([100, 3, 32, 32]) torch.Size([100])


In [10]:
# https://www.cnblogs.com/emanlee/p/17138634.html

class Block(nn.Module):
  def __init__(self, inc, n_chans):
    super(Block, self).__init__()
    self.conv1 = nn.Conv2d(inc, n_chans, kernel_size=3, padding=1)
    self.conv2 = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1)
    self.batch_norm = nn.BatchNorm2d(num_features=n_chans)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(self.batch_norm(x))
    x = self.conv2(x)
    x = self.batch_norm(x)
    return x

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.pool = nn.MaxPool2d(2, 2)

        self.conv10 = nn.Conv2d(3,32,1,padding=0)
        self.conv11 = Block(3,32)
        self.conv12 = Block(32,32)

        self.conv20 = nn.Conv2d(32,64,1,padding=0)
        self.conv21 = Block(32,64)
        self.conv22 = Block(64,64)

        self.conv30 = nn.Conv2d(64,128,1,padding=0)
        self.conv31 = Block(64,128)
        self.conv32 = Block(128,128)

        self.conv40 = nn.Conv2d(128,256,1,padding=0)
        self.conv41 = Block(128,256)
        self.conv42 = Block(256,256)

        self.conv50 = nn.Conv2d(256,512,1,padding=0)
        self.conv51 = Block(256,512)
        self.conv52 = Block(512,512)

        self.fc1 = nn.Linear(512*2*2, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self,x):
        out1 = self.conv10(x) # 32,32,32
        out2 = self.conv11(x)
        x = F.relu(out1+out2)
        out3 = self.conv12(x)
        x = self.pool(F.relu(x + out3)) # ,32,16,16

        out1 = self.conv20(x) # 64,16,16
        out2 = self.conv21(x)
        x = F.relu(out1+out2)
        out3 = self.conv22(x)
        x = self.pool(F.relu(x + out3)) # ,64,8,8

        out1 = self.conv30(x) # 128,8,8
        out2 = self.conv31(x)
        x = F.relu(out1+out2)
        out3 = self.conv32(x)
        x = self.pool(F.relu(x + out3)) # 128,4,4

        out1 = self.conv40(x) # 256,4,4
        out2 = self.conv41(x)
        x = F.relu(out1+out2)
        out3 = self.conv42(x)
        x = self.pool(F.relu(x + out3)) # 256,2,2

        out1 = self.conv50(x) # 512,2,2
        out2 = self.conv51(x)
        x = F.relu(out1+out2)
        out3 = self.conv52(x)
        x = F.relu(x + out3)

        x = x.view(-1,512*2*2)

        x = F.relu(self.fc1(x))
        x = self.fc3(self.fc2(x))
        return x

model = Net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-4)

#print(model)
#model_cpu = Net().to('cpu')
#summary(model_cpu, (3,32,32))

In [11]:
for epoch in range(1,20):
  total_loss = []
  errorTotal = 0
  startTime = time.time()
  for batch_idx, (img, label) in enumerate(loader):
    img = img.to(device)
    label = label.to(device)
    outputs = model(img)
    loss = criterion(outputs, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss.append(loss.item())
    maxV,maxIdx = outputs.max(dim=1)
    errorNum = torch.sum(torch.ne(maxIdx, label)).item()
    errorTotal += errorNum

    if(batch_idx % 100 == 99):
      print("epoch:{} batch:{} loss:{:.2f} mean:{:.2f} error:{}/100 errorTotal:{}/{} {:.2f}% time:{:.1f}".format(
          epoch, batch_idx, loss, np.mean(total_loss), errorNum, errorTotal, (batch_idx+1)*100, errorTotal/(batch_idx+1),
          time.time()-startTime))

epoch:1 batch:99 loss:1.52 mean:1.71 error:60/100 errorTotal:6454/10000 64.54% time:8.3
epoch:1 batch:199 loss:1.41 mean:1.55 error:50/100 errorTotal:11507/20000 57.53% time:16.6
epoch:1 batch:299 loss:1.22 mean:1.46 error:43/100 errorTotal:16169/30000 53.90% time:24.9
epoch:1 batch:399 loss:1.15 mean:1.38 error:40/100 errorTotal:20319/40000 50.80% time:33.1
epoch:1 batch:499 loss:0.98 mean:1.32 error:33/100 errorTotal:24195/50000 48.39% time:41.4
epoch:2 batch:99 loss:1.10 mean:0.99 error:41/100 errorTotal:3541/10000 35.41% time:8.3
epoch:2 batch:199 loss:0.89 mean:0.96 error:31/100 errorTotal:6921/20000 34.60% time:16.5
epoch:2 batch:299 loss:0.79 mean:0.95 error:28/100 errorTotal:10237/30000 34.12% time:24.8
epoch:2 batch:399 loss:0.79 mean:0.93 error:31/100 errorTotal:13379/40000 33.45% time:33.1
epoch:2 batch:499 loss:0.85 mean:0.92 error:31/100 errorTotal:16413/50000 32.83% time:41.4
epoch:3 batch:99 loss:0.77 mean:0.77 error:29/100 errorTotal:2716/10000 27.16% time:8.3
epoch:3 b

In [None]:
total_loss = []
errorTotal = 0
for batch_idx, (img, label) in enumerate(loader_test):
  img = img.to(device)
  label = label.to(device)
  outputs = model(img)
  loss = criterion_test(outputs, label)

  total_loss.append(loss.item())
  maxV,maxIdx = outputs.max(dim=1)
  errorNum = torch.sum(torch.ne(maxIdx, label)).item()
  errorTotal += errorNum

  if(batch_idx % 20 == 0):
    print("epoch:{} batch:{} loss:{:.2f} mean:{:.2f} error:{}/100 errorTotal:{}/{} {:.2f}%".format(
           epoch, batch_idx, loss, np.mean(total_loss), errorNum, errorTotal, (batch_idx+1)*100, errorTotal/(batch_idx+1)))

