In [54]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import numpy as np
from MyModel import MyModelBase, MyResNet

In [55]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU Device:【{}:{}】".format(device.type, device.index))
    torch.cuda.set_device(0)
else:
    device = torch.device("cpu")
    print("CPU Device:【{}:{}】".format(device.type, device.index))

CPU Device:【cpu:None】


In [56]:
train_dataset = datasets.CIFAR100(root="./data", train=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR100(root="./data", train=False, transform=transforms.ToTensor())

In [57]:
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=True)

for X, y in test_dataloader:
    data_shape = X.shape
    label_shape = y.shape
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([32, 3, 32, 32])
Shape of y: torch.Size([32]) torch.int64


In [58]:

class SimpleResBlock(nn.Module):
    def __init__(self, out_channels:int=None, kernel_size:int=3, device=None, *args, **wargs):
        super().__init__(*args, **wargs)
        self.__built = False
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.device = device
        self.return_attn=False

    def build(self, input_shape:tuple):
        self.input_shape = input_shape
        batch_size, in_channels, in_H, in_W = input_shape
        out_channels = in_channels if self.out_channels is None else self.out_channels
        
        self.conv:nn.Conv2d = nn.Conv2d(in_channels, out_channels, self.kernel_size, padding="same")
        self.activation:nn.ReLU = nn.ReLU()
        self.bn:nn.BatchNorm2d = nn.BatchNorm2d(in_channels)
        self.downconv:nn.Conv2d = nn.Conv2d(in_channels, out_channels, self.kernel_size, padding="same")

        self.output_shape = input_shape if out_channels == in_channels else (batch_size, in_channels, in_H, in_W)
        self.__built = True
        self.to(self.device)

    def forward(self, x:torch.Tensor, return_attn=False):
        fx:torch.Tensor = x
        fx = self.conv(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
        return fx + x

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

In [59]:
class SimpleBaseModel(MyModelBase):
    def __init__(self, classes=2, device=None, *args, **kwargs):
        super().__init__(device, *args, **kwargs)
        self.classes = classes
    
    def build(self, input_shape):
        self.blocks:nn.ModuleList = nn.ModuleList([SimpleResBlock(input_shape[1], device=self.device) for i in range(4)])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), self.classes)
        self.__built = True

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

In [60]:
class SimpleResNet(MyResNet):
    def __init__(self, classes=2, device=None, copy_block=False, cache=False, freeze_block=False, *args, **kwargs):
        super().__init__(device=device, copy_block=copy_block, cache=cache, freeze_block=freeze_block *args, **kwargs)
        self.classes = classes
    
    def build(self, input_shape):
        self.blocks:nn.ModuleList = nn.ModuleList([SimpleResBlock(input_shape[1])])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(torch.prod(input_shape[1:]), self.classes)
        self.__built = True

    def forward(self, x):
        if self.cache:
            for i in range(self.__frozen_blocks_num, self.__blocks_num):
                x = self.blocks[i](x)
        else:
            for i in range(self.__blocks_num):
                x = self.blocks[i](x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def addNewBlock(self, copy_last=False):
        print("----------")
        print(f"{'Copying' if copy_last else 'Adding'} new block...")
        newBlock = SimpleResBlock()
        last_block:SimpleResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        if copy_last:
            # 复制上一层参数
            newBlock.load_state_dict(last_block.state_dict())
        self.blocks.append(newBlock)
        newBlock.to(self.device)
        self.__blocks_num += 1
        if self.freeze_block:
            self.freeze(last_block)
        self.__frozen_blocks_num += 1
        self.__forward_cache(last_block)
        print("Success!")

    def add_condition(self, epoch):
        if self.__blocks_num < 4:
            if epoch and epoch%5 == 0:
                return True
        return False


In [61]:
lr = 5e-5
epochs = 20
loss_fn = nn.CrossEntropyLoss
optimizer = torch.optim.Adam
correct, loss, timing = {}, {}, {}

In [None]:
base_model = SimpleBaseModel(len(train_dataset.classes), device)
base_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
correct["Base"], loss["Base"], timing["Base"] = base_model.fit(train_dataloader, epochs=epochs, test_dataloader=test_dataloader)

In [None]:
res_model = SimpleBaseModel(len(train_dataset.classes), device)
res_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
correct["Test"], loss["Test"], timing["Test"] = res_model.fit(train_dataloader, epochs=epochs, test_dataloader=test_dataloader)