In [1]:
import sys
import os

# 获取当前工作目录
current_dir = os.getcwd()

# 获取父目录（即 python_dir）
project_dir = os.path.dirname(current_dir)
sys.path.append(project_dir)


import numpy as np
import cupy as cp
import trytorch as torch
import trytorch.ops as ops
import trytorch.nn as nn
import trytorch.optim as optim
import trytorch.datas as data
from trytorch.array_device import *



In [2]:
class SimpleResNet(nn.Module):
    def __init__(self, device=None, dtype="float32"):
        super().__init__()
        self.block1 = nn.ConvBN(1, 16, 7, 4, device=device, dtype=dtype)
        self.block2 = nn.ConvBN(16, 32, 3, 2, device=device, dtype=dtype)
        self.res1 = nn.Residual(
            nn.Sequential(
                nn.ConvBN(32, 32, 3, 1, device=device, dtype=dtype),
                nn.ConvBN(32, 32, 3, 1, device=device, dtype=dtype)
            )
        )
        self.block3 = nn.ConvBN(32, 64, 3, 2, device=device, dtype=dtype)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(64 * 2 * 2, 256, device=device, dtype=dtype)
        self.relu = nn.ReLU()
        self.drop = nn.Dropout(0.4)
        self.linear2 = nn.Linear(256, 10, device=device, dtype=dtype)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.res1(x)
        x = self.block3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.drop(x)
        x = self.linear2(x)
        return x

In [3]:
batch_size=256
epochs=20
net = SimpleResNet(device=cpu())
optimizer=optim.Adam(net.parameters(),lr=0.001,weight_decay=0.001)
criterion= nn.SoftmaxLoss()

### 测试模型是否能够正向计算

In [4]:
# #测试跑通模型
# np.random.seed(42)
# x_array = np.random.rand(1,1,32,32)
# x = torch.Tensor(x_array)
# x.to('cpu')
# y = net(x)
# print(y)
# y.backward()

In [5]:
from pathlib import Path
from trytorch.datas import DataLoader
from trytorch.datas.datasets import SVHNDataset
from trytorch.datas.data_transform import *

project_path = Path(project_dir)
print(project_path)

train_dataset = SVHNDataset(
    file = project_path / "data" / "SVHN" / "train_data.mat",
    transforms = [RandomCrop()]
)

train_dataloader = DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = True
)

test_dataset = SVHNDataset(
    project_path / "data" / "SVHN" / "test_data.mat",
)

test_dataloader = data.DataLoader(
    dataset = test_dataset,
    batch_size = batch_size,
    shuffle = True
)

d:\AIExperienments\TryTorch


In [9]:
# import matplotlib.pyplot as plt

# # 显示数据集

# img, label = train_dataset[42]

# type(img), type(label)
# # 因为是 (1, 28, 28)，需要 squeeze 去掉 channel 维度
# plt.imshow(img.squeeze())
# plt.title(f"Label: {label}")
# plt.axis("off")
# plt.show()
# 显示数据集
print(len(train_dataset), len(test_dataset))
# 显示数据加载器
len(train_dataloader), len(test_dataloader)

73257 26032


(287, 102)

In [11]:
from tqdm import tqdm

for epoch in range(epochs):
    total_loss = 0
    total_rights = 0
    total_examples = 0
    total_batches = 0
    step = 0;

    # 创建训练进度条
    train_bar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')

    for inputs, label in train_bar:

        net.train()
        
        optimizer.reset_grad()

        pred = net(inputs)

        loss = criterion(pred, label)

        loss.backward()

        optimizer.step()
        # (batch, features) -> (batch, 1)
        label_pred = np.argmax(pred.numpy(), axis = 1)

        rights = np.equal(label_pred, label.numpy()).sum()

        total_loss += loss.numpy()
        total_rights += rights
        total_batches += 1
        total_examples += inputs.shape[0]

        # 实时更新进度条信息
        avg_loss_so_far = total_loss / total_batches
        avg_accuracy_so_far = total_rights / total_examples
        
        # 更新进度条描述
        train_bar.set_postfix({
            'loss': f'{avg_loss_so_far:.4f}',
            'acc': f'{avg_accuracy_so_far:.4f}'
        })

    avg_loss = total_loss / total_batches
    avg_accuracy = total_rights / total_examples
    print(f"EPOCH {epoch}: {avg_accuracy=}, {avg_loss=}")



Epoch 1/20: 100%|██████████| 287/287 [02:01<00:00,  2.36batch/s, loss=1.0641, acc=0.6547]


EPOCH 0: avg_accuracy=0.6547360661779762, avg_loss=1.06411538023901


Epoch 2/20: 100%|██████████| 287/287 [02:00<00:00,  2.37batch/s, loss=0.8502, acc=0.7271]


EPOCH 1: avg_accuracy=0.7271114023233275, avg_loss=0.8501540264003344


Epoch 3/20: 100%|██████████| 287/287 [02:01<00:00,  2.37batch/s, loss=0.7348, acc=0.7679]


EPOCH 2: avg_accuracy=0.7678720122309131, avg_loss=0.7348087468075564


Epoch 4/20: 100%|██████████| 287/287 [02:01<00:00,  2.35batch/s, loss=0.6556, acc=0.7931]


EPOCH 3: avg_accuracy=0.7930982704724463, avg_loss=0.6556249274427153


Epoch 5/20: 100%|██████████| 287/287 [02:02<00:00,  2.34batch/s, loss=0.6028, acc=0.8091]


EPOCH 4: avg_accuracy=0.8091240427535935, avg_loss=0.6028050439750916


Epoch 6/20: 100%|██████████| 287/287 [02:07<00:00,  2.26batch/s, loss=0.5707, acc=0.8220]


EPOCH 5: avg_accuracy=0.8219555810366245, avg_loss=0.5706947777782184


Epoch 7/20: 100%|██████████| 287/287 [02:05<00:00,  2.29batch/s, loss=0.5352, acc=0.8340]


EPOCH 6: avg_accuracy=0.8340090366790887, avg_loss=0.535189157540488


Epoch 8/20: 100%|██████████| 287/287 [02:02<00:00,  2.34batch/s, loss=0.5125, acc=0.8386]


EPOCH 7: avg_accuracy=0.8386229302319232, avg_loss=0.512535637697603


Epoch 9/20: 100%|██████████| 287/287 [02:03<00:00,  2.33batch/s, loss=0.4913, acc=0.8462]


EPOCH 8: avg_accuracy=0.8461716969026851, avg_loss=0.49133031483937073


Epoch 10/20: 100%|██████████| 287/287 [02:03<00:00,  2.33batch/s, loss=0.4655, acc=0.8549]


EPOCH 9: avg_accuracy=0.8549490151111839, avg_loss=0.46551619647763737


Epoch 11/20: 100%|██████████| 287/287 [02:03<00:00,  2.33batch/s, loss=0.4581, acc=0.8562]


EPOCH 10: avg_accuracy=0.856177566648921, avg_loss=0.4581027388841347


Epoch 12/20: 100%|██████████| 287/287 [02:03<00:00,  2.33batch/s, loss=0.4389, acc=0.8640]


EPOCH 11: avg_accuracy=0.8639856941998717, avg_loss=0.4389104466578955


Epoch 13/20: 100%|██████████| 287/287 [02:02<00:00,  2.33batch/s, loss=0.4304, acc=0.8671]


EPOCH 12: avg_accuracy=0.8671116753347803, avg_loss=0.4304345804794977


Epoch 14/20: 100%|██████████| 287/287 [02:02<00:00,  2.33batch/s, loss=0.4212, acc=0.8688]


EPOCH 13: avg_accuracy=0.8688043463423291, avg_loss=0.4212249885916868


Epoch 15/20: 100%|██████████| 287/287 [02:05<00:00,  2.30batch/s, loss=0.4008, acc=0.8771]


EPOCH 14: avg_accuracy=0.8771311956536577, avg_loss=0.4007576113858574


Epoch 16/20: 100%|██████████| 287/287 [02:04<00:00,  2.30batch/s, loss=0.4048, acc=0.8745]


EPOCH 15: avg_accuracy=0.8744829845612023, avg_loss=0.4048198843666092


Epoch 17/20: 100%|██████████| 287/287 [02:05<00:00,  2.28batch/s, loss=0.4022, acc=0.8762]


EPOCH 16: avg_accuracy=0.8762439084319588, avg_loss=0.4022395106575674


Epoch 18/20: 100%|██████████| 287/287 [02:05<00:00,  2.29batch/s, loss=0.3919, acc=0.8794]


EPOCH 17: avg_accuracy=0.8794244918574334, avg_loss=0.39188964081752636


Epoch 19/20: 100%|██████████| 287/287 [02:05<00:00,  2.29batch/s, loss=0.3815, acc=0.8820]


EPOCH 18: avg_accuracy=0.8820181006593226, avg_loss=0.38150427633164435


Epoch 20/20: 100%|██████████| 287/287 [02:05<00:00,  2.29batch/s, loss=0.3815, acc=0.8826]

EPOCH 19: avg_accuracy=0.8826187258555497, avg_loss=0.38148129807163617



