In [1]:
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, inChannel, outChannel, downSampling=False, upSampling=False):
        super().__init__()
        # (Lin + 2P -k) / S + 1
        self.down = nn.Conv1d(inChannel, outChannel, kernel_size=2, padding=0, stride=2)
        # (Lin-1) * S - 2P + k
        self.up = nn.ConvTranspose1d(inChannel, outChannel, kernel_size=2, stride=2, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.downSampling = downSampling
        self.upSampling = upSampling
        
        assert not downSampling or not upSampling, "Can't perform downsampling and upsampline simultaneously."
        layers = []

        if downSampling:
            layers.append(nn.Conv1d(inChannel, outChannel, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Conv1d(outChannel, outChannel, kernel_size=2, padding=0, stride=2))
            
        elif upSampling:
            layers.append(nn.ConvTranspose1d(inChannel, outChannel, kernel_size=2, stride=2, padding=0))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Conv1d(outChannel, outChannel, kernel_size=3, padding=1))
            
        else:
            layers.append(nn.Conv1d(inChannel, outChannel, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Conv1d(outChannel, outChannel, kernel_size=3, padding=1))

        self.block = nn.Sequential(*layers)
    
    def forward(self, inPut):
        residual = inPut
        out = self.block(inPut)
        if self.downSampling:
            residual = self.down(inPut)
        elif self.upSampling:
            residual = self.up(inPut)
            
        out += residual
        out = self.relu(out)
        
        return out      
        
        
class ResNet(nn.Module):
    def __init__(self, firstChannels=[32, 64, 128], secondChannels=[128, 64, 32], firstBlockRepeats=[1, 1, 1], secBlockRepeats=[1, 1, 1],
                 upsampling=[False, True, True], downsampling=[False, True, True], useUpsampling=False,
                 dataLen=512, outSizeChange=False, numClasses=1):
        super(ResNet, self).__init__()
        outSize = dataLen*(2**(-sum(downsampling)+sum(upsampling))) if useUpsampling else dataLen*(2**(-sum(downsampling)-sum(upsampling)))
        self.preLayer = nn.Conv1d(1, firstChannels[0], kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Flatten()
        
        self.linear1 = nn.Linear(int(outSize), numClasses) if outSizeChange else nn.Linear(int((secondChannels[-1])*outSize), numClasses)
        
        self.firstModule = nn.ModuleList([ResBlock(firstChannels[0], firstChannels[0], downSampling=downsampling[0])]*firstBlockRepeats[0])
        for i in range(1, len(firstChannels)):
            self.firstModule.append(ResBlock(firstChannels[i-1], firstChannels[i], downSampling=downsampling[i]))
            for j in range(firstBlockRepeats[i]):
                self.firstModule.append(ResBlock(firstChannels[i], firstChannels[i], downSampling=downsampling[0]))
            
        assert firstChannels[-1] == secondChannels[0], 'the last value of firstChannels should equal to the first value of secondChannels'
        
        #self.outLayers = nn.ModuleList([ResBlock(secondChannels[0]*2, secondChannels[0], True)]) #use torch.cat
        self.secondModule = nn.ModuleList([])
        if useUpsampling:
            for i in range(1, len(secondChannels)):
                self.secondModule.append(ResBlock(secondChannels[i-1], secondChannels[i], upSampling=upsampling[i]))
                for j in range(secBlockRepeats[i]):
                    self.secondModule.append(ResBlock(secondChannels[i], secondChannels[i], upSampling=upsampling[0]))
                    
        else:
            for i in range(1, len(secondChannels)):
                self.secondModule.append(ResBlock(secondChannels[i-1], secondChannels[i], downSampling=upsampling[i]))
                for j in range(secBlockRepeats[i]):
                    self.secondModule.append(ResBlock(secondChannels[i], secondChannels[i], upSampling=upsampling[0]))
            
        if len(secondChannels)==1:self.secondModule.append(ResBlock(secondChannels[0], secondChannels[0]))
            
        if outSizeChange: self.secondModule.append(nn.Conv1d(secondChannels[-1], 1, kernel_size=1, stride=1, padding=0)) 
        
    def forward(self, inPut1, inPut2):
        data1 = inPut1.view(-1, 1, 512)
        data2 = inPut2.view(-1, 1, 512)
        
        out1 = self.relu(self.preLayer(data1))
        out2 = self.relu(self.preLayer(data2))
        
        for layer in self.firstModule:
            out1 = layer(out1)
            out2 = layer(out2)
            
        out = torch.sub(out1, out2)
        #out = torch.cat((out1, out2), dim=1)
        for layer in self.secondModule:
            out = layer(out)
            
        out = self.fc(out)
        out = self.linear1(out)
        
        return out

In [2]:
import torch
device = torch.device('cpu')
myNet = ResNet(firstChannels=[32, 48, 72], secondChannels=[72, 48, 32], 
               firstBlockRepeats=[1, 1, 1], secBlockRepeats=[1, 1, 1], 
               downsampling=[False, True, True], upsampling=[False, True, True], 
               useUpsampling=True, outSizeChange=True)

In [3]:
from scipy.signal import decimate, resample
from scipy import interpolate
import numpy as np
import os

modelDict = torch.load('ResNet_Max.pth', map_location='cpu')
myNet.load_state_dict(modelDict['modelState'])

reSam, norDataLen = 3072, 512
xNew = np.linspace(201, 3500, num=3300, endpoint=True, dtype=np.int)

basePATH = 'C:\\Users\\CHENTIEJUN\\Desktop\\dataset'

In [4]:
basePATH = 'C:\\Users\\CHENTIEJUN\\Desktop\\dataset'
dataInfoPATH = os.path.join(basePATH, 'dataInfo20.txt')
dataLibraryPATH = os.path.join(basePATH, 'dataLibrary20.txt')
dataPath = os.path.join(basePATH, 'exData')

dataInfo = np.loadtxt(dataInfoPATH, dtype=str, delimiter='\n')
dataLibrary = np.loadtxt(dataLibraryPATH)

myNet.eval()
with torch.no_grad():
    for fileName in os.listdir(dataPath):
        filePATH = os.path.join(dataPath, fileName)
        outScore = np.zeros(len(dataLibrary))
        
        rawData = np.loadtxt(filePATH, comments='', unpack=True)

        cubicFunc = interpolate.interp1d(rawData[0], rawData[1], kind='next', fill_value='extrapolate')
        dataC = cubicFunc(xNew) #interpolation

        dataR = resample(dataC, reSam)
        dataX = np.trim_zeros(dataR, trim='b')
        dataLenX = len(dataX)
        num = dataLenX // norDataLen
        dataLenX //= num

        z = decimate(dataX, num, ftype='iir')
        dataMin, dataMax = z.min(), z.max()
        z = (z-dataMin) / (dataMax-dataMin)

        for i, material in enumerate(dataInfo):
            inPut1 = z
            inPut2 = dataLibrary[i]

            d1 = torch.unsqueeze(torch.tensor(inPut1, dtype=torch.float32), dim=0).to(device)
            d2 = torch.unsqueeze(torch.tensor(inPut2, dtype=torch.float32), dim=0).to(device)
            outPut = myNet(d1, d2)
            outScore[i] = torch.sigmoid(outPut)

        scoreMaxSort = np.flip(np.argsort(outScore))

        print(f"Library: {dataInfo[scoreMaxSort[0]]}, test data: {fileName}, score: {outScore[scoreMaxSort[0]]}")

Library: 13DNB_01.txt, test data: 13DNB_01.txt, score: 1.0
Library: 13DNB_01.txt, test data: 13DNB_02.txt, score: 0.9999922513961792
Library: 13DNB_01.txt, test data: 13DNB_03.txt, score: 0.9999994039535522
Library: 13DNB_01.txt, test data: 13DNB_04.txt, score: 0.9999996423721313
Library: 13DNB_01.txt, test data: 13DNB_05.txt, score: 0.9999997615814209
Library: 13DNB_01.txt, test data: 13DNB_06.txt, score: 0.9999996423721313
Library: 13DNB_01.txt, test data: 13DNB_07.txt, score: 0.9999997615814209
Library: 13DNB_01.txt, test data: 13DNB_08.txt, score: 0.9999815225601196
Library: 13DNB_01.txt, test data: 13DNB_09.txt, score: 0.9999988079071045
Library: 13DNB_01.txt, test data: 13DNB_10.txt, score: 0.9999996423721313
Library: 13DNB_01.txt, test data: 13DNB_11.txt, score: 0.9999992847442627
Library: 13DNB_01.txt, test data: 13DNB_12.txt, score: 0.999997615814209
Library: 13DNB_01.txt, test data: 13DNB_13.txt, score: 0.9999984502792358
Library: 13DNB_01.txt, test data: 13DNB_14.txt, score:

Library: 2ADNT_01.txt, test data: 2ADNT_20.txt, score: 0.9998064637184143
Library: 2ADNT_01.txt, test data: 2ADNT_21.txt, score: 0.9999673366546631
Library: 2ADNT_01.txt, test data: 2ADNT_22.txt, score: 0.999995231628418
Library: 2ADNT_01.txt, test data: 2ADNT_23.txt, score: 0.9999974966049194
Library: 2ADNT_01.txt, test data: 2ADNT_24.txt, score: 0.99998939037323
Library: 2ADNT_01.txt, test data: 2ADNT_25.txt, score: 0.9999972581863403
Library: 2ADNT_01.txt, test data: 2ADNT_26.txt, score: 0.9999990463256836
Library: 2ADNT_01.txt, test data: 2ADNT_27.txt, score: 0.9999974966049194
Library: 2ADNT_01.txt, test data: 2ADNT_28.txt, score: 0.999993085861206
Library: 2ADNT_01.txt, test data: 2ADNT_29.txt, score: 0.999997615814209
Library: 2ADNT_01.txt, test data: 2ADNT_30.txt, score: 0.9954791069030762
Library: 2ADNT_01.txt, test data: 2ADNT_31.txt, score: 0.9992383718490601
Library: 2ADNT_01.txt, test data: 2ADNT_32.txt, score: 0.9998263716697693
Library: 2ADNT_01.txt, test data: 2ADNT_33.

Library: 4ADNT_02.txt, test data: 4ADNT_43.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_44.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_45.txt, score: 0.9999998807907104
Library: 4ADNT_02.txt, test data: 4ADNT_46.txt, score: 0.9999998807907104
Library: 4ADNT_02.txt, test data: 4ADNT_47.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_48.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_49.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_50.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_51.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_52.txt, score: 0.9999998807907104
Library: 4ADNT_02.txt, test data: 4ADNT_53.txt, score: 1.0
Library: 4ADNT_02.txt, test data: 4ADNT_54.txt, score: 1.0
Library: ADN_02.txt, test data: ADN_02.txt, score: 1.0
Library: ADN_02.txt, test data: ADN_03.txt, score: 1.0
Library: ADN_02.txt, test data: ADN_04.txt, score: 1.0
Library: ADN_02.txt, test data: ADN_05.txt, score: 1.0
Library: ADN_02.txt, test d

Library: AP_01.txt, test data: AP_39.txt, score: 1.0
Library: AP_01.txt, test data: AP_40.txt, score: 1.0
Library: AP_01.txt, test data: AP_41.txt, score: 1.0
Library: AP_01.txt, test data: AP_42.txt, score: 1.0
Library: AP_01.txt, test data: AP_43.txt, score: 1.0
Library: AP_01.txt, test data: AP_44.txt, score: 1.0
Library: AP_01.txt, test data: AP_45.txt, score: 1.0
Library: AP_01.txt, test data: AP_47.txt, score: 1.0
Library: AP_01.txt, test data: AP_48.txt, score: 1.0
Library: AP_01.txt, test data: AP_49.txt, score: 1.0
Library: AP_01.txt, test data: AP_50.txt, score: 1.0
Library: AP_01.txt, test data: AP_51.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_01.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_02.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_03.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_04.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_05.txt, score: 1.0
Library: DMDNB_01.txt, test data: DMDNB_06.txt, score: 1.0
Library: D

Library: HMX_01.txt, test data: HMX_26.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_27.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_28.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_29.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_30.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_31.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_32.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_33.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_34.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_35.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_36.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_37.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_38.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_39.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_40.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_41.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_42.txt, score: 1.0
Library: HMX_01.txt, test data: HMX_43.txt, score: 1.0
Library: H

Library: NTO_01.txt, test data: NTO_25.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_26.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_27.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_28.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_29.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_30.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_31.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_32.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_33.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_34.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_35.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_36.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_37.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_38.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_39.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_40.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_41.txt, score: 1.0
Library: NTO_01.txt, test data: NTO_42.txt, score: 1.0
Library: N

Library: RDX_01.txt, test data: RDX_24.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_25.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_26.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_27.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_28.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_29.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_30.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_31.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_32.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_33.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_34.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_35.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_36.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_37.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_38.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_39.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_40.txt, score: 1.0
Library: RDX_01.txt, test data: RDX_41.txt, score: 1.0
Library: R

Library: TMETN_01.txt, test data: TMETN_44.txt, score: 0.25772297382354736
Library: TMETN_01.txt, test data: TMETN_45.txt, score: 0.9595293998718262
Library: TMETN_01.txt, test data: TMETN_46.txt, score: 0.9565207958221436
Library: TMETN_01.txt, test data: TMETN_47.txt, score: 0.27618521451950073
Library: TMETN_01.txt, test data: TMETN_48.txt, score: 0.9633244276046753
Library: TMETN_01.txt, test data: TMETN_49.txt, score: 0.4507202208042145
Library: TMETN_01.txt, test data: TMETN_50.txt, score: 0.41415169835090637
Library: TNT_01.txt, test data: TNT_01.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_02.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_03.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_04.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_05.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_06.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_07.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_08.txt, score: 1.0
Library: TNT_01.txt, test data: TNT_09.