In [None]:
import sys
import os
import logging
import re
from time import time as ttime

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models
from torchinfo import summary
from torch.utils.data import Dataset, IterableDataset, DataLoader, random_split
import numpy as np

from PhotonDataset import transform, PhotonDataset

logging.getLogger("PIL").setLevel(logging.ERROR)
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
logger = logging

In [2]:
BASE_PATH="/mnt/ossdata/"

INSPECT2ID = {
    "Stratum_corneum" : 0,
    "DEJunction" : 1,
    "ELCOR" : 2,
}

IDS2INSPECT = {
    0 : "Stratum_corneum",
    1 : "DEJunction",
    2 : "ELCOR",
}

PRECISION_WINDOWS=3

MODEL_DIR_PATH="/versions"
MODEL_FILE_NAME="DoublePhoton"
THRESHOLD=0.7
LEARNING_RATE=0.0001
MODEL_SAVE_VER=150
WEIGHT_RATE=0.8

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [4]:
class WPLoss(nn.Module):
    def __init__(self):
        super(WPLoss, self).__init__()
        self._innerLoss = nn.BCELoss()
        self._sigmod = nn.Sigmoid()

    def forward(self, hatouts, targets):
        matchCnt, totalCnt, loss = 0.0, 0.0, 0.0
        for hatout, tgt in zip(hatouts, targets):
            mask = torch.gt(tgt, -1)
            hatout = self._sigmod(torch.masked_select(hatout, mask))
            tgt = torch.masked_select(tgt, mask)

            if tgt.shape[0] > 0:
                totalCnt += tgt.shape[0]
                loss += self._innerLoss(hatout, tgt)

                for i, x in zip(range(tgt.shape[0]), tgt):
                    if x == 1 and hatout[i] > THRESHOLD:
                        matchCnt += 1
                    elif x == 0 and hatout[i] <= THRESHOLD:
                        matchCnt += 1

        return matchCnt, totalCnt, loss

def computeParameterPrecision(hatouts, targets):
    sigmod = nn.Sigmoid()
    preTensor = torch.zeros((len(INSPECT2ID), 2))
    # posAcc, missAcc, negAcc
    frTensor = torch.zeros((len(INSPECT2ID), 3))
    for hatout, tgt in zip(hatouts, targets):
        hatout = sigmod(hatout)
        for seq, iValue, tValue in zip(range(hatout.shape[0]), hatout, tgt):
            if tValue <= -0.5:
                continue

            if (tValue > 0.8 and iValue > THRESHOLD) \
                or (tValue < 0.2 and iValue <= THRESHOLD):
                preTensor[seq][0] += 1
                preTensor[seq][1] += 1
            else:
                preTensor[seq][1] += 1

            # cal precision and recall
            if (tValue > 0.8 and iValue > THRESHOLD):
                frTensor[seq][0] += 1
            elif (tValue > 0.8 and iValue < THRESHOLD):
                frTensor[seq][2] += 1
            elif (tValue < 0.2 and iValue > THRESHOLD):
                frTensor[seq][1] += 1
    return preTensor, frTensor

In [5]:
def listModelVer(subDir, modelName):
    vers = []
    for tmpFile in os.listdir(subDir):
        if re.match(modelName + "-\d+\.pth", tmpFile):
            _, saveTime, _ = re.split("-|\.", tmpFile)
            vers.append((tmpFile, int(saveTime)))
    vers.sort(key=lambda xx : xx[1], reverse=True)
    return [x for x, _ in vers]

def saveCheckpoint(model, optimizer, checkpointPath):
    subDir = os.path.dirname(checkpointPath)
    modelName = os.path.basename(checkpointPath)

    if not os.path.exists(subDir):
        os.makedirs(subDir)

    fullModelName = "%s-%s.pth"%(modelName, int(ttime()))
    fullPath = "%s/%s"%(subDir, fullModelName)
    logger.info("@szh:Saving model and optimizer state at iteration at {}".format(fullPath))

    oldVers = listModelVer(subDir, modelName)
    if len(oldVers) >= MODEL_SAVE_VER:
        for oldVer in oldVers[MODEL_SAVE_VER - 1:]:
            os.remove("%s/%s"%(subDir, oldVer))

    if hasattr(model, "module"):
        state_dict = model.module.state_dict()
    else:
        state_dict = model.state_dict()

    torch.save({
        "model" : state_dict,
        "optimizer" : optimizer.state_dict() if optimizer else None,
    }, fullPath)

def loadCheckPoint(model, optimizer, checkpointPath):
    subDir = os.path.dirname(checkpointPath)
    modelName = os.path.basename(checkpointPath)

    if not os.path.exists(subDir):
        os.makedirs(subDir)

    oldVers = listModelVer(subDir, modelName)
    if len(oldVers) > 0:
        fullPath = "%s/%s"%(subDir, oldVers[0])
        logger.info("@szh:load model and optimizer state from file: {}".format(fullPath))
        checkpointDict = torch.load(fullPath, map_location=device)
        if optimizer is not None and checkpointDict["optimizer"] is not None:
            optimizer.load_state_dict(checkpointDict["optimizer"])

        if hasattr(model, "module"):
            model.module.load_state_dict(checkpointDict["model"])
        else:
            model.load_state_dict(checkpointDict["model"])
    else:
        return

In [None]:
vgg16 = models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
num_classes = len(IDS2INSPECT)
vgg16.classifier[6] = nn.Linear(4096, num_classes)

loadCheckPoint(vgg16, None, BASE_PATH + MODEL_DIR_PATH + "/" + MODEL_FILE_NAME)
optimizer = optim.SGD(vgg16.parameters(), lr=LEARNING_RATE, momentum=0.9)

vgg16.to(device)

metrics = WPLoss()
metrics.to(device)

In [7]:
def train_long(model, trainloaderPath, valloaderPath, loss_fn, epochs=5, optimizer=None, print_freq=10, save_freq=10):
    optimizer = optimizer or torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    run_seq = 0
    for epoch in range(epochs):
        model.train()

        norLoss, acc, cnt = 0.0, 0.0, 0.0
        sepLoss, wAcc, wCnt = 0.0, 0.0, 0.0
        prec, wPrec = torch.zeros((len(INSPECT2ID), 2)), torch.zeros((len(INSPECT2ID), 2))
        frPrec, wFRPrec = torch.zeros((len(INSPECT2ID), 3)), torch.zeros((len(INSPECT2ID), 3))
    
        trainFiles = os.listdir(trainloaderPath)
        for trainFile in trainFiles:
            if not trainFile.startswith("train"):
                continue
            logger.info("@szh: load train data: {}".format(trainFile))
            train_loader = torch.load(trainloaderPath + "/" + trainFile)
            batch_seq = 0
            for features, labels, winLabels in train_loader:
                lbls = labels.to(device)
                wLbls = winLabels.to(device)
                features = features.to(device)

                optimizer.zero_grad()
                out = model(features)

                bMatchCnt, bCnt, bNorLoss = loss_fn(out, lbls)
                wBMatchCnt, wBCnt, wBSpecLoss = loss_fn(out, wLbls)

                bWeiLoss = (wBSpecLoss / (wBCnt + 0.001)) * WEIGHT_RATE + (bNorLoss / bCnt) * (1 - WEIGHT_RATE)
                #bWeiLoss.backward()
                bNorLoss.backward()
                optimizer.step()

                lbls = labels.to('cpu')
                wLbls = wLbls.to('cpu')
                features = features.to('cpu')

                norLoss += bNorLoss.item()
                acc += bMatchCnt
                cnt += bCnt

                sepLoss += wBSpecLoss.item()
                wAcc += wBMatchCnt
                wCnt += wBCnt

                tmpPrec, tmpFRPrec = computeParameterPrecision(out, lbls)
                prec += tmpPrec
                frPrec += tmpFRPrec

                tmpPrec, tmpFRPrec = computeParameterPrecision(out, wLbls)
                wPrec += tmpPrec
                wFRPrec += tmpFRPrec

                if (batch_seq + 1) % print_freq == 0:
                    weiLoss = (sepLoss / (wCnt + 0.001)) * WEIGHT_RATE + (norLoss / cnt) * (1 - WEIGHT_RATE)
                    logger.info("Epoch {}, minibatch {}: weiLoss = {}, sepcialLoss = {}, specialAcc = {}, norLoss = {}, norAcc = {}".format(\
                        epoch,\
                        batch_seq,\
                        weiLoss,\
                        sepLoss / wCnt,\
                        wAcc / wCnt,\
                        norLoss / cnt,\
                        acc / cnt))

                lbls = None
                wLbls = None
                features = None
                batch_seq += 1
                run_seq += 1

            train_loader = None

        saveCheckpoint(model, optimizer, BASE_PATH + MODEL_DIR_PATH + "/" + MODEL_FILE_NAME)

        weiLoss = (sepLoss / (wCnt + 0.001)) * WEIGHT_RATE + (norLoss / cnt) * (1 - WEIGHT_RATE)
        prec = torch.cat([prec, (prec[:, 0] / prec[:, 1]).unsqueeze(1)], dim=1)
        wPrec = torch.cat([wPrec, (wPrec[:, 0] / wPrec[:, 1]).unsqueeze(1)], dim=1)

        frPrec = torch.cat([frPrec, (frPrec[:, 0] / (frPrec[:, 0] + frPrec[:, 1])).unsqueeze(1), (frPrec[:, 0] / (frPrec[:, 0] + frPrec[:, 2])).unsqueeze(1)], dim=1)
        wFRPrec = torch.cat([wFRPrec, (wFRPrec[:, 0] / (wFRPrec[:, 0] + wFRPrec[:, 1])).unsqueeze(1), (wFRPrec[:, 0] / (wFRPrec[:, 0] + wFRPrec[:, 2])).unsqueeze(1)], dim=1)
        logger.info("Epoch {} done, TrainData: weiLoss = {}, sepcialLoss = {}, specialAcc = {}, norLoss = {}, norAcc = {}， fcPrec={}, wFCPrec={}"\
                    .format(epoch, weiLoss, sepLoss / wCnt, wAcc / wCnt, norLoss / cnt, acc / cnt,  frPrec, wFRPrec ))

        vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vWPrec, vFCPrec, vWFCPrec = validateOrTest(model, valloaderPath, loss_fn)
        logger.info("Epoch {} done, ValidationData: weiLoss = {}, sepcialLoss = {}, specialAcc = {}, norLoss = {}, norAcc = {}, prec = {}, wPrec={}, fcPrec={}, wFCPrec={}".\
                    format(epoch, vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vFCPrec, vFCPrec, vWFCPrec))

In [45]:
def validateOrTest(model, dataloaderPath, loss_fn, datasetType='validation'):
    norLoss, acc, cnt = 0.0, 0.0, 0.0
    sepLoss, wAcc, wCnt = 0.0, 0.0, 0.0

    prec, wPrec = torch.zeros((len(INSPECT2ID), 2)), torch.zeros((len(INSPECT2ID), 2))
    frPrec, wFRPrec = torch.zeros((len(INSPECT2ID), 3)), torch.zeros((len(INSPECT2ID), 3))

    model.eval()
    with torch.no_grad():
        dataloaderFiles = os.listdir(dataloaderPath)
        for dataloaderFile in dataloaderFiles:
            if not dataloaderFile.startswith(datasetType):
                continue

            dataloader = torch.load(dataloaderPath + "/" + dataloaderFile)
            logger.info("@szh: load {} data: {}".format(datasetType, dataloaderFile))
            for features, labels, winLabels in dataloader:            
                lbls = labels.to(device)
                wLbls = winLabels.to(device)
                features = features.to(device)

                out = model(features)
  
                bMatchCnt, bCnt, bNorLoss = loss_fn(out, lbls)
                wBMatchCnt, wBCnt, wBSpecLoss = loss_fn(out, wLbls)

                lbls = labels.to('cpu')
                wLbls = wLbls.to('cpu')
                features = features.to('cpu')
                out = out.to('cpu')

                norLoss += bNorLoss
                acc += bMatchCnt
                cnt += bCnt

                sepLoss += wBSpecLoss
                wAcc += wBMatchCnt
                wCnt += wBCnt

                tmpPrec, tmpFRPrec = computeParameterPrecision(out, lbls)
                prec += tmpPrec
                frPrec += tmpFRPrec

                tmpPrec, tmpFRPrec = computeParameterPrecision(out, wLbls)
                wPrec += tmpPrec
                wFRPrec += tmpFRPrec

                lbls = None
                wLbls = None
                features = None
            dataloader = None

        weiLoss = (sepLoss / (wCnt + 0.001)) * WEIGHT_RATE + (norLoss / cnt) * (1 - WEIGHT_RATE)
        prec = torch.cat([prec, (prec[:, 0] / prec[:, 1]).unsqueeze(1)], dim=1)
        wPrec = torch.cat([wPrec, (wPrec[:, 0] / wPrec[:, 1]).unsqueeze(1)], dim=1)

        frPrec = torch.cat([frPrec, (frPrec[:, 0] / (frPrec[:, 0] + frPrec[:, 1])).unsqueeze(1), (frPrec[:, 0] / (frPrec[:, 0] + frPrec[:, 2])).unsqueeze(1)], dim=1)
        wFRPrec = torch.cat([wFRPrec, (wFRPrec[:, 0] / (wFRPrec[:, 0] + wFRPrec[:, 1])).unsqueeze(1), (wFRPrec[:, 0] / (wFRPrec[:, 0] + wFRPrec[:, 2])).unsqueeze(1)], dim=1)
    return weiLoss, sepLoss / wCnt, wAcc / wCnt, norLoss / cnt, acc / cnt, prec, wPrec, frPrec, wFRPrec

In [11]:
TRAIN_DATA_PATH=BASE_PATH + "/dataloader"
VALIDATION_DATA_PATH=BASE_PATH + "/dataloader"
TEST_DATA_PATH=BASE_PATH + "/dataloader"

In [None]:
train_long(vgg16, TRAIN_DATA_PATH, VALIDATION_DATA_PATH, metrics, epochs=50, optimizer=optimizer, print_freq=500, save_freq=600)

In [None]:
vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vWPrec, vFCPrec, vWFCPrec = validateOrTest(vgg16, VALIDATION_DATA_PATH, metrics)
logger.info("ValidationData: weiLoss = {}, sepcialLoss = {}, specialAcc = {}, norLoss = {}, norAcc = {}, prec = {}, wPrec={}, fcprec = {}, wRcPrec={".\
                    format(vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vWPrec, vFCPrec, vWFCPrec))

In [None]:
vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vWPrec, vFCPrec, vWFCPrec = validateOrTest(vgg16, TEST_DATA_PATH, metrics, 'test')
logger.info("TestData: weiLoss = {}, sepcialLoss = {}, specialAcc = {}, norLoss = {}, norAcc = {}, prec = {}, wPrec={}, fcprec = {}, wRcPrec={}".\
                    format(vWeiLoss, vSepLoss, vWAcc, vNorLoss, vAcc, vPrec, vWPrec, vFCPrec, vWFCPrec))