In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
#export
###############################  part-1 : import lib ###########################
import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
from   torch import optim
import torch.nn as nn
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid
import csv
import re
from torch.autograd import Variable
from fastai.basics import *

acc_threshold = 1/22
G_inChans  = 1
G_outChans = 3

#export
##############################  part-1 : dataset  ###########################
def get_file_path(root_path,file_list,dir_list):
    dir_or_files = os.listdir(root_path)
    for dir_file in dir_or_files:
        dir_file_path = os.path.join(root_path,dir_file)
        if os.path.isdir(dir_file_path):
            dir_list.append(dir_file_path)
            get_file_path(dir_file_path,file_list,dir_list)
        else:
            file_list.append(dir_file_path)
            
class jax_Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        super(jax_Dataset,self).__init__()
        self.root_dir  = root_dir
        self.transform = transform
        self.size      = 0
        self.csv_list      = []
        self.npyfile_list  = []
        self.label_list    = []
        self.all_file_list = [] 
        self.dir_list      = []
        get_file_path(self.root_dir,self.all_file_list,self.dir_list)
        for f in self.dir_list: 
            self.csv_list.append(f+'/label.csv')    
        for csv_path in self.csv_list:
            with open(csv_path, 'r') as csvfile:
                header_row = next(csvfile)
                remain_rows = csv.reader(csvfile)
                for row in remain_rows:
                    d_path = row[0]
                    d_labels = row[1:G_outChans+1]
                    d_labels_float = [float(item) for item in d_labels] # float list
                    d_labels_torch = torch.from_numpy(np.array(d_labels_float))
                    self.npyfile_list.append(root_dir+'/'+d_path)
                    self.label_list.append(d_labels_torch)
                    self.size += 1             
    def __len__(self):
        return self.size
    def __getitem__(self, idx):
        if idx > self.size:
            print("dataset out of index")
            return None
        mat_path = self.npyfile_list[idx]
        if not os.path.isfile(mat_path):
            print(mat_path + 'does not exist!')
            return None
        tensor_6d = np.load(mat_path)
        label_6d  = self.label_list[idx]
        if self.transform:
            tensor_6d = self.transform(tensor_6d)
        return tensor_6d,label_6d

def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

class DataBunch():
    def __init__(self, train_dl, valid_dl, c=None):
        self.train_dl,self.valid_dl,self.c = train_dl,valid_dl,c
    @property
    def train_ds(self): return self.train_dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dataset
    
#export
##############################  part-2 : model  ###########################

def accuracy(out,yb): 
    tmp = (torch.abs(out-yb) <= acc_threshold).float()
    return tmp.mean()
def errsmean(out,yb): 
    tmp = torch.abs(out-yb)
    indx = (tmp > acc_threshold).float()
    largpart = indx*tmp
    return largpart.sum() / indx.sum()


def flatten(x):      return x.view(x.shape[0], -1)
def linearMap(x):
    xmax = torch.max(x)
    xmin = torch.min(x)
    if abs(xmax-xmin):
        xret = (x-xmin)/(xmax-xmin)
    else:
        xret = torch.zeros(x.shape)
    return xret

class Lambda(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func
    def forward(self, x): return self.func(x)
    
def get_model_layers_3d(inChans = G_inChans,outChans = G_outChans):
    layers = [  nn.Sequential( nn.Conv3d( in_channels=inChans,out_channels=64,kernel_size=3,stride=1,
                                          padding=0,groups=1,bias=True,padding_mode='zeros'  )  , #添加 maxpooling
                                          nn.BatchNorm3d(64),nn.ReLU()),  # 7*7
                nn.Sequential( nn.Conv3d( in_channels=64,out_channels=128,kernel_size=3,stride=1,
                                          padding=0,groups=1,bias=True,padding_mode='zeros'  )  , 
                                          nn.BatchNorm3d(128), nn.ReLU()), # 5*5 
                nn.Sequential( nn.Conv3d( in_channels=128,out_channels=256,kernel_size=3,stride=1,
                                          padding=0,groups=1,bias=True,padding_mode='zeros'  )  , 
                                          nn.BatchNorm3d(256),nn.ReLU()), # 3*3           
                nn.Sequential( nn.Conv3d( in_channels=256,out_channels=512,kernel_size=3,stride=1,
                                          padding=0,groups=1,bias=True,padding_mode='zeros'  )  ,
                                          nn.BatchNorm3d(512), nn.ReLU() ), # 1*1
                Lambda(flatten)   ,
                nn.Linear(512,256),
                nn.ReLU()         ,
                nn.Linear(256,64),
                nn.ReLU()         ,
                nn.Linear(64,16),
                nn.ReLU()         ,
                nn.Linear(16,outChans)
              ]
    return layers

##############################  part-2 : model-init ###########################
def model_init(model):
    for l in model:
        if isinstance(l, nn.Sequential):
            nn.init.kaiming_normal_(l[0].weight)
            
#export
from typing import *
def listify(o):
    if o is None: return []
    if isinstance(o, list): return o
    if isinstance(o, str): return [o]
    if isinstance(o, Iterable): return list(o)
    return [o]
#export
##########################  Optimizer ########################
def compose(x, funcs, *args, order_key='_order', **kwargs):
    key = lambda o: getattr(o, order_key, 0)
    for f in sorted(listify(funcs), key=key): x = f(x, **kwargs)
    return x
  
def sgd_step(p, lr, **kwargs):
    p.data.add_(-lr, p.grad.data)
    return p

#weight decay#
def weight_decay(p, lr, wd, **kwargs):
    p.data.mul_(1 - lr*wd)
    return p
weight_decay._defaults = dict(wd=0.)

def l2_reg(p, lr, wd, **kwargs):
    p.grad.data.add_(wd, p.data)
    return p
l2_reg._defaults = dict(wd=0.)

def maybe_update(os, dest, f):
    for o in os:
        for k,v in f(o).items():
            if k not in dest: dest[k] = v

def get_defaults(d): return getattr(d,'_defaults',{})

class Optimizer():
    def __init__(self, params, steppers, **defaults):
        self.steppers = listify(steppers)
        maybe_update(self.steppers, defaults, get_defaults)
        self.param_groups = list(params)
        
        if not isinstance(self.param_groups[0], list): self.param_groups = [self.param_groups]
        self.hypers = [{**defaults} for p in self.param_groups]

    def grad_params(self):
        return [(p,hyper) for pg,hyper in zip(self.param_groups,self.hypers)
            for p in pg if p.grad is not None]

    def zero_grad(self):
        for p,hyper in self.grad_params():
            p.grad.detach_()
            p.grad.zero_()

    def step(self):
        for p,hyper in self.grad_params(): compose(p, self.steppers, **hyper)
            
def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b): test(a,b,operator.eq,'==')
    
sgd_opt = partial(Optimizer, steppers=[weight_decay, sgd_step])

In [4]:
#export
######################  Merge  Runner && Learner to a new :Learner ############
def param_getter(m): return m.parameters()

class Learner():
    def __init__(self, model, data, loss_func, cb_funcs, opt_func, lr, splitter=param_getter,cbs=None ):
        self.model,self.data,self.loss_func = model,data,loss_func
        self.opt_func,self.lr,self.splitter = opt_func,lr,splitter
        self.in_train,self.logger,self.opt  = False,print,None
        self.cbs = []
        self.add_cb(TrainEvalCallback())
        self.add_cbs(cbs)
        self.add_cbs(cbf() for cbf in listify(cb_funcs))

    def add_cbs(self, cbs):
        for cb in listify(cbs): self.add_cb(cb)
            
    def add_cb(self, cb):
        cb.set_runner(self) 
        setattr(self, cb.name, cb)
        self.cbs.append(cb)     

    def remove_cbs(self, cbs):
        for cb in listify(cbs): self.cbs.remove(cb)
            
    def one_batch(self, i, xb, yb):
        try:
            self.iter = i
            xb = xb.unsqueeze(1)          
            xb = xb.float()
            yb = yb.float()
            self.xb,self.yb = xb,yb;                        self('begin_batch') 
            self.pred = self.model(self.xb);                self('after_pred')
            self.loss = self.loss_func(self.pred, self.yb);
            self('after_loss') 
            if not self.in_train: return
            self.loss.backward();                           self('after_backward')
            self.opt.step();                                self('after_step')
            self.opt.zero_grad()
        except CancelBatchException:                        self('after_cancel_batch')
        finally:                                            self('after_batch')

    def all_batches(self):
        self.iters = len(self.dl)
        try:
            for i,(xb,yb) in enumerate(self.dl): self.one_batch(i, xb, yb)
        except CancelEpochException: self('after_cancel_epoch')

    def do_begin_fit(self, epochs):
        self.epochs,self.loss = epochs,tensor(0.)
        self('begin_fit')

    def do_begin_epoch(self, epoch):
        self.epoch,self.dl = epoch,self.data.train_dl
        return self('begin_epoch')

    def fit(self, epochs, cbs=None, reset_opt=False):
        self.add_cbs(cbs)
        if reset_opt or not self.opt: self.opt = self.opt_func(self.splitter(self.model), lr=self.lr)
        try:
            self.do_begin_fit(epochs)
            for epoch in range(epochs):
                if not self.do_begin_epoch(epoch): self.all_batches()

                with torch.no_grad(): 
                    self.dl = self.data.valid_dl
                    if not self('begin_validate'): self.all_batches()
                self('after_epoch')
            
        except CancelTrainException: self('after_cancel_train')
        finally:
            self('after_fit')
            self.remove_cbs(cbs)

    ALL_CBS = {'begin_batch', 'after_pred', 'after_loss', 'after_backward', 'after_step',
        'after_cancel_batch', 'after_batch', 'after_cancel_epoch', 'begin_fit',
        'begin_epoch', 'begin_validate', 'after_epoch',
        'after_cancel_train', 'after_fit'}
    
    def __call__(self, cb_name):                         
        res = False
        assert cb_name in self.ALL_CBS
        for cb in sorted(self.cbs, key=lambda x: x._order): res = cb(cb_name) and res
        return res

In [5]:
#export
import re
_camel_re1 = re.compile('(.)([A-Z][a-z]+)')
_camel_re2 = re.compile('([a-z0-9])([A-Z])')
def camel2snake(name):
    s1 = re.sub(_camel_re1, r'\1_\2', name)
    return re.sub(_camel_re2, r'\1_\2', s1).lower()

class Callback():
    _order=0  
    def set_runner(self, run): self.run=run
    def __getattr__(self, k):
        try:
            return getattr(self.run, k)
        except KeyError:
            raise AttributeError(r"Jax said Clbk: object has no attribute '%s'" % k)
    @property
    def name(self): 
        name = re.sub(r'Callback$', '', self.__class__.__name__)
        return camel2snake(name or 'callback') 

    def __call__(self, cb_name):
        f = getattr(self, cb_name, None)
        if f and f(): return True
        return False
    
class TrainEvalCallback(Callback):
    _order = 3
    def begin_fit(self):
        self.run.n_epochs=0.
        self.run.n_iter=0
    
    def after_batch(self):
        if not self.in_train: return
        self.run.n_epochs += 1./self.iters
        self.run.n_iter   += 1
        
    def begin_epoch(self):
        self.run.n_epochs=self.epoch
        self.model.train() 
        self.run.in_train=True

    def begin_validate(self):
        self.model.eval() 
        self.run.in_train=False

class CancelTrainException(Exception): pass                                      
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

import time
from fastprogress.fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import format_time

class AvgStats():
    def __init__(self, metrics, in_train): self.metrics,self.in_train = listify(metrics),in_train
    def reset(self):
        self.tot_loss,self.count = 0.,0
        self.tot_mets = [0.] * len(self.metrics)
    @property
    def all_stats(self): return [self.tot_loss.item()] + self.tot_mets
    @property
    def avg_stats(self): return [o/self.count for o in self.all_stats] 
                                  
    def __repr__(self):
        if not self.count: return ""
        return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"

    def accumulate(self, run):
        bn = run.xb.shape[0] 
        self.tot_loss += run.loss * bn 
        self.count += bn
        for i,m in enumerate(self.metrics): 
            self.tot_mets[i] += m(run.pred, run.yb) * bn 

class AvgStatsCallback(Callback):
    def __init__(self, metrics):
        self.train_stats,self.valid_stats = AvgStats(metrics,True),AvgStats(metrics,False)
    
    def begin_fit(self):
        met_names = ['loss'] + [m.__name__ for m in self.train_stats.metrics]
        names = ['epoch'] + [f'train_{n}' for n in met_names] + [
            f'valid_{n}' for n in met_names] + ['time'] + ['lr']
        self.logger(names) 
    
    def begin_epoch(self):
        self.train_stats.reset()
        self.valid_stats.reset()
        self.start_time = time.time()
        
    def after_loss(self):
        stats = self.train_stats if self.in_train else self.valid_stats
        with torch.no_grad(): stats.accumulate(self.run) 
    
    def after_epoch(self):
        stats = [str(self.epoch)] 
        for o in [self.train_stats, self.valid_stats]:
            stats += [f'{v:.6f}' for v in o.avg_stats] 
        stats += [format_time(time.time() - self.start_time)] 
        stats += [self.lr]
        self.logger(stats)
        
class ProgressCallback(Callback):
    _order=-1
    def begin_fit(self):
        self.mbar = master_bar(range(self.epochs))
        self.mbar.on_iter_begin()
        self.run.logger = partial(self.mbar.write, table=True)
        
    def after_fit(self): self.mbar.on_iter_end()
    def after_batch(self): self.pb.update(self.iter)
    def begin_epoch   (self): self.set_pb()
    def begin_validate(self): self.set_pb()
        
    def set_pb(self):
        self.pb = progress_bar(self.dl, parent=self.mbar)#, auto_update=False)
        self.mbar.update(self.epoch)

#export
################################## CUDA  #############################
class CudaCallback(Callback):
    def __init__(self,device): self.device=device 
    def begin_fit(self): self.model.to(self.device) 
    def begin_batch(self): self.run.xb,self.run.yb = self.xb.to(self.device),self.yb.to(self.device)            

In [6]:
class StatefulOptimizer(Optimizer):
    def __init__(self, params, steppers, stats=None, **defaults):
        self.stats = listify(stats)
        maybe_update(self.stats, defaults, get_defaults)
        super().__init__(params, steppers, **defaults)
        self.state = {}

    def step(self):
        for p,hyper in self.grad_params():
            if p not in self.state:
                self.state[p] = {}
                maybe_update(self.stats, self.state[p], lambda o: o.init_state(p))
            state = self.state[p]
            for stat in self.stats: state = stat.update(p, state, **hyper)
            compose(p, self.steppers, **state, **hyper)
            self.state[p] = state
            
class Stat():
    _defaults = {}
    def init_state(self, p): raise NotImplementedError
    def update(self, p, state, **kwargs): raise NotImplementedError

def momentum_step(p, lr, grad_avg, **kwargs):
    p.data.add_(-lr, grad_avg)
    return p

def lin_comb(v1, v2, beta): return beta*v1 + (1-beta)*v2


class AverageGrad(Stat):
    _defaults = dict(mom=0.9)
    
    def __init__(self, dampening:bool=False): self.dampening=dampening
    def init_state(self, p): return {'grad_avg': torch.zeros_like(p.grad.data)}
    def update(self, p, state, mom, **kwargs):
        state['mom_damp'] = 1-mom if self.dampening else 1.
        state['grad_avg'].mul_(mom).add_(state['mom_damp'], p.grad.data)
        return state
    
class AverageSqrGrad(Stat):
    _defaults = dict(sqr_mom=0.99)
    
    def __init__(self, dampening:bool=True): self.dampening=dampening
    def init_state(self, p): return {'sqr_avg': torch.zeros_like(p.grad.data)}
    def update(self, p, state, sqr_mom, **kwargs):
        state['sqr_damp'] = 1-sqr_mom if self.dampening else 1.
        state['sqr_avg'].mul_(sqr_mom).addcmul_(state['sqr_damp'], p.grad.data, p.grad.data)
        return state

class StepCount(Stat):
    def init_state(self, p): return {'step': 0}
    def update(self, p, state, **kwargs):
        state['step'] += 1
        return state
    
def debias(mom, damp, step): return damp * (1 - mom**step) / (1-mom)

def adam_step(p, lr, mom, mom_damp, step, sqr_mom, sqr_damp, grad_avg, sqr_avg, eps, **kwargs):
    debias1 = debias(mom,     mom_damp, step)
    debias2 = debias(sqr_mom, sqr_damp, step)
    p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2 ).sqrt()  + eps )
    return p
adam_step._defaults = dict(eps=1e-5)

def adam_opt(xtra_step=None, **kwargs):
    return partial(StatefulOptimizer, steppers=[adam_step,weight_decay]+listify(xtra_step),
                   stats=[AverageGrad(dampening=True), AverageSqrGrad(), StepCount()], **kwargs)

In [7]:
def jax_get_learner(data, layers, loss_func, cb_funcs, opt_func,lr ):
    model = nn.Sequential(*layers)
    return Learner(model, data, loss_func, cb_funcs, opt_func,lr)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bs = 64
train_dataset = jax_Dataset(root_dir='./train')
train_split_ds, valid_split_ds = torch.utils.data.random_split(train_dataset,
                    [round(train_dataset.size*0.8), train_dataset.size-round(train_dataset.size*0.8)])
train_dl,valid_dl = get_dls(train_split_ds, valid_split_ds, bs,num_workers=4,drop_last=True)
data = DataBunch(*get_dls(train_split_ds, valid_split_ds, bs))
cb_funcs = [ProgressCallback,partial(AvgStatsCallback,[accuracy]),partial(CudaCallback,device)]
adam_opt_func = adam_opt()
learn = jax_get_learner(data=data,layers=get_model_layers_3d(),loss_func=nn.MSELoss(),cb_funcs= cb_funcs, opt_func=adam_opt_func, lr=0.001)
model_init(learn.model)

%time learn.fit(2)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time,lr
0,0.016902,0.572451,0.007822,0.699093,1:56:28,0.001
1,0.0052,0.773527,0.008704,0.719957,05:06,0.001


CPU times: user 52min 47s, sys: 3min 23s, total: 56min 11s
Wall time: 2h 1min 35s


In [11]:
%time learn.fit(2)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time,lr
0,0.003195,0.844249,0.009021,0.773055,05:09,0.001
1,0.002268,0.885854,0.003086,0.861743,05:10,0.001


CPU times: user 21min 7s, sys: 1min 11s, total: 22min 18s
Wall time: 10min 20s


In [19]:
%time learn.fit(40)
# 保存
torch.save(learn.model, './0922.pkl')

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time,lr
0,0.017486,0.572059,0.011356,0.66003,05:35,0.001
1,0.005283,0.773031,0.003534,0.830089,05:04,0.001
2,0.003189,0.844881,0.002661,0.872505,04:56,0.001
3,0.002335,0.883372,0.001756,0.913533,05:06,0.001
4,0.001761,0.912328,0.001624,0.926338,05:11,0.001
5,0.001382,0.933256,0.001269,0.941084,05:10,0.001
6,0.001152,0.947459,0.001066,0.95363,05:13,0.001
7,0.000949,0.959971,0.001149,0.949057,05:10,0.001
8,0.000829,0.968012,0.000917,0.963374,05:09,0.001
9,0.000763,0.972704,0.000825,0.967539,05:04,0.001


CPU times: user 7h 6min 40s, sys: 21min 31s, total: 7h 28min 12s
Wall time: 3h 27min 20s


In [22]:
######################################### testing ####################################################
device = 'cuda'
bs = 64
acc_threshold = 1/22
test_dataset = jax_Dataset(root_dir='./test')
test_dl = DataLoader( test_dataset, batch_size=bs, shuffle=False)
learn.in_train = False
accTmp = 0
accTot = 0
accCnt = 0
accMean = 0
for i,(xb,yb) in enumerate(test_dl): 
    xb = xb.unsqueeze(1)           
    xb = (xb.to(device)).float()   
    yb = (yb.to(device)).float()   
    yh = learn.model(xb)
    accTmp = accuracy(yh,yb)
    accTot += accTmp
    accCnt += 1
    
accMean = accTot/accCnt
accMean

tensor(0.9311, device='cuda:0')

In [23]:
######################################### testing ####################################################
device = 'cuda'
bs = 64
acc_threshold = 3/44
test_dataset = jax_Dataset(root_dir='./test')
test_dl = DataLoader( test_dataset, batch_size=bs, shuffle=False)
learn.in_train = False
accTmp = 0
accTot = 0
accCnt = 0
accMean = 0

for i,(xb,yb) in enumerate(test_dl): 
    xb = xb.unsqueeze(1)           
    xb = (xb.to(device)).float()   
    yb = (yb.to(device)).float()   
    yh = learn.model(xb)
    accTmp = accuracy(yh,yb)
    accTot += accTmp
    accCnt += 1

accMean = accTot/accCnt
accMean

tensor(0.9769, device='cuda:0')

In [24]:
%time learn.fit(20)
# 保存
torch.save(learn.model, './0922-2.pkl')

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time,lr
0,0.000258,0.998887,0.000388,0.991784,05:13,0.001
1,0.000254,0.998901,0.000375,0.99265,05:15,0.001
2,0.00025,0.999012,0.000388,0.992194,05:20,0.001
3,0.000248,0.99901,0.000372,0.992495,05:22,0.001
4,0.000246,0.999072,0.000346,0.993563,06:04,0.001
5,0.00024,0.999171,0.000376,0.993159,06:04,0.001
6,0.000237,0.999162,0.000352,0.993658,05:27,0.001
7,0.000233,0.999271,0.000357,0.993282,05:24,0.001
8,0.000231,0.999279,0.000345,0.993684,05:17,0.001
9,0.000228,0.999301,0.000338,0.993836,05:13,0.001


CPU times: user 3h 44min 28s, sys: 11min 35s, total: 3h 56min 4s
Wall time: 1h 49min 10s


In [25]:
######################################### testing ####################################################
device = 'cuda'
bs = 64
acc_threshold = 3/44
test_dataset = jax_Dataset(root_dir='./test')
test_dl = DataLoader( test_dataset, batch_size=bs, shuffle=False)
learn.in_train = False
accTmp = 0
accTot = 0
accCnt = 0
accMean = 0

for i,(xb,yb) in enumerate(test_dl): 
    xb = xb.unsqueeze(1)           
    xb = (xb.to(device)).float()   
    yb = (yb.to(device)).float()   
    yh = learn.model(xb)
    accTmp = accuracy(yh,yb)
    accTot += accTmp
    accCnt += 1

accMean = accTot/accCnt
accMean

tensor(0.9813, device='cuda:0')

In [7]:
## for test$$
def paramRMSE(yh,yb):
    tmp = torch.abs(yh-yb)
    largpart = tmp.mul(tmp)
    return largpart.sum().sqrt()
def paramAccu(yh,yb,acc_threshold): 
    tmp = (torch.abs(yh-yb) <= acc_threshold).float()
    return tmp.mean()

In [8]:
############## testing ###########
## save ***.csv and ****.npy and comExp.csv
fpSave_rslt_csv = './test/rslt.csv';

f = open(fpSave_rslt_csv,'w',newline='',encoding='UTF-8')

csv_writer = csv.writer(f)
# csv_writer.writerow(["xO","yO","zO","pA1","pR1","pA2","pR2","TM"])
#csv_cmp_writer do not write head
######################################### testing ####################################################
device = 'cuda'
acc_threshold = 3/44
cb_funcs = [ProgressCallback,partial(AvgStatsCallback,[accuracy]),partial(CudaCallback,device)]
adam_opt_func = adam_opt()

test_bs = 1
test_dataset = jax_Dataset(root_dir='./test')
test_dl = DataLoader( test_dataset, batch_size=test_bs, shuffle=False)
learn = jax_get_learner(data=None,layers=get_model_layers_3d(),loss_func=nn.MSELoss(),cb_funcs= cb_funcs, opt_func=adam_opt_func, lr=0.001)
learn.in_train = False
learn.model = torch.load('./0922-2.pkl')
accTmp = 0
accTot = 0
accCnt = 0
accMean = 0
testTime = 0
paramacc1 = 0
paramrmse1 = 0
paramacc2 = 0
paramrmse2 = 0

for i,(xb,yb) in enumerate(test_dl): 
    xb = xb.unsqueeze(1)           
    xb = (xb.to(device)).float()   
    yb = (yb.to(device)).float()   
    
    Tstart = time.time()
    yh = learn.model(xb)
    Tend = time.time()
    testTime = Tend-Tstart
    
    acc_threshold = 3/44 
    paramacc1  = paramAccu(yh,yb,acc_threshold)
    paramrmse1 = paramRMSE(yh,yb)
    acc_threshold = 1/22
    paramacc2  = paramAccu(yh,yb,acc_threshold)
    paramrmse2 = paramRMSE(yh,yb)
    csv_writer.writerow([(yh[0,0].to('cpu')).detach().numpy(), 
                         (yh[0,1].to('cpu')).detach().numpy(),
                         (yh[0,2].to('cpu')).detach().numpy(), 
                         
                         (yb[0,0].to('cpu')).detach().numpy(), 
                         (yb[0,1].to('cpu')).detach().numpy(),
                         (yb[0,2].to('cpu')).detach().numpy(),
                         
                         paramacc1.to('cpu').detach().numpy(),
                         paramrmse1.to('cpu').detach().numpy(),
                         paramacc2.to('cpu').detach().numpy(),
                         paramrmse2.to('cpu').detach().numpy(),
                         
                         testTime]);
    accTot += paramacc1
    accCnt += 1


accMean = accTot/accCnt
accMean
f.close()