In [None]:
pip install -r requirements.txt

In [None]:
#wget https://github.com/libsndfile/libsndfile/releases/download/1.1.0/libsndfile-1.1.0.tar.xz
import urllib.request
url = 'https://github.com/libsndfile/libsndfile/releases/download/1.1.0/libsndfile-1.1.0.tar.xz'
filename = 'libsndfile-1.1.0.tar.xz'
urllib.request.urlretrieve(url, filename)
!tar -xf libsndfile-1.1.0.tar.xz
! (cd ./libsndfile-1.1.0/ && ./configure && make && make install)
 

In [6]:
import torch
import torchaudio
import torchvision

print(torch.__version__)
print(torchaudio.__version__)
print(torchvision.__version__)


1.12.1+cu113
0.12.1+cu113
0.13.1+cu113


In [None]:
!python -v

# Prepare data for the Training process

In [None]:
%run run.py -m train -ct -w 0.7 -s 0.08

# Preparing data for validating the model

In [None]:
%run run.py -m val -ct -w 0.7 -s 0.08

# Test the data generation process for a single file

In [None]:
!python run-singlefile.py -f ./testwav/nextel.wav -od ./outputsinglewav -sr 16000 -ct -w 0.7 -s 0.08

In [12]:
import torch
import torchaudio
from torch import nn
from torch.utils.data import DataLoader
from torchvision import models
import torchvision.transforms as T
from torch.utils.data import Dataset
import pandas as pd
import audioprocess as AP
import torch.nn as nn
import torchvision.models as models
import tensor_util as TU
import queue
import copy
from datetime import datetime
import torchaudio.functional as F
import os




class BeepDetectDataset(Dataset):

    def __init__(self, annotations_file,transformation,target_sample_rate,plot,debug):
        pd.options.display.max_seq_items = 2000
        self.annotations = pd.read_csv(annotations_file)
        self.transformation = transformation
        self.target_sample_rate = target_sample_rate
        self.plot = plot
        self.debug = debug
        self.smax = 0
        self.fmax = 0
        self.smin = 0
        self.fmin = 0

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        audio_sample_path = self._get_audio_sample_path(index)
        label = self._get_audio_sample_label(index)
        signal, sr = torchaudio.load(audio_sample_path)
        signal = self._resample_if_necessary(signal, self.target_sample_rate)
        signal = self._mix_down_if_necessary(signal)
        
        #adding noise
        #noise, signal = AP.preprocess_wav(signal,self.target_sample_rate,2)
        
        self.smax = (torch.max(signal))
        self.smin = (torch.min(signal))
        
        if  self.debug: 
            print(f'audio_sample_path {audio_sample_path} label {label}' )
            print (f' S Max {self.smax} S Min {self.smin}')

        if  self.plot:
            AP.plot_waveform(signal,self.target_sample_rate)
            #noise, newwav = AP.preprocess_wav(signal,self.target_sample_rate)
            #AP.plot_waveform(newwav,self.target_sample_rate)
        #signal = torch.round(signal, decimals=2) 
        
        if self.smax < 0.01 and self.smin > -0.01 :
                #print (signal)
                signal = signal.clone()
                signal[torch.logical_and(signal>=-0.01, signal<=0.01)] = 0
                #print (signal)
                
        
        signal = self.transformation(signal)
        
        self.fmax = (torch.max(signal))
        self.fmin = (torch.min(signal))
        signal = torch.round(signal, decimals=1)
        
        signal = signal.clone()
        
        #signal[torch.logical_and(signal>=0, signal<=1e-1)] = 1e-4 
        
        if signal.min() != signal.max():
        
            signal -= signal.min()
            signal /= signal.max()
        
        #signal = signal.log2()
        
        #self.fmax = (torch.max(signal))
        #self.fmin = (torch.min(signal))
        
        if self.debug:
            print (f' F Max {self.fmax} F Min {self.fmin}')
        if self.plot:
            print(f'SHAPE SIGNAL {signal.shape}')
            AP.plot_spectrogram(signal)
        preprocessedsignal = signal
        #print (signal.shape)
        signal = signal.repeat(3, 1, 1) 
        
        signal = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(signal)
        
        #print (signal.shape)
        return signal,preprocessedsignal, label, audio_sample_path,sr,self.fmin,self.fmax,self.smin,self.smax

    def _get_audio_sample_path(self, index):
        path = self.annotations.iloc[index, 1]
        return path

    def _get_audio_sample_label(self, index):
        label= self.annotations.iloc[index, 2]
        if label:
            label = 0
        else:
            label = 1
        return label
    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            signal = resampler(signal)
        return signal

    def _mix_down_if_necessary(self, signal):
        #if signal.shape[0] > 1:
        #    signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

def create_data_loader(train_data, batch_size,shaf=True):
    train_dataloader = DataLoader(train_data, batch_size=batch_size,shuffle=shaf)
    return train_dataloader


def train_single_epoch(model, data_loader, loss_fn, optimiser, device,ep,debug=False,evalu=False):
    lenght = len(data_loader.dataset)
    piv = 0
    piv_t=0
    n=0
    n0=0
    n1=0
    nsp0=0
    nsp1 = 0
            
    for input, preprocessedsignal, target, audio_sample_path,sr,fmin,fmax,smin,smax in data_loader:
        input, target = input.to(device), target.to(device)
        datasize = target.size(dim=0)
        piv = piv + datasize
        per = round(piv / lenght * 100) 
        #print (input.shape)
        #print (target)
        # calculate loss
        prediction = model(input)
        _, predicted = torch.max(prediction, 1)
        loss = loss_fn(prediction, target)
                
        for dt in range(datasize):
            targetint = int (target[dt].int())
            predictedint = int (predicted[dt].int())

            if targetint == 1:
                n1=n1+1
                if predictedint == 1:
                    nsp1=nsp1+1
            elif targetint == 0:
                n0=n0+1
                if predictedint == 0:
                    nsp0=nsp0+1       

        # backpropagate error and update weights
        if evalu == False:
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
        print (f'\r v1 epoch {ep} - {per}% loss {loss} -- fmin {torch.min(fmin)} fmax {torch.max(fmax)} smin {torch.min(smin)} smax {torch.max(smax)}', end="")
        #print (f'\r n0 {n0} n1 {n1} nsp0 {nsp0} nsp1 {nsp1} n {piv}',end="")

    print(f"\rloss: {loss.item()}")
    return  piv,n0,n1,nsp0,nsp1


def train(model, data_loader, loss_fn, optimiser, device, epochs,debug=False,evalu=False,useckp=True):
    n=0
    n0=0
    n1=0
    nsp0=0
    nsp1=0
    dtstart = datetime.now()
    lowestloss = 100000
    #overalloss = 0
    epochloss=[]
    start_epoch=1
    print ("Starting Training")
    
    if useckp:
        if os.path.isfile("./checkpoint/checkpoint.pt"):
            model, optimiser, start_epoch, lowestloss,epochloss,n0,n1,nsp0,nsp1,n = load_ckp(model, optimiser)
            #lowestloss = ((n0-nsp0)+(n1-nsp1))/(n0+n1)
            print(f"Loading checkpoint start epoch {start_epoch}, lowestloss {lowestloss}, n0 {n0}, n1 {n1}, nsp0 {nsp0}, nsp1 {nsp1},n {n}")
            start_epoch=start_epoch+1
        else:
            print ("NO CHECK POINT FILE ..SKIPPING")
    if evalu == True:
            print(f'Evaluation mode')
            model.eval()
    else:
            model.train()
    
    for i in range(start_epoch,epochs):
        print(f"Epoch {i}")
        dtepochstart = datetime.now()

        n,n0,n1,nsp0,nsp1 = train_single_epoch(model, data_loader, loss_fn, optimiser, device, i,debug,evalu)
        print("---------------------------")
        # save model
        overalloss = ((n0-nsp0)+(n1-nsp1))/(n0+n1)
        dtepochstop = datetime.now()
        epochloss.append(overalloss)
        #checkpoint={}
        checkpoint = {
        'epoch': i,
        'state_dict': model.state_dict(),
        'optimizer': optimiser.state_dict(),
        'loss': overalloss, 
        'epochloss': epochloss,
        'n0': n0, 
        'n1': n1, 
        'nsp0': nsp0,
        'nsp1': nsp1,
        'n': n
        }
        print ("epoch lossess trend:",epochloss)
        if (evalu == False) and (overalloss < lowestloss):
            print (f'Overall loss {overalloss} < lowestloss {lowestloss} saving {STD_MODEL_WEIGHTS_BEST}')
            torch.save(model.state_dict(), STD_MODEL_WEIGHTS_BEST)
            save_ckp('./checkpoint/checkpoint_'+str(i)+'_'+str(overalloss)+'.pt',checkpoint)
            lowestloss = overalloss
        print (f'Overall loss {overalloss} n0 {n0} n1 {n1} nsp0 {nsp0} nsp1 {nsp1} n {n} time {dtepochstop-dtepochstart}')

        if (evalu == False) and (i%5 == 0):
            torch.save(model.state_dict(), STD_MODEL_WEIGHTS)
            print("Net saved at feedforwardnet.pth")
        
        checkpoint = {
        'epoch': i,
        'state_dict': model.state_dict(),
        'optimizer': optimiser.state_dict(),
        'loss': lowestloss, 
        'epochloss': epochloss,
        'n0': n0, 
        'n1': n1, 
        'nsp0': nsp0,
        'nsp1': nsp1,
        'n': n
        }    
        
        print("Saving ./checkpoint/checkpoint.pt")
        save_ckp('./checkpoint/checkpoint.pt',checkpoint)
    
    dtstop = datetime.now()      
    print(f'Finished training in {dtstop-dtstart}')

def validate_model(model,data_loader,loss_fn,device,debug):
      train(model, data_loader, loss_fn, None , device,1,debug,True)
        
def predict(model, input, expected):
    with torch.no_grad():
        prediction = model(input)
        _, predicted = torch.max(prediction, 1) 
        predictedsoft = torch.softmax(prediction,1)
        #predictedsoft = predicted
    return predicted, expected ,predictedsoft


def countvaluesinwin(slwin):
    count = 0
    slwinc = queue.Queue() 
    slwinc.queue = copy.deepcopy(slwin.queue)
    
    while not slwinc.empty():
        value = slwinc.get() 
        #print("value",value)
        if value == 0:
            count += 1
    return count

def scanfileforbeep(values):
    slidingwin = queue.Queue()
    countof0 = 0
    piv = 0
    beepdetected = False
    start = True

    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    
    isthereabeep = False
    
    for va in values:
        rem = slidingwin.get(False)
        slidingwin.put(va)
        
        countof0 = countvaluesinwin(slidingwin)
        if countof0 > 3:
                beepdetected = True
                isthereabeep = True
        else:
                beepdetected = False
        #print(va,countof0,beepdetected)

    return isthereabeep

def resample_if_necessary(signal,target_sample_rate, sr):
        if sr != target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
            signal = resampler(signal)
        return signal

def mix_down_if_necessary(signal):
        if signal.shape[0] > 1:
            signal = torch.mean(signal, dim=0, keepdim=True)
        return signal

def getfinalsignal(wavetensor,samplerate,transformation,debug,plot):
        
        signal = resample_if_necessary(wavetensor,samplerate,samplerate)
        signal = mix_down_if_necessary(signal)

        smax = (torch.max(signal))
        smin = (torch.min(signal))

        if  debug: 
            print(f'samplerate {samplerate}' )
            print (f' S Max {smax} S Min {smin}')

        if smax < 0.01 and smin > -0.01 :
                #print (signal)
                signal = signal.clone()
                signal[torch.logical_and(signal>=-0.01, signal<=0.01)] = 0
                #print (signal)
        if  plot:
            AP.plot_waveform(signal,samplerate)
        
        ns,wav = AP.preprocess_wav(signal,samplerate,2)
        pitch = F.detect_pitch_frequency(wav, samplerate)
        
        if  plot:
            AP.plot_waveform(wav,samplerate)
            AP.plot_waveform(pitch,samplerate)

        
        #signal = torch.round(signal, decimals=2) 
        
            
        signal = transformation(wav)

        fmax = (torch.max(signal))
        fmin = (torch.min(signal))
        signal = torch.round(signal, decimals=1)
        
        #signal = round(signal,1)
        #signal[signal < 200] = 0
        
        signal = signal.clone()

        #signal[torch.logical_and(signal>=0, signal<=1e-1)] = 1e-4 

        if signal.min() != signal.max():

            signal -= signal.min()
            signal /= signal.max()
        
        #if  plot:
        #    AP.plot_waveform(signal,samplerate)
        #signal = signal.log2()

        #self.fmax = (torch.max(signal))
        #self.fmin = (torch.min(signal))

        if debug:
            print (f' F Max {fmax} F Min {fmin}')
        if plot:
            print(f'SHAPE SIGNAL {signal.shape}')
            AP.plot_spectrogram(signal)
        preprocessedsignal = signal
        #print (signal.shape)
        signal = signal.repeat(3, 1, 1) 
        signal = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(signal)
        signal = signal[None]
        #print (signal.shape)
        return signal

def preparedata (filename,samplerate,transformation,debug,plot):
            signal, sr = torchaudio.load(filename)
            return getfinalsignal(signal,samplerate,transformation,debug,plot)

def doeswavhavebeep(outputdir):
        annotations = pd.read_csv(outputdir+"/wavlist.csv")
        files = annotations['slices'].tolist()
        li = []
        for file in files:
            res = preparedata (file,16000,transform_spectra,False,False)
            predicted, expected,predictedsoft = predict(new_model_pred,res,0)
            predicted = int(predicted.int())
            print(f'file {file} predicted {predicted}')
            li.append(predicted)
        return scanfileforbeep(li)

def scan_file_for_beep(model,device,file,window,stride,targetsamplerate,debug1,debug,plot,retimmediately=False):
    
    #print(f'Processing file {file} window {window} stride {stride} ')

    waveformpre, sample_rate, duration = AP.load_audio_from_file(file)
    #print(f'sample_rate {sample_rate} duration {duration}')
    waveformpre, sample_rate = AP.resample_wav_plus_mono(waveformpre,sample_rate,targetsamplerate)
    #print(f'sample_rate {sample_rate}')

    realshift = 0

    waveform = TU.sub_from_begin_2nd_tensor(waveformpre,realshift)
    slices,number = AP.slice_the_audio_w_0_pad(waveform,sample_rate,window,stride)
    
    #print(f'Generated {number} slices')
    
    tslices = []

    slidingwin = queue.Queue()

    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    slidingwin.put(1)
    
    istherebeep = False
    timeofbeep = -1
    
    dtstart = datetime.now()
    
    for i in range (number):
        wavtensor = slices.select(1,i)
        #tslices.append(wavtensor)
        signal = getfinalsignal(wavtensor,targetsamplerate,transform_spectra,debug,plot)
        signal = signal.to(device)
        predicted, expected,predictedsoft = predict(model,signal,0)
        predicted = int(predicted.int())
        
        win = window*targetsamplerate
        st = stride*targetsamplerate
        
        fwin = win+st*i
        time = fwin/targetsamplerate
        
        rem = slidingwin.get(False)
        slidingwin.put(predicted)
        countof0 = countvaluesinwin(slidingwin)
        if countof0 > 3:
            istherebeep = True
            if timeofbeep == -1:
                timeofbeep = time
                if retimmediately:
                    return istherebeep, timeofbeep
                else:
                    istherebeep= False
                    
        if debug1:
            print (f'slice {i} beepinwindow {predicted==0} beepfound {istherebeep} time {time} timeofbeep {timeofbeep}' )
    
    dtstop = datetime.now()
    if retimmediately == False:
        AP.plot_waveform_withline(waveformpre,sample_rate,timeofbeep,"beep")
    if debug1:
        print(f'Processing Start {dtstart} Stop {dtstop} duration {dtstop-dtstart}')
    return istherebeep, timeofbeep

'''
def round(x, decimals=0):
    b = 10**decimals
    return torch.round(x*b)/b
'''

def save_ckp(f_path,state):
    torch.save(state, f_path)

def load_ckp(model, optimizer):
    f_path = './checkpoint/checkpoint.pt'
    checkpoint = torch.load(f_path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch'],checkpoint['loss'],checkpoint['epochloss'],checkpoint['n0'],checkpoint['n1'],checkpoint['nsp0'],checkpoint['nsp1'],checkpoint['n']


In [13]:
class Net(nn.Module):

    def __init__(self):
        super().__init__()
        # Use a pretrained model
        self.network = models.resnet152(pretrained=True)
        # Replace last layer
        #for param in self.network.parameters():
        #    param.requires_grad = False
        self.network.conv1=nn.Conv2d(1, self.network.conv1.out_channels, 
                      kernel_size=self.network.conv1.kernel_size[0], 
                      stride=self.network.conv1.stride[0], 
                      padding=self.network.conv1.padding[0])
        num_ftrs = self.network.fc.in_features
        self.network.fc = nn.Linear(num_ftrs, 2)
    def forward(self, xb):
        return self.network(xb)
    def freeze(self):
        for param in self.network.parameters():
            param.require_grad = False
        for param in self.network.fc.parameters():
            param.require_grad = True
    def unfreeze(self):
        for param in self.network.parameters():
            param.require_grad = True

class Net2(nn.Module):

    def __init__(self):
        super().__init__()
        # Use a pretrained model
        self.network = models.resnet34(pretrained=False)
        # Replace last layer
        #for param in self.network.parameters():
        #    param.requires_grad = False
        num_ftrs = self.network.fc.in_features
        self.network.fc = nn.Linear(num_ftrs, 2)
    def forward(self, xb):
        return self.network(xb)
    
def _freeze_norm_stats(net):
    try:
        for m in net.modules():
            if isinstance(m, nn.BatchNorm2d):
                #m.track_running_stats = False
                m.train()

    except ValueError:  
        print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
        return
    
def _set_batch_momentum(net,mom):
    try:
        for m in net.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.momentum = mom
                #m.train()

    except ValueError:  
        print("errrrrrrrrrrrrrroooooooorrrrrrrrrrrr with instancenorm")
        return

        

In [14]:
BATCH_SIZE = 1
EPOCHS = 600
LEARNING_RATE = 0.01
ANNOTATIONS_FILE = "./outputpretrainwav/wavlist.csv"
VALIDATION_FILE = "./outputpretrainwav_val/wavlist.csv"
SAMPLE_RATE=16000
size=224
STD_MODEL_WEIGHTS="feedforwardnet-temp.pth"
STD_MODEL_WEIGHTS_BEST='feedforwardnet-best.pth'

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
print(f"Using {device}")

#device="cpu"
# instantiating our dataset object and create data loader
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    win_length = 2048,
    n_fft=2048,
    hop_length=256,
    n_mels=128,
    f_max=4500,
    f_min=200,
    normalized = False
)


transform_spectra = T.Compose([
    mel_spectrogram,
    T.Resize((size,size)),
])
 


Using cuda


# Train the model

In [None]:
usd = BeepDetectDataset(ANNOTATIONS_FILE,
                        transform_spectra,
                        SAMPLE_RATE,False,False)

train_dataloader = create_data_loader(usd, BATCH_SIZE)
model_ft = Net2()

model_ft = model_ft.to(device)
_set_batch_momentum(model_ft,0.9)


# initialise loss funtion + optimiser
loss_fn = nn.CrossEntropyLoss()
#loss_fn = nn.BCELoss()
optimiser = torch.optim.Adam(model_ft.parameters(),
                             lr=LEARNING_RATE)

#optimiser = torch.optim.SGD(model_ft.parameters(),lr=LEARNING_RATE,momentum=0.9)
# train model
train(model_ft, train_dataloader, loss_fn, optimiser, device, EPOCHS)

# save model
torch.save(model_ft.state_dict(), STD_MODEL_WEIGHTS)
print("Trained feed forward net saved at feedforwardnet.pth")



Starting Training
NO CHECK POINT FILE ..SKIPPING
Epoch 1
[2024-01-21 11:28:25.837 pytorch-1-12-gpu-py-ml-g4dn-xlarge-fb2c43dabd38fe5c0df5e678313a:67 INFO utils.py:28] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2024-01-21 11:28:26.020 pytorch-1-12-gpu-py-ml-g4dn-xlarge-fb2c43dabd38fe5c0df5e678313a:67 INFO profiler_config_parser.py:111] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.




 v1 epoch 1 - 34% loss 0.000825898430775851 -- fmin 0.0008494672947563231 fmax 1146.7117919921875 smin -0.20026475191116333 smax 0.310861945152282730273

# Validate the model

In [None]:
VALIDATION_FILE = "./outputpretrainwav_val/wavlist.csv"
mod_w= STD_MODEL_WEIGHTS_BEST
state_dict = torch.load(mod_w)
model_val = Net2().to(device)
model_val.load_state_dict(state_dict)

usdval = BeepDetectDataset(VALIDATION_FILE,
                        transform_spectra,
                        SAMPLE_RATE,False,False)

val_dataloader = create_data_loader(usdval,1)
validate_model(model_val, val_dataloader,loss_fn, device,False)


The following cell will iterate the CSV file and loops through each FILE in search of a beep.

In [11]:
%matplotlib inline

model_weights =STD_MODEL_WEIGHTS_BEST

new_state_dict = torch.load(model_weights)

new_model_pred= Net2().to(device)
new_model_pred.load_state_dict(new_state_dict)

print("Model LOADING STATE")

files = {
        './testwav/machine-tape-035a.wav':'-1',
        './testwav/nextel.wav':'11.021',
        './testwav/sprint.wav':'20.690',
        './testwav/verizon.wav':'20.600'
         }
         

debug=False
single=False
new_model_pred.eval()
_freeze_norm_stats(new_model_pred)

print("Model READY")

count_faild=0
failed_files=[]

if single:
    isabeep, time = scan_file_for_beep(new_model_pred,device,File_name,0.7,0.08,16000,True,debug,debug,False)   
    print (f'File {File_name} isbeep {isabeep} predicted time {time}') 
else:
    tests = files.keys()
    for f in tests:
        File_name=f
        #isabeep, time = scan_file_for_beep(new_model_pred,File_name,1,0.15,16000,False,debug,debug,True)   
        isabeep, time = scan_file_for_beep(new_model_pred,device,File_name,0.7,0.08,16000,False,debug,False,True)   
        
        timediff=abs(float(time)-float(files[f]))
        
        if timediff > 0.08*4:
            count_faild=count_faild+1
            failed_files.append(File_name)
        print (f'File {File_name} isbeep {isabeep} predicted time {time} expected {files[f]}') 
    if count_faild > 0:
        print("Test run FAILED - n of TC failed & filenames",count_faild,failed_files)
    else:
        print("Test run SUCCEDED")




Model LOADING STATE
Model READY
File ./testwav/machine-tape-035a.wav isbeep False predicted time -1 expected -1
File ./testwav/nextel.wav isbeep True predicted time 11.34 expected 11.021
File ./testwav/sprint.wav isbeep True predicted time 20.94 expected 20.690
File ./testwav/verizon.wav isbeep True predicted time 20.86 expected 20.600
Test run SUCCEDED


In [None]:
def load_ckp(model, optimizer):
    f_path = './checkpoint/checkpoint.pt'
    checkpoint = torch.load(f_path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch'],checkpoint['loss'],checkpoint['epochloss'],checkpoint['n0'],checkpoint['n1'],checkpoint['nsp0'],checkpoint['nsp1'],checkpoint['n']
