In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda
import torchinfo
import time
from multiprocessing import Process
import matplotlib.pyplot as plt

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU Device:【{}:{}】".format(device.type, device.index))
    torch.cuda.set_device(0)

GPU Device:【cuda:None】


In [3]:
class CustomDataset(Dataset):
    def __init__(self, data:np.ndarray, labels:np.ndarray, transform=ToTensor(), 
    target_transform=Lambda(lambda y: torch.zeros(2, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))):
        self.data:torch.Tensor = torch.from_numpy(data)
        self.labels:torch.Tensor = torch.from_numpy(labels)
        self.transform = None
        self.target_transform = None
        # self.transform = transform
        # self.target_transform = target_transform
        # self.shuffle()
    
    def shuffle(self, seed=None):
        '\n        seed(self, seed=None)\n\n        Reseed a legacy MT19937 BitGenerator\n        '
        self.shuffle_seed = np.random.randint(1, 65535) if seed is None else seed
        print(f"随机种子：{self.shuffle_seed}")
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.data)
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx, 0]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

In [4]:
def load_dataset(path="dataset.npz", train_percent=0.8) -> tuple:
    with np.load(path) as dataset:
        full_data = dataset["data"].astype(np.float32)
        full_labels = dataset["labels"].astype(np.int64)
    train_size = int(full_data.shape[0]*train_percent)
    test_size = full_data.shape[0]-train_size
    seed = np.random.randint(1, 65535) # 35468
    np.random.seed(seed)
    np.random.shuffle(full_data)
    np.random.seed(seed)
    np.random.shuffle(full_labels)
    train_data, test_data = full_data[:train_size], full_data[train_size:]
    train_labels, test_labels = full_labels[:train_size], full_labels[train_size:]
    print(f"训练集大小：{train_size}", f"测试集大小：{test_size}", f"随机种子：{seed}")
    train_dataset = CustomDataset(train_data, train_labels)
    test_dataset = CustomDataset(test_data, test_labels)
    return train_dataset, test_dataset

In [5]:
train_dataset, test_dataset = load_dataset("D:\\datasets\\ABIDE\\ABIDE_augmented_dataset.npz", 0.8)

训练集大小：9636 测试集大小：2409 随机种子：35468


In [6]:
batch_size=64

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, L, H]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, L, H]: torch.Size([64, 60, 116])
Shape of y: torch.Size([64]) torch.int64


In [7]:
class SimlpeLSTMCNN(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.device = device
        self.__built = False
    
    def build(self, input_shape:tuple):
        # resolve output shape in model summary
        self.lstm = nn.LSTM(116, 116, 2, batch_first=True)
        self.conv = nn.Conv1d(input_shape[1], 1, 3, padding="same")
        self.activation = nn.ReLU()
        self.flatten = nn.Flatten()
        self.liner = nn.LazyLinear(2)
        self.__built = True

    def forward(self, x:np.ndarray):
        if not self.__built:
            raise RuntimeWarning("模型未完成编译！")
        x, (h_n, c_n) = self.lstm(x)
        x = self.conv(x)
        x = self.flatten(x)
        x = self.liner(x)
        return x

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(self.parameters(), lr=lr)
        self.to(self.device)

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            loss, correct = 0, 0
            time_start = time.time()
            for batch, (X, y) in enumerate(dataloader):
                # time_batch_start = time.time()
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= len(X)
                # time_batch_end = time.time()
                time_end = time.time()
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {(time_end-time_start)/num_batches*1000:>0.3f}ms/batch", end = "", flush=True)
            loss /= num_batches
            correct /= size
            time_end = time.time()
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {(time_end-time_start)*1000:>0.3f}ms")
        print("\n", torchinfo.summary(self, input_size=self.input_shape))

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [8]:
class ResBlock(nn.Module):

    def __init__(self, out_channels:int=None, kernel_size=5, device=None, *args, **wargs):
        super().__init__(*args, **wargs)
        self.__built = False
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.device = device

    def build(self, input_shape):
        self.input_shape = input_shape
        in_channels = input_shape[1]
        out_channels = in_channels if self.out_channels is None else in_channels
        # resolve output shape in model summary
        self.conv = nn.Conv1d(in_channels, out_channels, self.kernel_size, padding="same")
        self.activation = nn.ReLU()
        self.bn = nn.BatchNorm1d(in_channels)
        self.downconv = nn.Conv1d(in_channels, out_channels, self.kernel_size, padding="same")
        self.downbn = nn.BatchNorm1d(in_channels)
        self.output_shape = input_shape
        self.__built = True
        self.to(self.device)

    def forward(self, inputs:np.ndarray):
        x:np.ndarray = inputs
        fx:np.ndarray = x
        fx = self.conv(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
            x = self.downbn(x)
        try:
            return fx + x
        except:
            raise RuntimeError(x.shape, fx.shape, inputs.shape)
    
    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

In [9]:
class SimpleResNet(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__built = False
        self.device = device
    
    def build(self, input_shape):
        self.lstm = nn.LSTM(116, 116, 2, batch_first=True)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks:nn.ModuleList = nn.ModuleList([ResBlock(), ResBlock(), ResBlock()])
        self.blocks_1:nn.ModuleList = nn.ModuleList([ResBlock(1)])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        for block in self.blocks_1:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(self.parameters(), lr=lr)
        self.to(self.device)
    
    def forward(self, x):
        x, (h_n, c_n) = self.lstm(x)
        for block in self.blocks:
            x = block(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            loss, correct = 0, 0
            time_start = time.time()
            for batch, (X, y) in enumerate(dataloader):
                # time_batch_start = time.time()
                X, y = X.to(device), y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= len(X)
                # time_batch_end = time.time()
                time_end = time.time()
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {(time_end-time_start)/num_batches*1000:>0.3f}ms/batch", end = "", flush=True)
            loss /= num_batches
            correct /= size
            time_end = time.time()
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {(time_end-time_start)*1000:>0.3f}ms")
        print("\n", torchinfo.summary(self, input_size=self.input_shape))

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [10]:
class MyResNet(nn.Module):
    def __init__(self, device=None, copy_block=False, cache=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__blocks_num:int = 1
        self.__frozen_blocks_num:int = 0
        self.__built:bool = False
        self.cache:bool = cache
        self.__cache:list(torch.TensorType) = []
        self.device:torch.DeviceObjType = device
        self.copy_block:bool = copy_block
    
    def build(self, input_shape):
        self.lstm = nn.LSTM(116, 116, 2, batch_first=True)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks:nn.ModuleList = nn.ModuleList([ResBlock()])
        self.blocks1:nn.ModuleList = nn.ModuleList([ResBlock(1)])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        for block in self.blocks1:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
        self.to(self.device)
    
    def forward(self, x):
        x, (h_n, c_n) = self.lstm(x)
        if self.cache:
            for i in range(self.__frozen_blocks_num, self.__blocks_num):
                x = self.blocks[i](x)
        else:
            for i in range(self.__blocks_num):
                x = self.blocks[i](x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def getBlocksNum(self):
        return self.__blocks_num
    
    def freeze(self, block:ResBlock):
        for param in block.parameters():
            param.requires_grad = False
    
    def __forward_cache(self, block):
        if self.cache:
            for batch, X in enumerate(self.__cache):
                self.__cache[batch] = block(X.to(self.device))

    def addNewBlock(self):
        print("----------")
        print("Adding new block...")
        newBlock = ResBlock()
        last_block:ResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        self.blocks.append(newBlock)
        newBlock.to(self.device)
        self.__blocks_num += 1
        self.freeze(last_block)
        self.__frozen_blocks_num += 1
        self.__forward_cache(last_block)
        print("Success!")

    def copyLastBlock(self):
        print("----------")
        print("Copying last block...")
        newBlock = ResBlock()
        last_block:ResBlock = self.blocks[-1]
        newBlock.build(last_block.output_shape)
        if last_block.input_shape == last_block.output_shape:
            newBlock.load_state_dict(last_block.state_dict())
        else:
            print("Copy failed: shape different with last block")
        self.blocks.append(newBlock)
        newBlock.to(self.device)
        self.__blocks_num += 1
        self.freeze(last_block)
        self.__frozen_blocks_num += 1
        self.__forward_cache(last_block)
        print("Success!")

    def fit(self, dataloader:DataLoader, epochs:int=1):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        self.train()
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            if epoch and epoch%3 == 0:
                if self.copy_block:
                    self.copyLastBlock()
                else:
                    self.addNewBlock()
            loss, correct = 0, 0
            torch.cuda.synchronize()
            time_start = time.time()
            for batch, (X, y) in enumerate(dataloader):
                # time_batch_start = time.time()
                if self.cache:
                    if epoch == 0:
                        self.__cache.append(X)
                    X = self.__cache[batch].to(device)
                else:
                    X = X.to(device)
                y = y.to(device)

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss
                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct

                batch_correct /= len(X)
                # time_batch_end = time.time()
                torch.cuda.synchronize()
                time_end = time.time()
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {(time_end-time_start)/num_batches*1000:>0.3f}ms/batch", end = "", flush=True)
            loss /= num_batches
            correct /= size
            torch.cuda.synchronize()
            time_end = time.time()
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {(time_end-time_start)*1000:>0.3f}ms")
        print("\n", torchinfo.summary(self, input_size=self.input_shape))

    def test(self, dataloader:DataLoader):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
        correct /= size
        print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [11]:
lr = 1e-2
epochs = 9
loss_fn = nn.CrossEntropyLoss
optimizer = torch.optim.Adam

In [12]:
simple_model = SimpleResNet(device)
simple_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
simple_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
151/151  [9636/9636] - batch loss: 0.632798 - batch accuracy: 66.7% - 35.031ms/batch
-- Average loss: 2.837251 - Accuracy: 52.4% - 5255.603ms
Epoch: 2/9
151/151  [9636/9636] - batch loss: 0.412600 - batch accuracy: 75.0% - 18.289ms/batch
-- Average loss: 0.506696 - Accuracy: 73.7% - 2744.315ms
Epoch: 3/9
151/151  [9636/9636] - batch loss: 0.408411 - batch accuracy: 86.1% - 18.336ms/batch
-- Average loss: 0.307631 - Accuracy: 86.2% - 2751.330ms
Epoch: 4/9
151/151  [9636/9636] - batch loss: 0.086233 - batch accuracy: 94.4% - 18.380ms/batch
-- Average loss: 0.202394 - Accuracy: 92.0% - 2757.979ms
Epoch: 5/9
151/151  [9636/9636] - batch loss: 0.265240 - batch accuracy: 88.9% - 18.207ms/batch
-- Average loss: 0.161910 - Accuracy: 93.9% - 2732.985ms
Epoch: 6/9
151/151  [9636/9636] - batch loss: 0.049649 - batch accuracy: 97.2% - 18.327ms/batchh
-- Average loss: 0.138168 - Accuracy: 94.7% - 2750.000ms
Epoch: 7/9
151/151  [9636/9636] - batch loss: 0.002106 - batch accuracy: 100.0% -

In [13]:
model = MyResNet(device)
model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
151/151  [9636/9636] - batch loss: 1.376981 - batch accuracy: 47.2% - 15.242ms/batch
-- Average loss: 2.090946 - Accuracy: 57.4% - 2287.260ms
Epoch: 2/9
151/151  [9636/9636] - batch loss: 0.616090 - batch accuracy: 83.3% - 15.801ms/batch
-- Average loss: 0.603029 - Accuracy: 77.4% - 2371.142ms
Epoch: 3/9
151/151  [9636/9636] - batch loss: 0.692257 - batch accuracy: 75.0% - 16.062ms/batch
-- Average loss: 0.326872 - Accuracy: 86.3% - 2410.265ms
Epoch: 4/9
----------
Adding new block...
Success!
151/151  [9636/9636] - batch loss: 0.740456 - batch accuracy: 80.6% - 17.104ms/batch
-- Average loss: 1.951129 - Accuracy: 75.2% - 2566.612ms
Epoch: 5/9
151/151  [9636/9636] - batch loss: 0.450994 - batch accuracy: 83.3% - 17.027ms/batch
-- Average loss: 0.918521 - Accuracy: 82.0% - 2556.100ms
Epoch: 6/9
151/151  [9636/9636] - batch loss: 1.029123 - batch accuracy: 77.8% - 16.755ms/batch
-- Average loss: 0.935290 - Accuracy: 82.1% - 2513.190ms
Epoch: 7/9
----------
Adding new block...


In [15]:
cache_model = MyResNet(device, cache=True)
cache_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
cache_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
151/151  [9636/9636] - batch loss: 0.569411 - batch accuracy: 69.4% - 15.746ms/batch
-- Average loss: 2.299541 - Accuracy: 54.7% - 2362.936ms
Epoch: 2/9
151/151  [9636/9636] - batch loss: 0.835074 - batch accuracy: 72.2% - 16.293ms/batch
-- Average loss: 0.481192 - Accuracy: 78.4% - 2444.907ms
Epoch: 3/9
151/151  [9636/9636] - batch loss: 0.584027 - batch accuracy: 86.1% - 16.102ms/batch
-- Average loss: 0.378450 - Accuracy: 85.4% - 2416.286ms
Epoch: 4/9
----------
Adding new block...
Success!
151/151  [9636/9636] - batch loss: 2.698141 - batch accuracy: 58.3% - 15.459ms/batchh
-- Average loss: 5.281126 - Accuracy: 60.9% - 2319.817ms
Epoch: 5/9
151/151  [9636/9636] - batch loss: 0.597353 - batch accuracy: 83.3% - 15.569ms/batch
-- Average loss: 1.901757 - Accuracy: 72.3% - 2336.408ms
Epoch: 6/9
151/151  [9636/9636] - batch loss: 3.392251 - batch accuracy: 58.3% - 15.391ms/batch
-- Average loss: 1.094381 - Accuracy: 81.1% - 2309.643ms
Epoch: 7/9
----------
Adding new block...

In [19]:
copy_nocache_model = MyResNet(device, copy_block=True, cache=False)
copy_nocache_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
copy_nocache_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
151/151  [9636/9636] - batch loss: 0.654439 - batch accuracy: 63.9% - 16.772ms/batch
-- Average loss: 3.038122 - Accuracy: 52.2% - 2516.848ms
Epoch: 2/9
151/151  [9636/9636] - batch loss: 0.547437 - batch accuracy: 72.2% - 16.535ms/batch
-- Average loss: 0.739187 - Accuracy: 56.0% - 2480.309ms
Epoch: 3/9
151/151  [9636/9636] - batch loss: 0.572166 - batch accuracy: 72.2% - 16.055ms/batch
-- Average loss: 0.504303 - Accuracy: 74.6% - 2409.187ms
Epoch: 4/9
----------
Copying last block...
Success!
151/151  [9636/9636] - batch loss: 0.407271 - batch accuracy: 77.8% - 16.889ms/batch
-- Average loss: 0.930404 - Accuracy: 72.1% - 2534.419ms
Epoch: 5/9
151/151  [9636/9636] - batch loss: 1.949872 - batch accuracy: 55.6% - 17.007ms/batch
-- Average loss: 0.667843 - Accuracy: 78.2% - 2552.048ms
Epoch: 6/9
151/151  [9636/9636] - batch loss: 0.276107 - batch accuracy: 86.1% - 17.152ms/batch
-- Average loss: 0.629826 - Accuracy: 81.0% - 2574.816ms
Epoch: 7/9
----------
Copying last block

In [None]:
copy_model = MyResNet(device, copy_block=True, cache=True)
copy_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
copy_model.fit(train_dataloader, epochs=epochs)

Epoch: 1/9
151/151  [9636/9636] - batch loss: 0.784736 - batch accuracy: 55.6% - 16.006ms/batch
-- Average loss: 2.576262 - Accuracy: 55.1% - 2402.967ms
Epoch: 2/9
151/151  [9636/9636] - batch loss: 0.493238 - batch accuracy: 72.2% - 16.317ms/batch
-- Average loss: 0.598163 - Accuracy: 76.5% - 2448.535ms
Epoch: 3/9
151/151  [9636/9636] - batch loss: 0.340224 - batch accuracy: 88.9% - 16.132ms/batch
-- Average loss: 0.337614 - Accuracy: 85.7% - 2420.732ms
Epoch: 4/9
----------
Copying last block...
Success!
151/151  [9636/9636] - batch loss: 0.455099 - batch accuracy: 88.9% - 15.658ms/batchh
-- Average loss: 0.221921 - Accuracy: 91.3% - 2348.665ms
Epoch: 5/9
151/151  [9636/9636] - batch loss: 0.015366 - batch accuracy: 100.0% - 15.640ms/batch
-- Average loss: 0.150251 - Accuracy: 94.3% - 2347.071ms
Epoch: 6/9
151/151  [9636/9636] - batch loss: 0.280999 - batch accuracy: 91.7% - 15.573ms/batchh
-- Average loss: 0.119422 - Accuracy: 95.6% - 2336.007ms
Epoch: 7/9
----------
Copying last bl

In [20]:
simple_model.test(test_dataloader)
model.test(test_dataloader)
cache_model.test(test_dataloader)
copy_nocache_model.test(test_dataloader)
copy_model.test(test_dataloader)

Test Error: 
 Accuracy: 93.6%, Avg loss: 0.181359 

Test Error: 
 Accuracy: 81.5%, Avg loss: 2.275847 

Test Error: 
 Accuracy: 82.7%, Avg loss: 1.053976 

Test Error: 
 Accuracy: 86.5%, Avg loss: 0.333633 

Test Error: 
 Accuracy: 94.3%, Avg loss: 0.199079 



In [21]:
simple_model.test(train_dataloader)
model.test(train_dataloader)
cache_model.test(train_dataloader)
copy_nocache_model.test(train_dataloader)
copy_model.test(train_dataloader)

Test Error: 
 Accuracy: 97.1%, Avg loss: 0.069871 

Test Error: 
 Accuracy: 84.5%, Avg loss: 1.396245 

Test Error: 
 Accuracy: 87.5%, Avg loss: 0.494309 

Test Error: 
 Accuracy: 90.9%, Avg loss: 0.191386 

Test Error: 
 Accuracy: 97.4%, Avg loss: 0.078086 

