In [1]:
import os
import sys
import json

import torch
import torch.nn as nn
from torchvision import transforms, datasets
import tqdm
import math
import numpy as np
import torch
import sys

import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
from torchsummary import summary
from tqdm import tqdm
from mydataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate
from model import VGG, vgg

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.cuda.is_available()
tb_writer = SummaryWriter()

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using {} device.".format(device))

if os.path.exists("./weights") is False:
    os.makedirs("./weights")
data_path = '/Users/tanx/mtetna/MicSigV1/MicSigV1_jsn/MicSigV1/V1_spec_no_nl'
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(data_path, 0.3)
# train_images_label = train_images_label - 1
# val_images_label = val_images_label - 1

batch_size = 16
lrm = 0.005
lrf = 0.00001
epochs = 50

data_transform = {
    "train": transforms.Compose([transforms.Resize((128,192)),
                                 transforms.ToTensor()]),
    "val": transforms.Compose([transforms.Resize((128,192)),
                                 transforms.ToTensor()])}

train_dataset = MyDataSet(images_path=train_images_path,
                          images_class=train_images_label,
                          transform=data_transform["train"])

# 实例化验证数据集
val_dataset = MyDataSet(images_path=val_images_path,
                        images_class=val_images_label,
                        transform=data_transform["val"])

using cpu device.
1187 images were found in the dataset.
832 images for training.
355 images for validation.


In [4]:
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 4])  # number of workers
print('Using {} dataloader workers every process'.format(nw))
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True,
                                           num_workers=nw,
                                           collate_fn=train_dataset.collate_fn)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=nw,
                                         collate_fn=val_dataset.collate_fn)

model = vgg(model_name='vgg_small', num_classes=5, init_weights=True, inchannels=1).to(device)

Using 4 dataloader workers every process


In [5]:
summary(model, (1, 128, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 16, 128, 192]             160
              ReLU-2         [-1, 16, 128, 192]               0
         MaxPool2d-3           [-1, 16, 64, 96]               0
            Conv2d-4           [-1, 32, 64, 96]           4,640
              ReLU-5           [-1, 32, 64, 96]               0
         MaxPool2d-6           [-1, 32, 32, 48]               0
            Conv2d-7           [-1, 64, 32, 48]          18,496
              ReLU-8           [-1, 64, 32, 48]               0
            Conv2d-9           [-1, 64, 32, 48]          36,928
             ReLU-10           [-1, 64, 32, 48]               0
        MaxPool2d-11           [-1, 64, 16, 24]               0
           Conv2d-12          [-1, 128, 16, 24]          73,856
             ReLU-13          [-1, 128, 16, 24]               0
           Conv2d-14          [-1, 128,

In [6]:
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(pg, lr=lrm, weight_decay=5E-5)
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
lf = lambda x: ((1 + math.cos(x * math.pi / epochs)) / 2) * (lrm - lrf) + lrf  # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

In [7]:
for epoch in range(epochs):
    # train
    train_loss, train_acc = train_one_epoch(model=model,
                                            optimizer=optimizer,
                                            data_loader=train_loader,
                                            device=device,
                                            epoch=epoch)
    scheduler.step()
    # validate
    val_loss, val_acc = evaluate(model=model,
                                 data_loader=val_loader,
                                 device=device,
                                 epoch=epoch)

    tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
    tb_writer.add_scalar(tags[0], train_loss, epoch)
    tb_writer.add_scalar(tags[1], train_acc, epoch)
    tb_writer.add_scalar(tags[2], val_loss, epoch)
    tb_writer.add_scalar(tags[3], val_acc, epoch)
    tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))

[train epoch 0] loss: 1.125, acc: 0.846: 100%|██| 52/52 [00:34<00:00,  1.50it/s]
[valid epoch 0] loss: 0.609, acc: 0.882: 100%|██| 23/23 [00:23<00:00,  1.00s/it]
[train epoch 1] loss: 0.509, acc: 0.879: 100%|██| 52/52 [00:34<00:00,  1.50it/s]
[valid epoch 1] loss: 0.518, acc: 0.882: 100%|██| 23/23 [00:23<00:00,  1.00s/it]
[train epoch 2] loss: 0.455, acc: 0.879: 100%|██| 52/52 [00:34<00:00,  1.49it/s]
[valid epoch 2] loss: 0.431, acc: 0.885: 100%|██| 23/23 [00:23<00:00,  1.00s/it]
[train epoch 3] loss: 0.371, acc: 0.888: 100%|██| 52/52 [00:34<00:00,  1.49it/s]
[valid epoch 3] loss: 0.327, acc: 0.915: 100%|██| 23/23 [00:23<00:00,  1.01s/it]
[train epoch 4] loss: 0.317, acc: 0.921: 100%|██| 52/52 [00:34<00:00,  1.49it/s]
[valid epoch 4] loss: 0.290, acc: 0.921: 100%|██| 23/23 [00:23<00:00,  1.02s/it]
[train epoch 5] loss: 0.283, acc: 0.925: 100%|██| 52/52 [00:34<00:00,  1.49it/s]
[valid epoch 5] loss: 0.243, acc: 0.938: 100%|██| 23/23 [00:23<00:00,  1.00s/it]
[train epoch 6] loss: 0.268,