<a href="https://colab.research.google.com/github/eisbetterthanpi/vision/blob/main/resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title data
import torch
import torch.nn as nn
# import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# https://github.com/python-engineer/pytorchTutorial/blob/master/14_cnn.py

# dataset has PILImage images of range [0, 1], transform them to Tensors of normalized range [-1, 1]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# transform = transforms.Compose(transforms.ToTensor())

# CIFAR10: 60000 32x32 color images in 10 classes, with 6000 images per class
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
batch_size = 4
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

dataiter = iter(train_loader) # get some random training images
images, labels = dataiter.next()
# imshow(torchvision.utils.make_grid(images))



In [None]:
# @title simplifi
# https://github.com/JayPatwardhan/ResNet-PyTorch/blob/master/ResNet/ResNet.py
import torch
import torch.nn as nn
import torch.nn.functional as F


class Block(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels), nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        self.i_downsample = i_downsample

    def forward(self, x):
        identity = x.clone()
        x = self.conv(x)
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        x += identity
        x = nn.ReLU()(x)
        return x


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_channels), nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels), nn.ReLU(),
            nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_channels*self.expansion),
        )
        self.i_downsample = i_downsample
        
    def forward(self, x):
        identity = x.clone()
        x = self.conv(x)
        if self.i_downsample is not None: #downsample if needed
            identity = self.i_downsample(identity)
        x += identity #add identity
        x = nn.ReLU()(x)
        return x


class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        plane_list=[4,8,16,32]
        # self.in_channels = 64
        self.in_channels = plane_list[0]
        self.conv = nn.Sequential(
            # nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            # nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(num_channels, plane_list[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(plane_list[0]), nn.ReLU(),

            nn.MaxPool2d(kernel_size = 3, stride=2, padding=1),

            # self._make_layer(ResBlock, layer_list[0], planes=4),#planes=64
            # self._make_layer(ResBlock, layer_list[1], planes=8, stride=2),#planes=128
            # self._make_layer(ResBlock, layer_list[2], planes=16, stride=2),#planes=256
            # self._make_layer(ResBlock, layer_list[3], planes=32, stride=2),#planes=512
            self._make_layer(ResBlock, layer_list[0], plane_list[0]),
            self._make_layer(ResBlock, layer_list[1], plane_list[1], stride=2),
            self._make_layer(ResBlock, layer_list[2], plane_list[2], stride=2),
            self._make_layer(ResBlock, layer_list[3], plane_list[3], stride=2),

            nn.AdaptiveAvgPool2d((1,1)),
        )
        # self.fc = nn.Linear(512*ResBlock.expansion, num_classes)
        self.fc = nn.Linear(plane_list[3]*ResBlock.expansion, num_classes)
        
    def forward(self, x):
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x
        
    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []
        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )
        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion
        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))
        return nn.Sequential(*layers)

        
        
def ResNet50(num_classes, channels=3):
    return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)
    
def ResNet101(num_classes, channels=3):
    return ResNet(Bottleneck, [3,4,23,3], num_classes, channels)

def ResNet152(num_classes, channels=3):
    return ResNet(Bottleneck, [3,8,36,3], num_classes, channels)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = ResNet50(num_classes=10, channels=3).to(device)
model = ResNet(Bottleneck, [3,4,6,3], num_classes=10, channels=3).to(device)
# print(model)


In [None]:
# @title train/test
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

def train():
    model.train()
    num_epochs = 5
    n_total_steps = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader): # origin shape: [4, 3, 32, 32] = 4, 3, 1024 input_layer: 3 input channels, 6 output channels, 5 kernel size
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 2000 == 0:
                print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
train()
print('Finished Training')
PATH = './cnn.pth'
torch.save(model.state_dict(), PATH)

def test():
    model.eval()
    n_correct = 0
    n_samples = 0
    with torch.no_grad():
        n_class_correct = [0 for i in range(10)]
        n_class_samples = [0 for i in range(10)]
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            # max returns (value ,index)
            _, predicted = torch.max(outputs, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()
            for i in range(batch_size):
                label = labels[i]
                pred = predicted[i]
                if (label == pred):
                    n_class_correct[label] += 1
                n_class_samples[label] += 1
        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network: {acc} %')
        for i in range(10):
            acc = 100.0 * n_class_correct[i] / n_class_samples[i]
            print(f'Accuracy of {classes[i]}: {acc} %')

test()

Epoch [1/5], Step [1/12500], Loss: 2.2956
Epoch [1/5], Step [2001/12500], Loss: 2.4555
Epoch [1/5], Step [4001/12500], Loss: 2.7042
Epoch [1/5], Step [6001/12500], Loss: 2.0968
Epoch [1/5], Step [8001/12500], Loss: 1.9578
Epoch [1/5], Step [10001/12500], Loss: 1.8977
Epoch [1/5], Step [12001/12500], Loss: 2.1082
Epoch [2/5], Step [1/12500], Loss: 2.0538
Epoch [2/5], Step [2001/12500], Loss: 2.0453
Epoch [2/5], Step [4001/12500], Loss: 1.7704
Epoch [2/5], Step [6001/12500], Loss: 1.6398
Epoch [2/5], Step [8001/12500], Loss: 2.0163
Epoch [2/5], Step [10001/12500], Loss: 1.8784
Epoch [2/5], Step [12001/12500], Loss: 2.2474
Epoch [3/5], Step [1/12500], Loss: 2.3727
Epoch [3/5], Step [2001/12500], Loss: 2.2510
Epoch [3/5], Step [4001/12500], Loss: 1.8058
Epoch [3/5], Step [6001/12500], Loss: 1.9301
Epoch [3/5], Step [8001/12500], Loss: 2.0612
Epoch [3/5], Step [10001/12500], Loss: 2.0406
Epoch [3/5], Step [12001/12500], Loss: 1.5219
Epoch [4/5], Step [1/12500], Loss: 1.9874
Epoch [4/5], Ste