In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda
import torchinfo
import time
from multiprocessing import Process
import matplotlib.pyplot as plt

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU Device:【{}:{}】".format(device.type, device.index))
    torch.cuda.set_device(0)

GPU Device:【cuda:None】


In [3]:
class CustomDataset(Dataset):
    def __init__(self, data:np.ndarray, labels:np.ndarray, transform=ToTensor(), 
    target_transform=Lambda(lambda y: torch.zeros(2, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))):
        self.data:torch.Tensor = torch.from_numpy(data)
        self.labels:torch.Tensor = torch.from_numpy(labels)
        self.transform = None
        self.target_transform = None
        # self.transform = transform
        # self.target_transform = target_transform
        # self.shuffle()
    
    def shuffle(self, seed=None):
        '\n        seed(self, seed=None)\n\n        Reseed a legacy MT19937 BitGenerator\n        '
        self.shuffle_seed = np.random.randint(1, 65535) if seed is None else seed
        print(f"随机种子：{self.shuffle_seed}")
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.data)
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx, 0]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

In [4]:
def load_dataset(path="dataset.npz", train_percent=0.8) -> tuple:
    with np.load(path) as dataset:
        full_data = dataset["data"].astype(np.float32).reshape((-1, 1, 116, 116))
        full_labels = dataset["labels"].astype(np.int64)
    train_size = int(full_data.shape[0]*train_percent)
    test_size = full_data.shape[0]-train_size
    seed = np.random.randint(1, 65535)
    np.random.seed(seed)
    np.random.shuffle(full_data)
    np.random.seed(seed)
    np.random.shuffle(full_labels)
    train_data, test_data = full_data[:train_size], full_data[train_size:]
    train_labels, test_labels = full_labels[:train_size], full_labels[train_size:]
    print(f"训练集大小：{train_size}", f"测试集大小：{test_size}", f"随机种子：{seed}")
    train_dataset = CustomDataset(train_data, train_labels)
    test_dataset = CustomDataset(test_data, test_labels)
    return train_dataset, test_dataset

In [5]:
train_dataset, test_dataset = load_dataset("D:\\datasets\\ABIDE\\ABIDE_FC_dataset.npz", 0.8)
# train_dataset, test_dataset = load_dataset("D:\\datasets\\ABIDE\\ABIDE_FC_augmented_dataset.npz", 0.8)

训练集大小：706 测试集大小：177 随机种子：8832


In [7]:
batch_size=64

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 116, 116])
Shape of y: torch.Size([64]) torch.int64


In [8]:
class SimlpeCNN(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.__built = False
        # self.conv = nn.Conv2d(1, 1, 5, padding="same")
        # self.activation = nn.ReLU()
        # self.flatten = nn.Flatten()
        # self.dense = nn.Linear(2)
    
    def build(self, input_shape:tuple):
        # resolve output shape in model summary
        self.conv = nn.Conv2d(input_shape[1], 1, 5, padding="same")
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.liner = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True
    
    def compile(self, dataloader:DataLoader, loss_fn, optimizer):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(self.parameters(), lr=1e-2)
        self.to(self.device)
    
    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch}/{epochs}")
            loss, correct = 0, 0
            for batch, (X, y) in enumerate(dataloader):
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= self.batch_size
                print(f"\r{batch+1}/{num_batches+1}  [{current:>5d}/{size:>5d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}%", end = "")
            loss /= num_batches
            correct /= size
            print(f"\nAverage loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}%")
        torchinfo.summary(self, input_size=self.input_shape)

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

    def forward(self, x:np.ndarray):
        if not self.__built:
            raise RuntimeWarning("模型未完成编译！")
        x = self.conv(x)
        x = self.flatten(x)
        x = self.liner(x)
        return x

In [9]:
model = SimlpeCNN(device)
model.compile(train_dataloader, nn.CrossEntropyLoss, torch.optim.SGD)
model.fit(train_dataloader, epochs=10)

Epoch: 0/10
12/12  [  706/  706] - batch loss: 0.900903 - batch accuracy: 0.0%%
Average loss: 0.770886 - Accuracy: 52.7%
Epoch: 1/10
12/12  [  706/  706] - batch loss: 0.563308 - batch accuracy: 3.1%%
Average loss: 0.736654 - Accuracy: 59.8%
Epoch: 2/10
12/12  [  706/  706] - batch loss: 0.395109 - batch accuracy: 3.1%%
Average loss: 0.701510 - Accuracy: 62.7%
Epoch: 3/10
12/12  [  706/  706] - batch loss: 0.291984 - batch accuracy: 3.1%%
Average loss: 0.677356 - Accuracy: 65.9%
Epoch: 4/10
12/12  [  706/  706] - batch loss: 0.223830 - batch accuracy: 3.1%%
Average loss: 0.657725 - Accuracy: 68.0%
Epoch: 5/10
12/12  [  706/  706] - batch loss: 0.176781 - batch accuracy: 3.1%%
Average loss: 0.640561 - Accuracy: 69.4%
Epoch: 6/10
12/12  [  706/  706] - batch loss: 0.143226 - batch accuracy: 3.1%%
Average loss: 0.624871 - Accuracy: 70.3%
Epoch: 7/10
12/12  [  706/  706] - batch loss: 0.118615 - batch accuracy: 3.1%%
Average loss: 0.610100 - Accuracy: 71.8%
Epoch: 8/10
12/12  [  706/  706]

In [10]:
model.test(test_dataloader)

Test Error: 
 Accuracy: 64.4%, Avg loss: 0.626033 



In [11]:
class ResBlock(nn.Module):

    def __init__(self, device=None, *args, **wargs):
        super().__init__(*args, **wargs)
        self.__built = False
        self.device = device

    def build(self, input_shape):
        self.input_shape = input_shape
        # resolve output shape in model summary
        self.conv = nn.Conv2d(input_shape[1], 1, 5, padding="same")
        self.activation = nn.ReLU()
        self.bn = nn.BatchNorm2d(input_shape[1])
        self.downconv = nn.Conv2d(input_shape[1], 1, 5, padding="same")
        self.downbn = nn.BatchNorm2d(input_shape[1])
        self.output_shape = input_shape
        self.__built = True
        self.to(self.device)

    def forward(self, inputs:np.ndarray):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        fx = self.conv(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
            x = self.downbn(x)
        try:
            # print(self.name, x.shape, fx.shape, inputs.shape)
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)
    
    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

In [12]:
class SimpleResNet(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__built = False
        self.device = device
    
    def build(self, input_shape):
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks:nn.ModuleList = nn.ModuleList([ResBlock(), ResBlock(), ResBlock()])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(self.parameters(), lr=lr)
        self.to(self.device)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            loss, correct = 0, 0
            time_start = time.time()
            for batch, (X, y) in enumerate(dataloader):
                # time_batch_start = time.time()
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= len(X)
                # time_batch_end = time.time()
                time_end = time.time()
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {(time_end-time_start)/(batch+1)*1000:>0.3f}ms/batch", end = "", flush=True)
            loss /= num_batches
            correct /= size
            time_end = time.time()
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {(time_end-time_start)*1000:>0.3f}ms\n")
        print(torchinfo.summary(self, input_size=self.input_shape))

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [13]:
class MyResNet(nn.Module):
    def __init__(self, device=None, copy_block=False, cache=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num:int = 1
        self.__frozen_blocks_num:int = 0
        self.__built:bool = False
        self.cache:bool = cache
        self.__cache:list(torch.TensorType) = []
        self.device:torch.DeviceObjType = device
        self.copy_block:bool = copy_block
    
    def build(self, input_shape):
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks:nn.ModuleList = nn.ModuleList([ResBlock()])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
        self.to(self.device)
    
    def forward(self, x):
        if self.cache:
            for i in range(self.__frozen_blocks_num, self.__blocks_num):
                x = self.blocks[i](x)
        else:
            for i in range(self.__blocks_num):
                x = self.blocks[i](x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def getBlocksNum(self):
        return self.__blocks_num
    
    def freeze(self, block:ResBlock):
        for param in block.parameters():
            param.requires_grad = False
    
    def __forward_cache(self, block):
        if self.cache:
            for batch, X in enumerate(self.__cache):
                self.__cache[batch] = block(X.to(self.device))

    def addNewBlock(self):
        print("----------")
        print("Adding new block...")
        newBlock = ResBlock()
        last_block:ResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        self.blocks.append(newBlock)
        newBlock.to(self.device)
        self.__blocks_num += 1
        self.freeze(last_block)
        self.__frozen_blocks_num += 1
        self.__forward_cache(last_block)
        print("Success!")

    def copyLastBlock(self):
        print("----------")
        print("Copying last block...")
        newBlock = ResBlock()
        last_block:ResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        if last_block.input_shape == last_block.output_shape:
            newBlock.load_state_dict(last_block.state_dict())
        else:
            print("Copy failed: shape different with last block")
        self.blocks.append(newBlock)
        newBlock.to(self.device)
        self.__blocks_num += 1
        self.freeze(last_block)
        self.__frozen_blocks_num += 1
        self.__forward_cache(last_block)
        print("Success!")

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            if epoch and epoch%3 == 0:
                if self.copy_block:
                    self.copyLastBlock()
                else:
                    self.addNewBlock()
            loss, correct = 0, 0
            torch.cuda.synchronize()
            time_start = time.time()
            for batch, (X, y) in enumerate(dataloader):
                # time_batch_start = time.time()
                if self.cache:
                    if epoch == 0:
                        self.__cache.append(X)
                    X = self.__cache[batch].to(device)
                else:
                    X = X.to(device)
                y = y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= len(X)
                # time_batch_end = time.time()
                torch.cuda.synchronize()
                time_end = time.time()
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {(time_end-time_start)/(batch+1)*1000:>0.3f}ms/batch", end = "", flush=True)
            loss /= num_batches
            correct /= size
            torch.cuda.synchronize()
            time_end = time.time()
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {(time_end-time_start)*1000:>0.3f}ms\n")
        print(torchinfo.summary(self, input_size=self.input_shape))

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [14]:
lr = 1e-2
epochs = 9
loss_fn = nn.CrossEntropyLoss
optimizer = torch.optim.Adam

In [15]:
simple_model = SimpleResNet(device)
simple_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
simple_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
12/12  [706/706] - batch loss: 19.599588 - batch accuracy: 50.0% - 54.592ms/batchh
-- Average loss: 28.612175 - Accuracy: 53.1% - 656.102ms

Epoch: 2/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 45.823ms/batch
-- Average loss: 4.673749 - Accuracy: 68.1% - 550.872ms

Epoch: 3/9
12/12  [706/706] - batch loss: 0.000027 - batch accuracy: 100.0% - 45.714ms/batch
-- Average loss: 5.003527 - Accuracy: 63.7% - 549.563ms

Epoch: 4/9
12/12  [706/706] - batch loss: 0.000172 - batch accuracy: 100.0% - 46.030ms/batch
-- Average loss: 0.941706 - Accuracy: 83.6% - 553.358ms

Epoch: 5/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 45.743ms/batch
-- Average loss: 0.248906 - Accuracy: 93.3% - 548.914ms

Epoch: 6/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 45.654ms/batch
-- Average loss: 0.403793 - Accuracy: 87.5% - 548.852ms

Epoch: 7/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 45.540ms/batch
-- Av

In [16]:
model = MyResNet(device)
model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
12/12  [706/706] - batch loss: 29.952320 - batch accuracy: 50.0% - 19.300ms/batch
-- Average loss: 25.903312 - Accuracy: 51.3% - 232.600ms

Epoch: 2/9
12/12  [706/706] - batch loss: 0.000265 - batch accuracy: 100.0% - 19.197ms/batch
-- Average loss: 7.807234 - Accuracy: 65.6% - 231.365ms

Epoch: 3/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 19.381ms/batch
-- Average loss: 2.280597 - Accuracy: 76.9% - 233.571ms

Epoch: 4/9
----------
Adding new block...
Success!
12/12  [706/706] - batch loss: 0.000003 - batch accuracy: 100.0% - 20.457ms/batch
-- Average loss: 2.137368 - Accuracy: 80.9% - 246.489ms

Epoch: 5/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 20.415ms/batch
-- Average loss: 4.956413 - Accuracy: 71.8% - 245.982ms

Epoch: 6/9
12/12  [706/706] - batch loss: 4.860724 - batch accuracy: 50.0% - 20.129ms/batch
-- Average loss: 8.302202 - Accuracy: 71.1% - 242.553ms

Epoch: 7/9
----------
Adding new block...
Success!
12/12  [7

In [17]:
cache_model = MyResNet(device, cache=True)
cache_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
cache_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
12/12  [706/706] - batch loss: 8.489141 - batch accuracy: 0.0% - 19.022ms/batchh
-- Average loss: 11.886777 - Accuracy: 54.0% - 229.271ms

Epoch: 2/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.831ms/batch
-- Average loss: 6.986750 - Accuracy: 59.6% - 226.974ms

Epoch: 3/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.872ms/batch
-- Average loss: 5.222652 - Accuracy: 65.7% - 227.458ms

Epoch: 4/9
----------
Adding new block...
Success!
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 17.877ms/batch
-- Average loss: 3.838403 - Accuracy: 72.9% - 215.519ms

Epoch: 5/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 17.803ms/batch
-- Average loss: 3.381629 - Accuracy: 76.9% - 213.635ms

Epoch: 6/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.171ms/batch
-- Average loss: 3.303361 - Accuracy: 79.0% - 219.055ms

Epoch: 7/9
----------
Adding new block...
Success!
12/12  [7

In [18]:
copy_nocache_model = MyResNet(device, copy_block=True)
copy_nocache_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
copy_nocache_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
12/12  [706/706] - batch loss: 22.267452 - batch accuracy: 50.0% - 21.167ms/batch
-- Average loss: 18.543943 - Accuracy: 51.0% - 254.998ms

Epoch: 2/9
12/12  [706/706] - batch loss: 1.545585 - batch accuracy: 50.0% - 20.253ms/batch
-- Average loss: 4.117221 - Accuracy: 67.8% - 244.032ms

Epoch: 3/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.719ms/batch
-- Average loss: 2.602998 - Accuracy: 76.2% - 225.630ms

Epoch: 4/9
----------
Copying last block...
Success!
12/12  [706/706] - batch loss: 0.000001 - batch accuracy: 100.0% - 20.031ms/batch
-- Average loss: 2.626776 - Accuracy: 58.4% - 240.376ms

Epoch: 5/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 20.410ms/batch
-- Average loss: 2.205239 - Accuracy: 68.4% - 245.924ms

Epoch: 6/9
12/12  [706/706] - batch loss: 0.000045 - batch accuracy: 100.0% - 20.150ms/batch
-- Average loss: 2.230779 - Accuracy: 72.7% - 241.797ms

Epoch: 7/9
----------
Copying last block...
Success!
12/12

In [19]:
copy_model = MyResNet(device, copy_block=True, cache=True)
copy_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
copy_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
12/12  [706/706] - batch loss: 33.051132 - batch accuracy: 50.0% - 18.894ms/batch
-- Average loss: 22.686199 - Accuracy: 51.8% - 227.731ms

Epoch: 2/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 20.945ms/batch
-- Average loss: 5.781325 - Accuracy: 64.0% - 251.338ms

Epoch: 3/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 19.083ms/batch
-- Average loss: 2.372716 - Accuracy: 75.2% - 229.993ms

Epoch: 4/9
----------
Copying last block...
Success!
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.197ms/batch
-- Average loss: 3.477890 - Accuracy: 72.1% - 219.366ms

Epoch: 5/9
12/12  [706/706] - batch loss: 0.000000 - batch accuracy: 100.0% - 18.617ms/batch
-- Average loss: 1.935719 - Accuracy: 82.2% - 224.392ms

Epoch: 6/9
12/12  [706/706] - batch loss: 0.000001 - batch accuracy: 100.0% - 18.619ms/batch
-- Average loss: 1.298425 - Accuracy: 86.5% - 224.425ms

Epoch: 7/9
----------
Copying last block...
Success!
12/1

In [20]:
simple_model.test(test_dataloader)
model.test(test_dataloader)
cache_model.test(test_dataloader)
copy_nocache_model.test(test_dataloader)
copy_model.test(test_dataloader)

Test Error: 
 Accuracy: 58.2%, Avg loss: 2.896871 

Test Error: 
 Accuracy: 57.6%, Avg loss: 19.668694 

Test Error: 
 Accuracy: 60.5%, Avg loss: 4.995947 

Test Error: 
 Accuracy: 55.4%, Avg loss: 7.237557 

Test Error: 
 Accuracy: 60.5%, Avg loss: 2.154367 



In [21]:
simple_model.test(train_dataloader)
model.test(train_dataloader)
cache_model.test(train_dataloader)
copy_nocache_model.test(train_dataloader)
copy_model.test(train_dataloader)

Test Error: 
 Accuracy: 94.3%, Avg loss: 0.186212 

Test Error: 
 Accuracy: 87.1%, Avg loss: 1.684372 

Test Error: 
 Accuracy: 56.1%, Avg loss: 3.725947 

Test Error: 
 Accuracy: 85.7%, Avg loss: 1.146361 

Test Error: 
 Accuracy: 71.8%, Avg loss: 1.191327 

