In [1]:
import numpy as np
from scipy.ndimage import *
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.hub
import torch.utils.data as Data

from torchvision.datasets import *
import torchvision
import torchvision.transforms as transforms

# from torchsummary import summary
# import tensorboardX as tbx
# from tensorboardX import SummaryWriter

import random
import os
import time
import copy
import datetime
from PIL import *
import cv2
from cv2 import *
from collections import *

import argparse

from utility.output import *
from utility.metrics import computeMetrics
from utility.network_process import net_freeze_layer
from utility.plot import plotResultCurve
# from utility.edataset import *

In [3]:
# p = argparse.ArgumentParser()

# p.add_argument("--EPOCH", type=int, default=20)
# p.add_argument("--BATCH_SIZE", type=int, default=6)
# p.add_argument("--CV", type=int, default=5)
# p.add_argument("--NUM_CLASS", type=int, default=4)
# p.add_argument("--SCHEDULE_EPOCH", type=int, default=5)
# p.add_argument("--SCHEDULE_REGRESS", type=int, default=0.2)
# p.add_argument("--PARTIAL_TRAIN", action="store_true", default=False)
# p.add_argument("--PARTIAL_TRAIN_RATIO", type=float, default=0.003)
# p.add_argument("--NET_FREEZE", action="store_true", default=False)
# p.add_argument("--train_ratio", type=float, default=0.7)
# p.add_argument("--init_lr", type=float, default=1e-3)

# args = p.parse_args()

# EPOCH = args.EPOCH
# BATCH_SIZE = args.BATCH_SIZE
# NUM_CLASS = args.NUM_CLASS
# CV = args.CV

# SCHEDULE_EPOCH = args.SCHEDULE_EPOCH
# SCHEDULE_REGRESS = args.SCHEDULE_REGRESS

# ### 部分训练
# _PARTIAL_TRAIN = args.PARTIAL_TRAIN
# _PARTIAL_TRAIN_RATIO = args.PARTIAL_TRAIN_RATIO

# ### 冻结网络
# _NET_FREEZE = args.NET_FREEZE
# _NET_NO_GRAD = []

# P_lr = args.init_lr
# train_ratio = args.train_ratio 

In [2]:
#### 模型保存文件路径
MODEL_SAVE_PATH = './model'
#### 数据路径
DATA_PATH = r'E:\buffer\dataset\train'

# 定义Summary_Writer
# tensorboard --logdir=D:\IDE\MyProject\python\jupyter_notebook\Research\git-metallic\metallic\log_res
# writer = SummaryWriter(log_dir='./log_res',comment='resnet18')   # 数据存放在这个文件夹
# writer.export_scalars_to_json("./log_res/all_scalars.json")
# writer.close()

In [3]:
def getCurrentTime():
    return datetime.datetime.strftime(datetime.datetime.fromtimestamp(time.time()),format='%Y-%m-%d-%H-%M-%S')

# def RemoveImgInFolder(path_dir, ratio) :
#     # path文件夹路径，ratio删除比例
#     filelist = os.listdir(path_dir)
#     rmfilelist = random.sample(filelist, int(ratio*len(filelist)))
#     filepaths = [ os.path.join(path_dir, f) for f in rmfilelist ]
#     _ = [os.remove(filepath) if os.path.isfile(filepath) for filepath in filepaths]
#     return

def ImgSummary(path_dirs):
    filelists = [ os.listdir(x) for x in path_dirs ]
    num_class = [len(x) for x,p in zip(filelists,path_dirs)]
    return num_class

#### 模型保存
def checkpoint(model, optimizer, epoch, useTimeDir=False):
    # 保存整个模型  
    state = {'net':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
    model_name = str(model).split('(')[0]
    if useTimeDir is True:
        savePath = './'+MODEL_SAVE_PATH+'/'+getCurrentTime()
        os.mkdir(savePath)
    else:
        savePath = MODEL_SAVE_PATH
    dir = os.path.join(savePath,model_name+'_model.pth')
    torch.save(state, dir)
    return savePath if useTimeDir else None

#### 模型恢复
def modelrestore(model):
    model_name = str(model).split('(')[0]
    dir = os.path.join(MODEL_SAVE_PATH,model_name+'_model.pth')
    checkpoint = torch.load(dir)
    model.load_state_dict(checkpoint['net'])
    epoch = checkpoint['epoch'] + 1
    return model, epoch


def saveParameters(root_path):
    params = {
        'EPOCH': EPOCH,
        'BATCH_SIZE': BATCH_SIZE,
        'NUM_CLASS': NUM_CLASS,
        'CV': CV,
        'SCHEDULE_EPOCH': SCHEDULE_EPOCH,
        'SCHEDULE_REGRESS': SCHEDULE_REGRESS,
        '_PARTIAL_TRAIN': _PARTIAL_TRAIN,
        '_PARTIAL_TRAIN_RATIO': _PARTIAL_TRAIN_RATIO,
        '_NET_FREEZE': _NET_FREEZE,
        '_NET_NO_GRAD':  _NET_NO_GRAD,
        'P_lr': P_lr,
        'train_ratio': train_ratio,
    }
    with open(os.path.join(root_path,'params.txt'),'w+') as f:
        json.dump(params,f)
    print('parameters stored')
    return

In [4]:
EPOCH = 20
BATCH_SIZE =20
NUM_CLASS = 4
CV = 3

SCHEDULE_EPOCH = 10
SCHEDULE_REGRESS = 0.1

### 部分训练
_PARTIAL_TRAIN = True
_PARTIAL_TRAIN_RATIO = 0.5

### 冻结网络
_NET_FREEZE = False
_NET_NO_GRAD = []

P_lr = 1e-3
train_ratio = 0.7

In [5]:
#### image transformation for original images
data_transform_origin = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])

#### image transformation for augmented images
data_transform_aug = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
#     transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
    transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5, 0.5, 0.5])
])


# full_dataset = ImageFolder(DATA_PATH,transform = None)


class EDataset(Data.Dataset):

    def __init__(self, root_path, basic_transform, aug_transform, aug_ratio=0.3, aug_class=['3']):
        self.root_path = root_path
        self.basic_transform = basic_transform
        self.aug_transform = aug_transform
        self.image_origin = ImageFolder(self.root_path,
                                        transform = self.basic_transform)
        self.image_augment = ImageFolder(self.root_path,
                                        transform = self.aug_transform)
        if type(aug_class)==list and len(aug_class):
            self.image_augment = list(filter(lambda x:x[1] in aug_class,self.image_augment)) #### filter by class
        self.len_origin = len(self.image_origin)
        self.len_augment = int(len(self.image_augment)*aug_ratio)
        self.idx_augment = np.random.permutation(len(self.image_augment))
        return

    def __len__(self):
        return self.len_origin + self.len_augment

    def __getitem__(self, idx):
        if idx<self.len_origin:
            item = self.image_origin[idx]
        else:
            item = self.image_augment[ self.idx_augment[idx-self.len_origin] ]
        return item

    
def getModel(NUM_CLASS,name='se_resnet50'):
    return torch.hub.load(
            'moskomule/senet.pytorch',
            name,
            num_classes=NUM_CLASS
    )



full_dataset = EDataset(DATA_PATH,
                        basic_transform=data_transform_origin,
                        aug_transform=data_transform_aug,
                        aug_ratio=0.3 ### 数据扩充 0.3
                       )
total_size = len(full_dataset)

KeyboardInterrupt: 

In [None]:
class EDataset1(Data.Dataset):
    def __init__(self, root_path, basic_transform, aug_transform, augment={'3':2}):
        self.root_path = root_path
        self.basic_transform = basic_transform
        self.aug_transform = aug_transform
        self.augment = augment
        

        self.image_origin = ImageFolder(self.root_path,
                                        transform = self.basic_transform)
        self.origin_len = len(self.image_origin)
        
        self.image_augment = []
        
        self.num_classes = dict(Counter(np.array(self.image_origin.imgs)[:,1]))
        
        for k,v in self.augment.items():
            self.augment[k] = int(v*self.num_classes[k])
            
        current_idx = len(self.image_origin)
        self.idxs = {}
        while 1:
#             print(len(self.augment))

            if len(self.augment)==0:
                break
            
            new_image_augment = ImageFolder(self.root_path, transform = self.aug_transform)
            current_folder_idx = len(self.image_augment)
            self.image_augment.append(new_image_augment)
            
            idxs = []
            for idx,x in enumerate(new_image_augment.imgs):
                k = str(x[1])
                if k in self.augment.keys():
                    if self.augment[k] == 0:
                        _ = self.augment.pop(k)
                    else:
                        self.augment[k] -= 1
                        self.idxs[current_idx] = (current_folder_idx,idx)
                        current_idx += 1
            
            if len(self.augment)==0:
                break
                
        self.total_len = len(self.image_origin) + len(self.idxs)
        return

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        if idx<self.origin_len:
            item = self.image_origin[idx]
        else:
#             print(self.idxs)
            idx_augment = self.idxs[idx]
            item = self.image_augment[idx_augment[0]][idx_augment[1]]
        return item

# extend_dataset = EMultiSampleDataset(DATA_PATH, data_transform_aug, multiple=2, aug_class=['3'])
# extend_dataset[2]
full_dataset = EDataset(DATA_PATH,
                        basic_transform=data_transform_origin,
                        aug_transform=data_transform_aug,
                        augment={'3':2}
                       )
# full_dataset = ImageFolder(DATA_PATH,transform = data_transform_origin)
total_size = len(full_dataset)
# for i in range(total_size):
#     if i%1000==1:
#         print(i)
#     item = full_dataset[i]
# np.where(np.array([1,2,3])>=1)

In [None]:
_metrics = []
epoch_save = 0

for cv in range(CV):
    
    hub_model = getModel(NUM_CLASS=NUM_CLASS,name='se_resnet50')
    
    #### load model
    print('DEBUG:: fold ',cv)
    try:
        hub_model, epo = modelrestore(hub_model)
        print('Model successfully loaded')
        print('-' * 60)
    except Exception as e:
        print('Model not found, use the initial model')
        epo = 0
        print('-' * 60)
        
    net = hub_model
    
    #### define criterian & optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=P_lr)
    scheduler = lr_scheduler.StepLR(optimizer, SCHEDULE_EPOCH, SCHEDULE_REGRESS)
    
    #### use CUDA if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net.to(device)
    
    
    #### data splitting
    if _PARTIAL_TRAIN:
        full_dataset, _ = torch.utils.data.random_split(full_dataset, 
                                                        [
                                                            int(_PARTIAL_TRAIN_RATIO*total_size),
                                                            total_size - int(_PARTIAL_TRAIN_RATIO*total_size) 
                                                        ])
        total_size = int(_PARTIAL_TRAIN_RATIO*total_size)
        
    train_size = int(np.floor( total_size * train_ratio ))
    test_size = int(total_size - train_size)
    dataset_train, dataset_test = torch.utils.data.random_split(full_dataset, [train_size,test_size])
    
    #### training
    _loss = []
    __record_train_num = 0
    epoch_save = epo
    for epoch in range(epo, EPOCH):  # loop over the dataset multiple times
        epoch_save += 1
        print('DEBUG:: training epoch ',epoch)
        trainloader = Data.DataLoader(dataset=dataset_train, batch_size=BATCH_SIZE, shuffle=True)
        testloader = Data.DataLoader(dataset=dataset_test, batch_size=BATCH_SIZE, shuffle=True)
        
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            #### get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            #### zero the parameter gradients
            optimizer.zero_grad()
            #### forward + backward + optimize
            outputs = net(inputs)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            #### print statistics
            running_loss += loss.item()
            _loss.append(running_loss)
            
            __record_train_num += len(labels)
            if __record_train_num % (BATCH_SIZE * 50) == 0:
                print('DEBUG:: num has trained',__record_train_num)
            if i % 50 == 0:
                print('DEBUG:: trainloader:{}/{}'.format(i, len(trainloader)))
            if i % 50 == 0:
                try:
                    checkpoint(net, optimizer, epoch_save)
                    print('*' * 60)
                    print('Model is saved successfully at epoch {}'.format(str(epoch)))
                    print('*' * 60)
                except Exception as e:
                    print('*' * 60)
                    print('Something is wrong!',e)
                    print('*' * 60)
                    
        #### predicting
        print('=' * 60)
        print('i:', i)
        print('Start predicting')
        Ypred = []
        Ytest = []
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                images = images.to(device)
                labels = labels.to(device)
                outputs = net(images)
                _, predicted = torch.max(outputs, -1)
                Ytest.extend(labels.tolist())
                Ypred.extend(predicted.tolist())

        _metrics.append(computeMetrics(Ypred,Ytest))
        print("accuracy is {}".format(_metrics[-1]['acc']) )
        print("auc is {}".format(_metrics[-1]['auc']) )
        print('=' * 60)
    

print('-' * 60)
print('Training is over, saving the model')
print('-' * 60)
try:
    savePath = checkpoint(net, optimizer, epoch_save, useTimeDir=True)
    saveResult(_metrics,savePath)
    saveParameters(savePath)
    print('Model is saved successfully')
except Exception as e:
    print('Something is wrong!',e)
    raise e
    
    
    

In [None]:
#### plotting
plotResultCurve(_metrics,['acc','auc'],'acc-auc')
plotResultCurve(_metrics,['fpr','tpr'],'fpr-tpr')

In [15]:
def make_weights_for_balanced_classes(images, nclasses):                        
    count = [0] * nclasses                                                      
    for item in images:                                                         
        count[item[1]] += 1
    weight_per_class = [0.] * nclasses                                      
    N = float(len(images))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(images)                                              
    for idx, val in enumerate(images):                                          
        weight[idx] = weight_per_class[val[1]]                                  
    return weight   

# num_imgs = ImgSummary([
#             DATA_PATH+r'\0',
#             DATA_PATH+r'\1',
#             DATA_PATH+r'\2',
#             DATA_PATH+r'\3',
# ])

# weights = 1.0/np.array(num_imgs)
# weights = weights/np.sum(weights)
# list(Data.WeightedRandomSampler(weights, len(weights)))
# weights

full_dataset = ImageFolder(DATA_PATH,transform = data_transform_origin)
# For unbalanced dataset we create a weighted sampler                       
weights = make_weights_for_balanced_classes(full_dataset.imgs, len(full_dataset.classes))                                                                
weights = torch.DoubleTensor(weights)                                       
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))                     

train_loader = torch.utils.data.DataLoader(dataset=full_dataset, batch_size=BATCH_SIZE, sampler = sampler)

[tensor([[[[-0.7020, -0.7020, -0.7098,  ..., -0.0118, -0.0275, -0.0353],
          [-0.7176, -0.7176, -0.7176,  ..., -0.0196, -0.0275, -0.0353],
          [-0.7333, -0.7255, -0.7255,  ..., -0.0353, -0.0431, -0.0510],
          ...,
          [-0.0745, -0.0824, -0.0980,  ..., -0.0824, -0.0980, -0.1059],
          [-0.0902, -0.0902, -0.0902,  ..., -0.1059, -0.1216, -0.1216],
          [-0.1059, -0.1059, -0.0980,  ..., -0.1216, -0.1294, -0.1373]],

         [[-0.7333, -0.7333, -0.7412,  ..., -0.3098, -0.3255, -0.3333],
          [-0.7490, -0.7490, -0.7490,  ..., -0.3176, -0.3255, -0.3333],
          [-0.7647, -0.7569, -0.7569,  ..., -0.3255, -0.3333, -0.3412],
          ...,
          [-0.2863, -0.2941, -0.3098,  ..., -0.3098, -0.3255, -0.3333],
          [-0.3020, -0.3020, -0.3020,  ..., -0.3333, -0.3490, -0.3490],
          [-0.3176, -0.3176, -0.3098,  ..., -0.3490, -0.3569, -0.3647]],

         [[-0.7569, -0.7569, -0.7647,  ..., -0.3804, -0.3961, -0.4039],
          [-0.7725, -0.7725, 

KeyboardInterrupt: 