In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import torchinfo
from multiprocessing import Process
import matplotlib.pyplot as plt

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU Device:【{}:{}】".format(device.type, device.index))
    torch.cuda.set_device(0)

GPU Device:【cuda:None】


In [3]:
training_data = datasets.MNIST("MNIST_data", train=True, download=True, transform=ToTensor())
test_data = datasets.MNIST("MNIST_data", train=False, download=True, transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting MNIST_data\MNIST\raw\train-images-idx3-ubyte.gz to MNIST_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data\MNIST\raw\train-labels-idx1-ubyte.gz


102.8%


Extracting MNIST_data\MNIST\raw\train-labels-idx1-ubyte.gz to MNIST_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting MNIST_data\MNIST\raw\t10k-images-idx3-ubyte.gz to MNIST_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data\MNIST\raw\t10k-labels-idx1-ubyte.gz


112.7%

Extracting MNIST_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to MNIST_data\MNIST\raw






In [4]:
batch_size=32
# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 1, 28, 28])
Shape of y: torch.Size([32]) torch.int64


In [5]:
class ResBlock(nn.Module):

    def __init__(self, device=None, *args, **wargs):
        super().__init__(*args, **wargs)
        self.__built = False
        self.device = device

    def build(self, input_shape):
        self.input_shape = input_shape
        # resolve output shape in model summary
        self.conv = nn.Conv2d(input_shape[1], 1, 5, padding="same")
        self.activation = nn.ReLU()
        self.bn = nn.BatchNorm2d(input_shape[1])
        self.downconv = nn.Conv2d(input_shape[1], 1, 5, padding="same")
        self.downbn = nn.BatchNorm2d(input_shape[1])
        self.output_shape = input_shape
        self.to(self.device)
        self.__built = True

    def forward(self, inputs:np.ndarray):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        fx = self.conv(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
            x = self.downbn(x)
        try:
            # print(self.name, x.shape, fx.shape, inputs.shape)
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)

    # def get_weights(self):
    #     return [self.conv.get_weights(), self.bn.get_weights()]

    # def set_weights(self, weights:list):
    #     self.conv.set_weights(weights[0])
    #     self.bn.set_weights(weights[1])
    #     return super().set_weights(weights)

In [6]:
class MyResNet(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num = 1
        self.__frozen_blocks_num = 0
        self.__built = False
        self.device = device
    
    def build(self, input_shape):
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks = [ResBlock()]
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 10)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(self.parameters(), lr=1e-2)
        self.to(self.device)
    
    def forward(self, x):
        # for i in range(self.__blocks_num):
        #     x = self.blocks[i](x)
        # x = self.blocks[0](x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def getBlocksNum(self):
        return self.__blocks_num

    def addNewBlock(self):
        print("----------")
        print("Adding new block...")
        if self.__blocks_num >= len(self.blocks):
            newBlock = ResBlock()
            newBlock.build(self.blocks[-1].output_shape)
            self.blocks.append(newBlock)
            newBlock.to(self.device)
        self.__blocks_num += 1
        print("Success!")

    def copyLastBlock(self):
        print("----------")
        print("Copying last block...")
        newBlock = ResBlock()
        last_block:ResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        if last_block.input_shape == last_block.output_shape:
            newBlock.load_state_dict(last_block.state_dict())
        else:
            print("Copy failed: shape different with last block")
        self.blocks.append(newBlock)
        self.__blocks_num += 1
        print("Success!")

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch}/{epochs}")
            loss, correct = 0, 0
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= self.batch_size
                print(f"\r{batch+1}/{num_batches+1}  [{current:>5d}/{size:>5d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}%", end = "")
            loss /= num_batches
            correct /= size
            print(f"\nAverage loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}%")
        torchinfo.summary(self, input_size=self.input_shape)

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [None]:
model = MyResNet(device)
model.compile(train_dataloader, nn.CrossEntropyLoss, torch.optim.SGD)
model.fit(train_dataloader, epochs=10)