In [None]:
import torch
import numpy as np
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import time
from tqdm import tqdm
import torch.nn.functional as F
from torch.nn.functional import cross_entropy
import os
from thop import profile
import pandas as pd
import glob
import math
import torch.optim.lr_scheduler as lr_scheduler
import sys

# 配置运行设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 设置参数
batch_size = 32
learning_rate = 0.00003
num_epoch = 50
model_name = 'basenet'
# only_train_fc = True

In [None]:
# 数据处理
normalize = transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1])
transform = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    normalize
])

# 读取图像数据
train_dataset = ImageFolder('/kaggle/input/oral-cancer-dataset/Oral Cancer5/train/', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = ImageFolder('/kaggle/input/oral-cancer-dataset/Oral Cancer5/test/', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print('{0} for train. {1} for val'.format(len(train_dataset), len(test_dataset)))

In [None]:
class Basenet(nn.Module):
    def __init__(self):
        super().__init__()
        '''
        input_shape: [B, 3, 224, 224]

        '''
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv_layer1 = nn.Conv2d(3, 16, 5, stride=1, padding=2)
        self.conv_layer2 = nn.Conv2d(16, 32, 3, stride=1, padding=1)
        self.conv_layer3 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv_layer4 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv_layer5 = nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.conv_layer6 = nn.Conv2d(256, 512, 3, stride=1, padding=1)
        self.inception = nn.Conv2d(32, 128, 3, stride=1, padding=1)
        self.conv_layer_res = nn.Conv2d(128, 128, 3, stride=1, padding=1)

        
        self.relu = nn.ReLU(inplace=True)
        self.adptavgpool2d = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(512, 2, bias=True)
        self.flatten = nn.Flatten(start_dim=1, end_dim=-1)

    def forward(self, x):
        """
        args:
            x: Input img, (B, 3, 224, 224)
        return:
            the calss capsules, ench capsule is a 16 dimension vector

        """
        x = self.relu(self.conv_layer1(x))  # [B, 16, 224, 224]
        x = self.maxpool(x)# [B, 16, 112, 112]
        x = self.relu(self.conv_layer2(x))
        y = self.maxpool(x)# [B, 32, 56, 56]
        
        
        x = self.relu(self.conv_layer3(y))
        x = self.maxpool(x) # [B, 64, 28, 28]
        x = self.conv_layer4(x) # [B, 128, 28, 28]
        
        y = self.inception(y) # [B, 128, 56, 56]
        y = self.maxpool(self.conv_layer_res(y)) # [B, 128, 28, 28]
        
        x = self.relu(x + y)
        x = self.maxpool(x) # [B, 128, 14, 14]
        x = self.relu(self.conv_layer5(x)) # [B, 256, 14, 14]
        x = self.maxpool(x)
        x = self.relu(self.conv_layer6(x)) # [B, 512, 7, 7]
        
        x = self.adptavgpool2d(x)
        x = self.flatten(x)
        x = self.fc(x)
        return x

In [None]:
pretrained_model_path = '/kaggle/input/oral-cancer-dataset/breakhis/basenet/best/basenet-oral_epoch65.pth'
save_path = '/kaggle/working/check_point'
if not os.path.exists(save_path):
    os.mkdir(save_path)
result_path = os.path.join('/kaggle/working/check_point', model_name)
if not os.path.exists(result_path):
    print(1)
    os.mkdir(result_path)
if len(glob.glob(result_path + '/**.pth')) == 0:
    model = torch.load(pretrained_model_path)
else:
    model_path = glob.glob(result_path + '/**.pth')[-1]
    model = torch.load(model_path)
# if only_train_fc:
#     for param in model.parameters():
#         param.requires_grad_(False)
#     for param in model.fc.parameters():
#         param.requires_grad_(True)

train_test_data_path = os.path.join(result_path, 'train_test_data.csv')
if os.path.exists(train_test_data_path):
    train_test_data = pd.read_csv(train_test_data_path)
    last_epoch = train_test_data.shape[0]
    test_acc_best = train_test_data['test_acc_best'].values[-1]
else:
    train_test_data = pd.DataFrame(data=[], columns=['train_acc', 'train_loss',
                                                     'train_lr',
                                                     'test_acc', 'test_loss', 
                                                     'epoch', 'test_acc_best'])
    last_epoch = 0
    test_acc_best = 0
assert num_epoch > last_epoch, '已达训练次数'

In [None]:
model = model.to(device)
# 输出模型参数与模型计算量

flops, params = profile(model, inputs=(torch.zeros((batch_size, 3, 224, 224)).to(device),), verbose=False)
print(f'number of parameter: {params}', ', %.1f GFLOPS' % (flops / 1E9 * 2))

def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters()if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}
print(get_parameter_number(model))

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')


def train(epoch, dataloder):
    model.train()

    t0 = time.time()
    for (X_batch, y_batch) in tqdm(dataloder, leave=False, desc=f'epoch:{epoch}'):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()

        out_put = model(X_batch)
        loss = loss_fn(out_put, y_batch)
        loss.backward()
        optimizer.step()
    now_lr = optimizer.param_groups[0]["lr"]
#     scheduler.step()
    t1 = time.time()
    print(f'epoch[{epoch}] time[{round(t1 - t0, 1)}]s lr:{now_lr}')
    return now_lr

def evaluate(data_loader, type):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        eval_loss = 0
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(F.softmax(outputs, dim=-1), 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            eval_loss += loss_fn(outputs, labels).item()
    acc = 100 * correct / total
    print(f'Accuracy on {type} set: {round(acc, 2)}%  eval_loss:{eval_loss / total}')
    return acc, eval_loss / total


'''
每次运行之后要保存：
1.模型本身
2.模型迭代次数
3.模型学习率
4.历代模型表现情况loss,acc等

'''


if __name__ == '__main__':
    weight_path_best = os.path.join(result_path, 'best')
    if not os.path.exists(weight_path_best):
        os.mkdir(weight_path_best)
    for epoch in range(last_epoch, num_epoch):
        weight_name = '{}-oral'.format(model_name)
        lr = train(epoch, train_loader)
        train_acc, train_loss = evaluate(train_loader, type='train')
        test_acc, test_loss = evaluate(test_loader, type='test')
        
        # 保存训练好的模型之前，删掉已有的模型
        for f in glob.glob(result_path + '/**.pth'):
            os.remove(f)
        weight_name = weight_name + '_epoch{}.pth'.format(epoch + 1)
        weight_path = os.path.join(result_path, weight_name)
        torch.save(model, weight_path)
        
        if test_acc > test_acc_best:
            test_acc_best = test_acc
            for f in glob.glob(weight_path_best + '/**.pth'):
                os.remove(f)
            torch.save(model, os.path.join(weight_path_best, weight_name))
        train_test_data.loc[len(train_test_data.index)] = [train_acc, train_loss,
                                                           lr,
                                                           test_acc, test_loss,
                                                           epoch, test_acc_best]

        train_test_data.to_csv(train_test_data_path, index=False)

    train_acc_array = train_test_data['train_acc'].values
    test_acc_array = train_test_data['test_acc'].values
    train_loss_array = train_test_data['train_loss'].values
    test_loss_array = train_test_data['test_loss'].values

    plt.figure(figsize=(8, 8))
    plt.subplot(2, 1, 1)
    plt.plot(train_acc_array, label='Training Accuracy')
    plt.plot(test_acc_array, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.title('Capsnet Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(train_loss_array, label='Training Loss')
    plt.plot(test_loss_array, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Margin Entropy')
    plt.title('Capsnet Training and Validation Loss')
    plt.xlabel('epoch')
    plt.show()