In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import torch
from torch import nn 
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [3]:
#device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

0) Prepare data

In [4]:
insurance = pd.read_csv('./insurance.csv')

In [5]:
X_array = pd.get_dummies(insurance.drop(['charges'], axis=1), drop_first=True, dtype=int).values
y_array = insurance.charges.values

In [6]:
X_train,X_test,y_train,y_test = train_test_split(X_array,y_array,test_size = 0.2,random_state = 0)
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train,y_train)
X_test = scaler.transform(X_test)

In [7]:
X_train_tensor = torch.tensor(X_train).float()
X_test_tensor = torch.tensor(X_test).float()
y_train_tensor = torch.tensor(y_train).float()
y_test_tensor = torch.tensor(y_test).float()

In [8]:
dataloader_train = DataLoader(dataset = TensorDataset(X_train_tensor,y_train_tensor),shuffle = True,batch_size = 10)
dataloader_test = DataLoader(dataset = TensorDataset(X_test_tensor,y_test_tensor),shuffle = True,batch_size = 10)

In [9]:
#for feature,label in dataloader_train:
#    print(feature,label)

1) Define model

In [10]:
class insurance_net(nn.Module):
    def __init__(self):
        super(insurance_net,self).__init__()
        self.hidden1 = nn.Linear(in_features = 8,out_features = 160)
        self.hidden2 = nn.Linear(in_features = 160,out_features = 80)
        self.hidden3 = nn.Linear(in_features = 80,out_features = 40)
        self.hidden4 = nn.Linear(in_features = 40,out_features = 20)
        self.hidden5 = nn.Linear(in_features = 20,out_features = 10)
        self.regression = nn.Linear(in_features = 10,out_features = 1)
    
    def forward(self,x):
        x = F.relu(self.hidden1(x))
        x = F.relu(self.hidden2(x))
        x = F.relu(self.hidden3(x))
        x = F.relu(self.hidden4(x))
        x = F.relu(self.hidden5(x))
        y = self.regression(x)
        return y

2) Loss and optimizer

In [11]:
from sklearn.metrics import r2_score
#Create my model
model = insurance_net()
#Choose MSELoss as the loss function
model.loss_func = nn.MSELoss()
#Choose Adam as the optimizer
model.optimizer = torch.optim.Adam(params = model.parameters(),lr = 0.1)
#Define r2 evaluation
def r2(y_pred,y_true):
    y_pred_cls = y_pred.data
    return r2_score(y_true,y_pred_cls)
# Choose r2 as score
model.metric_func = r2
# Set metric_name as r2
model.metric_name = 'r2'

3) Training function

In [12]:
def train(model,features,targets):
    # Switch to train mode
    model.train()
    # Reset gradient
    model.optimizer.zero_grad()
    # Forward pass and loss
    predictions = model(features)
    loss = model.loss_func(predictions,targets)
    # Evaluation metric calculation
    metric = model.metric_func(predictions,targets)
    # Backward pass
    loss.backward()
    # Updata
    model.optimizer.step()
    return loss.item(),metric.item()

4) Validating function

In [13]:
#Prediction mode without gradient
@torch.no_grad()
def valid(model,features,targets):
    #Switch to eval mode
    model.eval()
    # Forward pass and loss
    prediction = model(features)
    loss = model.loss_func(prediction,targets)
    # Evaluation metric calculation
    metric = model.metric_func(prediction,targets)
    return loss.item(),metric.item()

5) Training loop

In [14]:
def model_train(model,epochs,dataloader_train,dataloader_test,log_print_frequency):
    metric_name = model.metric_name
    # Record 'epoch', 'loss', etc as a dataframe
    metirc_history = pd.DataFrame(columns = ['epoch','loss',metric_name,'val_loss','val_'+metric_name])
    print('Training starts')
    for epoch in tqdm(range(1, epochs+1)):
        #Training loop
        # Loss
        loss_sum = 0.0
        # Evaluation metric
        metric_sum = 0.0
        step = 1
        for step,(features,targets) in enumerate(dataloader_train,1):
            loss,metric = train(model,features,targets)
            # Every 'log_step_freq' times print steps and loss, etc.
            loss_sum += loss
            metric_sum += metric
            if step % log_print_frequency == 0:
                print(f'[step = {step}] loss:{loss_sum/step},{metric_name}:{metric_sum/step}')
        
        # Evaluation loop
        val_loss_sum = 0.0
        val_metric_sum = 0.0
        val_step = 1
        for val_step, (features,targets) in enumerate(dataloader_test, 1):
            val_loss,val_metric = valid(model,features,targets)
            val_loss_sum += val_loss
            val_metric_sum += val_metric
        # Records
        info = (epoch, loss_sum/step, metric_sum/step, val_loss_sum/val_step, val_metric_sum/val_step)
        metirc_history.loc[epoch-1] = info
        # Print log
        print(f'epoch = {info[0]},loss = {info[1]},{metric_name} = {info[2]},val_loss = {info[3]},val_+{metric_name} = {info[4]}')
    print('Training done')
    return metirc_history

6) Training model

In [None]:
epochs = 800
# The frequency of prting log
log_step_freq = 10
history1 = model_train(model,epochs,dataloader_train,dataloader_test,log_step_freq)

Training starts


  0%|                                                                                          | 0/800 [00:00<?, ?it/s]

[step = 10] loss:247878840.8,r2:-0.6470218056703669
[step = 20] loss:210651302.0,r2:-0.31607034610133566
[step = 30] loss:194536907.06666666,r2:-0.1609323218892128
[step = 40] loss:189627257.5,r2:-0.19394927849262708
[step = 50] loss:198322856.4,r2:-0.27050073448405326
[step = 60] loss:187163379.6,r2:-0.28670681086870364
[step = 70] loss:190194461.14285713,r2:-0.25954480867393814
[step = 80] loss:185184502.75,r2:-0.23522895934886917
[step = 90] loss:184712407.84444445,r2:-0.5249806938698098
[step = 100] loss:174882394.0,r2:-0.4745227922481824


  0%|                                                                                  | 1/800 [00:00<03:04,  4.33it/s]

epoch = 1,loss = 172032104.7102804,r2 = -0.4447051109773044,val_loss = 167788468.2222222,val_+r2 = -0.03459773460855811
[step = 10] loss:194784852.4,r2:-0.2721542252693042
[step = 20] loss:168769747.5,r2:-0.19901295570743768
[step = 30] loss:149449405.06666666,r2:-0.19184195290589962
[step = 40] loss:147530234.2,r2:-0.3284215860886385
[step = 50] loss:151588658.0,r2:-0.28366812025230326
[step = 60] loss:151593011.73333332,r2:-0.2920279572057111
[step = 70] loss:165472908.2857143,r2:-0.33097184839662813
[step = 80] loss:166328243.1,r2:-0.29713084418117713
[step = 90] loss:161753200.24444443,r2:-0.2741738930732425
[step = 100] loss:161888982.74,r2:-0.24787788992458168


  0%|▏                                                                                 | 2/800 [00:00<03:08,  4.23it/s]

epoch = 2,loss = 159695983.34579438,r2 = -0.2264036238231761,val_loss = 190509108.5925926,val_+r2 = -0.23541218916247997
[step = 10] loss:150572427.6,r2:-0.8658903705367254
[step = 20] loss:145253293.9,r2:-0.5884766083923421
[step = 30] loss:164570403.8,r2:-0.5455585605220726
[step = 40] loss:160042771.75,r2:-0.41908180998435435
[step = 50] loss:154760387.42,r2:-0.3935873361124227
[step = 60] loss:155342831.51666668,r2:-0.3600838494885636
[step = 70] loss:151582674.41428572,r2:-0.3558340375638499
[step = 80] loss:155277676.8625,r2:-0.34343099882571343
[step = 90] loss:149101921.63333333,r2:-0.33389255478870483
[step = 100] loss:150770005.95,r2:-0.294276506896535


  0%|▎                                                                                 | 3/800 [00:00<03:15,  4.08it/s]

epoch = 3,loss = 156495281.2990654,r2 = -0.2828030446877348,val_loss = 167721226.96296296,val_+r2 = -0.30820036201602363
[step = 10] loss:137379977.0,r2:-0.24966163783633605
[step = 20] loss:152675118.1,r2:-0.10132136827479872
[step = 30] loss:149924918.46666667,r2:-0.06969116555785117
[step = 40] loss:152337326.85,r2:-0.16698886018576098
[step = 50] loss:152529575.88,r2:-0.14058412244083274
[step = 60] loss:150831337.83333334,r2:-0.11807078886069734
[step = 70] loss:156473910.08571428,r2:-0.10772206651922205
[step = 80] loss:163812595.225,r2:-0.1691480731358853
[step = 90] loss:159190992.75555557,r2:-0.15971236430620436


  0%|▍                                                                                 | 4/800 [00:00<03:20,  3.97it/s]

[step = 100] loss:155675313.16,r2:-0.16212123484004023
epoch = 4,loss = 153074320.9906542,r2 = -0.1571470622582098,val_loss = 164297205.03703704,val_+r2 = -0.13553896467780988
[step = 10] loss:108781682.4,r2:0.03367464264908411
[step = 20] loss:118925978.8,r2:-0.008734896585813556
[step = 30] loss:142853657.86666667,r2:0.010909488881032835
[step = 40] loss:144665595.6,r2:-0.0038965556832044777
[step = 50] loss:141512162.24,r2:-0.030613890730882814
[step = 60] loss:138877929.13333333,r2:-0.020995420506079603
[step = 70] loss:130902998.65714286,r2:-0.06747932157710283
[step = 80] loss:134383223.425,r2:-0.07044460250318774


  1%|▌                                                                                 | 5/800 [00:01<03:14,  4.09it/s]

[step = 90] loss:143216580.55555555,r2:-0.05649145941924955
[step = 100] loss:146185536.54,r2:-0.07139835172054645
epoch = 5,loss = 149283138.9906542,r2 = -0.07799065481547932,val_loss = 158451463.55555555,val_+r2 = -0.2727169215900086
[step = 10] loss:151395405.6,r2:-0.5912686040946016
[step = 20] loss:156509214.2,r2:-0.3510825828564852
[step = 30] loss:151387585.6,r2:-0.2188404949875896
[step = 40] loss:143884325.3,r2:-0.192968334176934
[step = 50] loss:147833229.16,r2:-0.22856072844609515
[step = 60] loss:146167339.1,r2:-0.309796754081059
[step = 70] loss:153964665.05714285,r2:-0.26708187497874786


  1%|▌                                                                                 | 6/800 [00:01<03:22,  3.91it/s]

[step = 80] loss:158956379.425,r2:-0.2538444281110136
[step = 90] loss:158834732.4888889,r2:-0.23794810344096154
[step = 100] loss:158244274.04,r2:-0.38596311635119174
epoch = 6,loss = 156366135.47663552,r2 = -0.385245346439068,val_loss = 164913522.37037036,val_+r2 = -0.10927967643050457
[step = 10] loss:108305502.4,r2:-0.01906494294591061
[step = 20] loss:115483330.2,r2:-0.22782672816249314
[step = 30] loss:112776176.33333333,r2:-0.15087072779739727
[step = 40] loss:124307482.85,r2:-0.19651467849968543
[step = 50] loss:138870937.48,r2:-0.17857077472153282


  1%|▋                                                                                 | 7/800 [00:01<03:20,  3.96it/s]

[step = 60] loss:135732826.56666666,r2:-0.18892554045078555
[step = 70] loss:149812643.57142857,r2:-0.16522926286507186
[step = 80] loss:146410514.725,r2:-0.16191150495523413
[step = 90] loss:146878536.95555556,r2:-0.17357266887084727
[step = 100] loss:150959941.04,r2:-0.20521550234579058
epoch = 7,loss = 156011522.46728972,r2 = -0.1967606852357548,val_loss = 183584967.4074074,val_+r2 = -1.3642879162683472
[step = 10] loss:152315806.8,r2:-0.8658587738501378
[step = 20] loss:148413409.8,r2:-0.5158267067250002
[step = 30] loss:138789816.46666667,r2:-0.4695195462112425
[step = 40] loss:140271760.875,r2:-0.3472972900230442
[step = 50] loss:153381444.7,r2:-0.2777581143532723
[step = 60] loss:168203743.58333334,r2:-0.25404767664450134
[step = 70] loss:165198452.12857142,r2:-0.24510946863596292
[step = 80] loss:160911921.2375,r2:-0.37888874242864157
[step = 90] loss:162797147.12222221,r2:-0.401151822846465
[step = 100] loss:161333888.29,r2:-0.4212725534047246


  1%|▊                                                                                 | 8/800 [00:02<03:40,  3.59it/s]

epoch = 8,loss = 162354410.6635514,r2 = -0.4091104620670851,val_loss = 160976664.0,val_+r2 = 0.0021048891290227235
[step = 10] loss:123963670.4,r2:0.012187534604684603
[step = 20] loss:123220369.15,r2:0.008834265558453397
[step = 30] loss:130381091.16666667,r2:0.010531831540764473
[step = 40] loss:129603481.275,r2:-0.05113132350384891
[step = 50] loss:131599018.54,r2:-0.07964398397658722
[step = 60] loss:144574613.58333334,r2:-0.05593724977973642
[step = 70] loss:141694577.92857143,r2:-0.06978004617212172
[step = 80] loss:149476702.2875,r2:-0.10467410277073255


  1%|▉                                                                                 | 9/800 [00:02<03:46,  3.49it/s]

[step = 90] loss:146697683.34444445,r2:-0.09761196627910673
[step = 100] loss:147738802.29,r2:-0.09180975964157206
epoch = 9,loss = 146239230.8504673,r2 = -0.09022847650285157,val_loss = 163501759.85185185,val_+r2 = -0.0741859191587651
[step = 10] loss:152824798.4,r2:-0.2539058263339341
[step = 20] loss:180245754.0,r2:-0.1331512555546151
[step = 30] loss:159001157.46666667,r2:-0.15623413552169374
[step = 40] loss:150717596.95,r2:-0.14755666355813984
[step = 50] loss:144176240.12,r2:-0.11455041113648645


  1%|█                                                                                | 10/800 [00:02<03:45,  3.50it/s]

[step = 60] loss:136391491.4,r2:-0.09518396273764375
[step = 70] loss:142656458.57142857,r2:-0.08789838082484575
[step = 80] loss:144913185.85,r2:-0.14256790359337773
[step = 90] loss:148225714.65555555,r2:-0.14315484966733844
[step = 100] loss:152049611.55,r2:-0.1652668503748649
epoch = 10,loss = 150436842.55140188,r2 = -0.1675610843148004,val_loss = 176307926.66666666,val_+r2 = -0.05009710233926081
[step = 10] loss:191971736.0,r2:-0.1699418998855243
[step = 20] loss:186757432.4,r2:-0.04555053065887079
[step = 30] loss:166004845.6,r2:-0.3187536716748273
[step = 40] loss:162500573.15,r2:-0.28528705007706306
[step = 50] loss:161393563.96,r2:-0.2132117035203412
[step = 60] loss:154927039.0,r2:-0.19172543108148749
[step = 70] loss:153596366.91428572,r2:-0.1621333038321382
[step = 80] loss:151655317.95,r2:-0.14965908001872025
[step = 90] loss:148627615.35555556,r2:-0.14218007046226608
[step = 100] loss:148899452.22,r2:-0.15932375560883721


  1%|█                                                                                | 11/800 [00:02<03:54,  3.36it/s]

epoch = 11,loss = 150559568.76635513,r2 = -0.17621633430196376,val_loss = 161551667.4074074,val_+r2 = -0.04530146263053904
[step = 10] loss:118731357.2,r2:-0.054179306053726374
[step = 20] loss:149737619.3,r2:-0.08293533978632817
[step = 30] loss:164447041.53333333,r2:-0.05428955222453621
[step = 40] loss:164369017.75,r2:-0.12020459964926618
[step = 50] loss:149359838.48,r2:-0.12661985271304732
[step = 60] loss:152011099.93333334,r2:-0.1489359370717982
[step = 70] loss:154210389.91428572,r2:-0.18765594239450226


  2%|█▏                                                                               | 12/800 [00:03<03:59,  3.29it/s]

[step = 80] loss:157108018.575,r2:-0.19484941595810654
[step = 90] loss:149817988.6111111,r2:-0.18693184982457037
[step = 100] loss:149971173.99,r2:-0.18618114021389665
epoch = 12,loss = 150268918.94392523,r2 = -0.18834646928842302,val_loss = 158689582.07407406,val_+r2 = -0.2256974732219127
[step = 10] loss:201043073.4,r2:-0.1959464990434723
[step = 20] loss:200262175.9,r2:-0.22721303664788298
[step = 30] loss:184472212.2,r2:-0.20693427387023683
[step = 40] loss:170038922.15,r2:-0.19844643544376464


  2%|█▎                                                                               | 13/800 [00:03<04:03,  3.24it/s]

[step = 50] loss:163733785.92,r2:-0.19815125291086894
[step = 60] loss:155723447.63333333,r2:-0.20819844673982205
[step = 70] loss:150602237.0,r2:-0.44219174742726436
[step = 80] loss:148113180.775,r2:-0.3925917837532759
[step = 90] loss:150345498.02222222,r2:-0.36830768585206003
[step = 100] loss:152239233.18,r2:-0.35432594190181044
epoch = 13,loss = 153110831.682243,r2 = -0.3396971874631235,val_loss = 186349722.64814815,val_+r2 = -0.17795746262093426
[step = 10] loss:127874258.4,r2:-0.25165926280383344
[step = 20] loss:153233922.2,r2:-0.36302460910450973
[step = 30] loss:157606942.26666668,r2:-0.4997232328089305
[step = 40] loss:163634345.4,r2:-0.3826735005850831
[step = 50] loss:165805867.52,r2:-0.44513277952984154
[step = 60] loss:162741032.8,r2:-0.5528903602214299
[step = 70] loss:153865385.85714287,r2:-0.48435906489600516
[step = 80] loss:151988260.625,r2:-0.43482640722698757


  2%|█▍                                                                               | 14/800 [00:03<04:07,  3.18it/s]

[step = 90] loss:152874609.4888889,r2:-0.385904189338993
[step = 100] loss:152599180.14,r2:-0.3624868298915122
epoch = 14,loss = 156578289.1588785,r2 = -0.3706356635038268,val_loss = 213631082.7777778,val_+r2 = -0.4357794516493211
[step = 10] loss:167933977.6,r2:-0.543778818510804
[step = 20] loss:162378797.2,r2:-0.3501235378548694
[step = 30] loss:160305757.73333332,r2:-0.24947215558877128
[step = 40] loss:163744791.95,r2:-0.23342385882339092
[step = 50] loss:170404434.72,r2:-0.1900974121755797
[step = 60] loss:163600262.86666667,r2:-0.23329918513592757


  2%|█▌                                                                               | 15/800 [00:04<03:57,  3.31it/s]

[step = 70] loss:160321007.82857144,r2:-0.24512008934327376
[step = 80] loss:163038727.2,r2:-0.23288677568295707
[step = 90] loss:167889131.64444444,r2:-0.22065710494173094
[step = 100] loss:159215409.6,r2:-0.19233530930986895
epoch = 15,loss = 155347822.24299064,r2 = -0.20555923540564247,val_loss = 157876815.7777778,val_+r2 = -0.2189075646720184
[step = 10] loss:149859744.0,r2:-0.14436799155139912
[step = 20] loss:155039018.0,r2:-0.04897973935521932
[step = 30] loss:144932039.86666667,r2:-0.04958925114322526
[step = 40] loss:136269483.9,r2:-0.06483371184584989
[step = 50] loss:134261159.92,r2:-0.14192254461130177
[step = 60] loss:137884426.66666666,r2:-0.1586357429694884
[step = 70] loss:145361335.77142859,r2:-0.19946961216717318
[step = 80] loss:146353344.2,r2:-0.2056267102735207


  2%|█▌                                                                               | 16/800 [00:04<04:33,  2.87it/s]

[step = 90] loss:146792623.37777779,r2:-0.17481105502545224
[step = 100] loss:147093591.88,r2:-0.19316859984906848
epoch = 16,loss = 148561884.59813085,r2 = -0.17765866548191483,val_loss = 158405973.33333334,val_+r2 = -0.14324758206282803
[step = 10] loss:144774320.8,r2:-0.014705009440594986
[step = 20] loss:190633593.2,r2:-0.8418330719871087
[step = 30] loss:171648723.4,r2:-0.6987653594182455
[step = 40] loss:169872654.7,r2:-0.569776869164935
[step = 50] loss:162700043.36,r2:-0.5808572905521276
[step = 60] loss:169797793.46666667,r2:-0.6771642277300413
[step = 70] loss:163362214.45714286,r2:-0.608697371311286


  2%|█▋                                                                               | 17/800 [00:05<05:11,  2.51it/s]

[step = 80] loss:159880489.2,r2:-0.5311925365178881
[step = 90] loss:155783198.0,r2:-0.47417092161374597
[step = 100] loss:157207631.6,r2:-0.42482346090227385
epoch = 17,loss = 153059515.7009346,r2 = -0.4349056570119749,val_loss = 172056287.44444445,val_+r2 = -0.11413326042436482
[step = 10] loss:152545394.8,r2:-0.08673794374532998
[step = 20] loss:142083614.6,r2:-0.02278151671900418
[step = 30] loss:155009135.06666666,r2:-0.002988426376597379
[step = 40] loss:152771572.275,r2:-0.19539467470373528
[step = 50] loss:161041377.02,r2:-0.16930261935071506
[step = 60] loss:160595546.85,r2:-0.1426009626363612


  2%|█▊                                                                               | 18/800 [00:05<05:38,  2.31it/s]

[step = 70] loss:151701195.8,r2:-0.12372820433357987
[step = 80] loss:154269481.925,r2:-0.12245159048373863
[step = 90] loss:152004406.0,r2:-0.1707905902175865
[step = 100] loss:152867838.64,r2:-0.17312441486623115
epoch = 18,loss = 150199611.73831776,r2 = -0.19448475878021063,val_loss = 162411540.07407406,val_+r2 = -0.1130182134801065
[step = 10] loss:135275103.0,r2:-0.46693866121900196
[step = 20] loss:135754561.2,r2:-0.3260153594261233
[step = 30] loss:146592186.0,r2:-0.24222486702847396
[step = 40] loss:141509016.4,r2:-0.1997185371974768
[step = 50] loss:132290236.46,r2:-0.1599443034513134
[step = 60] loss:138558795.85,r2:-0.14652595071484217
[step = 70] loss:137702001.58571428,r2:-0.29330442875717777
[step = 80] loss:139747431.2375,r2:-0.306576748067731
[step = 90] loss:145798436.92222223,r2:-0.31433554674859504
[step = 100] loss:145138551.11,r2:-0.28485140887505633


  2%|█▉                                                                               | 19/800 [00:06<05:59,  2.17it/s]

epoch = 19,loss = 146540183.9906542,r2 = -0.26392852394928384,val_loss = 158962066.03703704,val_+r2 = -0.2510868083766684
[step = 10] loss:177858778.4,r2:-0.027110098877868673
[step = 20] loss:161406698.4,r2:-0.03173819818367046
[step = 30] loss:154627620.8,r2:-0.2745629570910656
[step = 40] loss:144751508.1,r2:-0.23011971802363282
[step = 50] loss:141965927.08,r2:-0.20559298949314467
[step = 60] loss:150893608.96666667,r2:-0.1819150251243906
[step = 70] loss:149107805.5142857,r2:-0.18446210423130846
[step = 80] loss:149145136.575,r2:-0.16374398566619336
[step = 90] loss:148309007.4888889,r2:-0.16645640877801662


  2%|██                                                                               | 20/800 [00:06<06:10,  2.11it/s]

[step = 100] loss:144935193.52,r2:-0.1634402819371164
epoch = 20,loss = 145960352.8224299,r2 = -0.15879952000890965,val_loss = 164111376.8888889,val_+r2 = -0.4012420573892784
[step = 10] loss:164221398.4,r2:-0.19870201833945092
[step = 20] loss:138235788.5,r2:-0.09401960191757745
[step = 30] loss:166701866.6,r2:-0.06228886864808463
[step = 40] loss:165896157.25,r2:-0.2483019110973615
[step = 50] loss:163820026.84,r2:-0.2098764130070542
[step = 60] loss:159609485.96666667,r2:-0.21738913820164155
[step = 70] loss:160524704.2,r2:-0.4669500073569445
[step = 80] loss:159409203.9,r2:-0.40712449089861014


  3%|██▏                                                                              | 21/800 [00:07<06:18,  2.06it/s]

[step = 90] loss:152714706.4,r2:-0.38990669063419275
[step = 100] loss:151805930.36,r2:-0.349612981949587
epoch = 21,loss = 147857158.6915888,r2 = -0.3286856735753435,val_loss = 164282000.2962963,val_+r2 = -0.1328020511989839
[step = 10] loss:99305706.4,r2:-0.5439844008411451
[step = 20] loss:137943562.2,r2:-0.3195425376765596
[step = 30] loss:146693284.66666666,r2:-0.23846572960606852
[step = 40] loss:149015781.8,r2:-0.1780224121344292
[step = 50] loss:152624227.6,r2:-0.1470292317511613
[step = 60] loss:148420619.33333334,r2:-0.37410002788511754
[step = 70] loss:153756060.34285715,r2:-0.3344661271784217


  3%|██▏                                                                              | 22/800 [00:07<06:25,  2.02it/s]

[step = 80] loss:156332084.5,r2:-0.29973417115296
[step = 90] loss:149015524.1111111,r2:-0.2823907866958081
[step = 100] loss:146395198.34,r2:-0.26085090923137855
epoch = 22,loss = 147314850.78504673,r2 = -0.28137586856478536,val_loss = 161114983.7037037,val_+r2 = -0.4120136072363105
[step = 10] loss:120733052.8,r2:-0.14438358689014083
[step = 20] loss:144511545.0,r2:-0.17277503327994187
[step = 30] loss:148115299.6,r2:-0.0887981648667167
[step = 40] loss:143179127.5,r2:-0.0882928976377048
[step = 50] loss:142967659.56,r2:-0.0847540307614836
[step = 60] loss:142741408.16666666,r2:-0.10823892686187583


  3%|██▎                                                                              | 23/800 [00:08<06:31,  1.98it/s]

[step = 70] loss:142594887.8,r2:-0.08939523028960704
[step = 80] loss:144601735.425,r2:-0.08547923912235433
[step = 90] loss:145665658.2888889,r2:-0.07539577966491816
[step = 100] loss:147514223.58,r2:-0.07542504160249791
epoch = 23,loss = 145911657.8504673,r2 = -0.1732514478642759,val_loss = 177527400.14814815,val_+r2 = -0.15669439730519397
[step = 10] loss:132558331.6,r2:-0.040433442990750024
[step = 20] loss:138765671.0,r2:0.0004922755136798407
[step = 30] loss:141477158.66666666,r2:-0.007038947876247277
[step = 40] loss:138263237.1,r2:-0.040503195747226085
[step = 50] loss:134549622.88,r2:-0.05743437435761134
[step = 60] loss:132412810.13333334,r2:-0.04527235096659602
[step = 70] loss:140990349.42857143,r2:-0.16580384570204676
[step = 80] loss:138867648.4,r2:-0.15644343328143526
[step = 90] loss:137613936.55555555,r2:-0.13730384271998503
[step = 100] loss:140460918.78,r2:-0.13845292910380103


  3%|██▍                                                                              | 24/800 [00:08<06:36,  1.96it/s]

epoch = 24,loss = 144852399.86915886,r2 = -0.1374794274454882,val_loss = 163392905.85185185,val_+r2 = -0.42465661271900246
[step = 10] loss:171583705.2,r2:-0.042181765896554216
[step = 20] loss:154214534.2,r2:-0.1042673340500386
[step = 30] loss:140429128.86666667,r2:-0.09193198008839785
[step = 40] loss:145101741.9,r2:-0.11289811783992765
[step = 50] loss:149569382.96,r2:-0.11261801385927515
[step = 60] loss:151432885.2,r2:-0.3147011177807615
[step = 70] loss:151787551.34285715,r2:-0.2890951799499891
[step = 80] loss:150111466.125,r2:-0.2933613741430377
[step = 90] loss:147762845.08888888,r2:-0.2691147229170398


  3%|██▌                                                                              | 25/800 [00:09<06:43,  1.92it/s]

[step = 100] loss:150166967.54,r2:-0.2768249152987044
epoch = 25,loss = 147917139.45794392,r2 = -0.26068677228954334,val_loss = 173504634.57407406,val_+r2 = -0.05738981778513696
[step = 10] loss:116928319.8,r2:-0.3537732080685724
[step = 20] loss:128664011.1,r2:-0.251655231549513
[step = 30] loss:134878838.73333332,r2:-0.3803297096978763
[step = 40] loss:143728589.35,r2:-0.35534701027056403
[step = 50] loss:142694479.64,r2:-0.2930115972890636
[step = 60] loss:136741148.03333333,r2:-0.25694154619491216
[step = 70] loss:137440968.1142857,r2:-0.22064352217051375


  3%|██▋                                                                              | 26/800 [00:09<06:49,  1.89it/s]

[step = 80] loss:142605562.5,r2:-0.20932333113754295
[step = 90] loss:148819503.2,r2:-0.18803353975537287
[step = 100] loss:148587592.56,r2:-0.2303852699175023
epoch = 26,loss = 145867983.10280374,r2 = -0.23772119089596905,val_loss = 175041287.14814815,val_+r2 = -0.12363672135239027
[step = 10] loss:151841802.0,r2:-0.05285510639862665
[step = 20] loss:137337850.4,r2:-0.0377972294429126
[step = 30] loss:136516221.86666667,r2:-0.015457801420021547
[step = 40] loss:139240397.2,r2:-0.008399840229001157
[step = 50] loss:145709863.2,r2:-0.016793184654963968
[step = 60] loss:148009201.2,r2:-0.007047508196634284
[step = 70] loss:144522667.2,r2:-0.017146498892978864
[step = 80] loss:148227123.65,r2:-0.04569291977996496
[step = 90] loss:143012380.84444445,r2:-0.13398761433974252


  3%|██▋                                                                              | 27/800 [00:10<06:57,  1.85it/s]

[step = 100] loss:143528145.88,r2:-0.1384941889951931
epoch = 27,loss = 145841463.06542057,r2 = -0.12500413272679542,val_loss = 158816816.2962963,val_+r2 = -0.40517562920942246
[step = 10] loss:184727796.8,r2:-0.13956445564108158
[step = 20] loss:173134086.4,r2:-0.16190524160428527
[step = 30] loss:169222659.2,r2:-0.15005968900872416
[step = 40] loss:159883381.125,r2:-0.27541771087615236
[step = 50] loss:159722720.82,r2:-0.22763675675857
[step = 60] loss:161106167.81666666,r2:-0.21890543665562834
[step = 70] loss:160657331.1,r2:-0.21434447723344724


  4%|██▊                                                                              | 28/800 [00:11<07:00,  1.83it/s]

[step = 80] loss:159299895.3125,r2:-0.20602378561791337
[step = 90] loss:151417387.03333333,r2:-0.21868988007424509
[step = 100] loss:151326629.77,r2:-0.2261131627252472
epoch = 28,loss = 146432663.52336448,r2 = -0.21439562291097267,val_loss = 167344644.5185185,val_+r2 = -0.05234705039595472
[step = 10] loss:121618287.6,r2:-0.02007986761439694
[step = 20] loss:158794812.8,r2:-0.08157702197998383
[step = 30] loss:141863157.93333334,r2:-0.20010455373405184
[step = 40] loss:141693662.1,r2:-0.1995999168052112
[step = 50] loss:141099453.44,r2:-0.24012603422944068
[step = 60] loss:145892729.53333333,r2:-0.20602058101425494
[step = 70] loss:141918678.22857141,r2:-0.18384650595759275
[step = 80] loss:142462744.8,r2:-0.19559782166967365
[step = 90] loss:148515563.2888889,r2:-0.1690472218165493


  4%|██▉                                                                              | 29/800 [00:11<07:01,  1.83it/s]

[step = 100] loss:150071972.36,r2:-0.17125262662529223
epoch = 29,loss = 147474849.02803737,r2 = -0.17915094645593072,val_loss = 171567282.07407406,val_+r2 = -0.10357301995546557
[step = 10] loss:184439121.2,r2:-0.2132750316872108
[step = 20] loss:151224368.2,r2:-0.1070570782095239
[step = 30] loss:145785658.13333333,r2:-0.08911790627765483
[step = 40] loss:148134578.0,r2:-0.12376388507542897
[step = 50] loss:153377081.12,r2:-0.20434252215239404
[step = 60] loss:156803859.73333332,r2:-0.24092858815296925
[step = 70] loss:148919929.82857144,r2:-0.2411853363775195


  4%|███                                                                              | 30/800 [00:12<07:02,  1.82it/s]

[step = 80] loss:146744218.2,r2:-0.20892577082292219
[step = 90] loss:141587003.37777779,r2:-0.1958483495095076
[step = 100] loss:147480333.36,r2:-0.21363872175247575
epoch = 30,loss = 147028620.41121495,r2 = -0.25486781779540657,val_loss = 169172903.4074074,val_+r2 = -0.153503299955975
[step = 10] loss:151626724.4,r2:-0.028050155137159073
[step = 20] loss:155448256.6,r2:-0.02757111740489402
[step = 30] loss:165541371.06666666,r2:-0.035597673854579334
[step = 40] loss:162079619.8,r2:-0.06532539275703661
[step = 50] loss:159440500.0,r2:-0.06501332965760752
[step = 60] loss:150173064.0,r2:-0.07811608524912785
[step = 70] loss:148526571.68571427,r2:-0.1410459988162047
[step = 80] loss:144967770.9875,r2:-0.24913027294210935
[step = 90] loss:146292894.12222221,r2:-0.20808170642212995
[step = 100] loss:145995639.73,r2:-0.20459176996388784


  4%|███▏                                                                             | 31/800 [00:12<07:02,  1.82it/s]

epoch = 31,loss = 145402465.42990655,r2 = -0.19532873691405037,val_loss = 167356473.85185185,val_+r2 = -0.08292562630842688
[step = 10] loss:170358491.2,r2:-0.7434648883328531
[step = 20] loss:158712510.0,r2:-0.44449105208533385
[step = 30] loss:155572345.86666667,r2:-0.3981237926365128
[step = 40] loss:156683498.8,r2:-0.32737779911479
[step = 50] loss:153479165.68,r2:-0.3436001550345192
[step = 60] loss:159304082.06666666,r2:-0.34956575641129817
[step = 70] loss:161156907.02857143,r2:-0.2935704273651357
[step = 80] loss:152741404.95,r2:-0.29488088965262
[step = 90] loss:151583310.31111112,r2:-0.262945412894301


  4%|███▏                                                                             | 32/800 [00:13<07:01,  1.82it/s]

[step = 100] loss:146167198.52,r2:-0.23412995745137152
epoch = 32,loss = 145111905.94392523,r2 = -0.22845814500896955,val_loss = 157660734.66666666,val_+r2 = -0.16924883351684675
[step = 10] loss:175120688.0,r2:0.06194853692240726
[step = 20] loss:160413931.2,r2:0.03607108272958585
[step = 30] loss:141388417.6,r2:0.026691460529508867
[step = 40] loss:141926084.3,r2:-0.00386704640428476
[step = 50] loss:144061123.76,r2:-0.03710555191452996
[step = 60] loss:145331900.86666667,r2:-0.052647027748343796
[step = 70] loss:145168474.17142856,r2:-0.06867955139763463


  4%|███▎                                                                             | 33/800 [00:13<07:02,  1.82it/s]

[step = 80] loss:145841271.6,r2:-0.06439832928876105
[step = 90] loss:143900612.62222221,r2:-0.14487182618226077
[step = 100] loss:146447447.28,r2:-0.148504385395395
epoch = 33,loss = 145292074.95327103,r2 = -0.14583042822576578,val_loss = 161686609.4814815,val_+r2 = 0.06380463308371444
[step = 10] loss:165024424.0,r2:0.10987536470516202
[step = 20] loss:152819822.6,r2:-0.032969813798424376
[step = 30] loss:148263746.26666668,r2:-0.028171139580789872
[step = 40] loss:142705463.35,r2:-0.026748349167312058
[step = 50] loss:134799541.32,r2:-0.2311563487724899
[step = 60] loss:139456804.06666666,r2:-0.214662781501995
[step = 70] loss:144517435.8857143,r2:-0.3049887425211124
[step = 80] loss:143442334.8375,r2:-0.331734528133079
[step = 90] loss:148959171.5,r2:-0.31501570482658514
[step = 100] loss:146866688.51,r2:-0.32071512188316853
epoch = 34,loss = 151392431.46728972,r2 = -0.31190642957948517,val_loss = 178035529.03703704,val_+r2 = -0.12752467929735378


  4%|███▍                                                                             | 34/800 [00:14<07:03,  1.81it/s]

[step = 10] loss:150607159.4,r2:-1.9330946679267562
[step = 20] loss:152518468.7,r2:-0.932906779683257
[step = 30] loss:139047904.5,r2:-0.6824735528324953
[step = 40] loss:133982150.275,r2:-0.49322334011364843
[step = 50] loss:149947835.58,r2:-0.38595762318991866
[step = 60] loss:149302780.65,r2:-0.34590352537594554
[step = 70] loss:145210913.75714287,r2:-0.29936166878756715
[step = 80] loss:145159994.3375,r2:-0.27203300239713946


  4%|███▌                                                                             | 35/800 [00:14<07:05,  1.80it/s]

[step = 90] loss:147059721.9,r2:-0.32719489917264016
[step = 100] loss:144704449.93,r2:-0.3022929461920551
epoch = 35,loss = 145120879.52336448,r2 = -0.3135564571700079,val_loss = 161676687.85185185,val_+r2 = 0.020754187240739578
[step = 10] loss:143868937.6,r2:-0.14905006100082993
[step = 20] loss:147942183.2,r2:-0.058408709394532
[step = 30] loss:135187849.73333332,r2:-0.10156917880054633
[step = 40] loss:137443528.65,r2:-0.18575737652421404
[step = 50] loss:135011463.24,r2:-0.16099227099211547
[step = 60] loss:136417641.83333334,r2:-0.13585050073266944
[step = 70] loss:143547392.54285714,r2:-0.24003305592230711


  4%|███▋                                                                             | 36/800 [00:15<07:05,  1.79it/s]

[step = 80] loss:146846091.725,r2:-0.35052878507043855
[step = 90] loss:150197211.66666666,r2:-0.3078014993654484
[step = 100] loss:148721814.08,r2:-0.2773634949656185
epoch = 36,loss = 145724273.94392523,r2 = -0.28992901031746915,val_loss = 168005026.96296296,val_+r2 = -0.05041185185509545
[step = 10] loss:143549364.8,r2:0.017098257941605677
[step = 20] loss:133881848.0,r2:-0.37543131915580596
[step = 30] loss:136341431.86666667,r2:-0.2545706428071519
[step = 40] loss:130601652.4,r2:-0.326037286289046
[step = 50] loss:135190967.36,r2:-0.29769425111894676
[step = 60] loss:145618189.56666666,r2:-0.2409203567340012
[step = 70] loss:142547244.31428573,r2:-0.2953377111306397
[step = 80] loss:151891244.15,r2:-0.2918351220068989
[step = 90] loss:147560678.93333334,r2:-0.3103401659533327
[step = 100] loss:149387938.68,r2:-0.3262542794859343


  5%|███▋                                                                             | 37/800 [00:16<07:06,  1.79it/s]

epoch = 37,loss = 148773323.40186915,r2 = -0.3143750445209967,val_loss = 159384118.66666666,val_+r2 = -0.16029766761827524
[step = 10] loss:154953440.4,r2:-0.343960555761979
[step = 20] loss:165960067.0,r2:-0.24777570037511515
[step = 30] loss:159006714.53333333,r2:-0.4561759123732612
[step = 40] loss:144497653.6,r2:-0.4267634639562258
[step = 50] loss:150479317.84,r2:-0.3625041420411293
[step = 60] loss:146708044.73333332,r2:-0.37013518777343407
[step = 70] loss:148325082.74285713,r2:-0.31758726229888445
[step = 80] loss:146945383.15,r2:-0.27150773846144616
[step = 90] loss:151476111.86666667,r2:-0.2773694172062287


  5%|███▊                                                                             | 38/800 [00:16<07:06,  1.79it/s]

[step = 100] loss:147585863.4,r2:-0.2767569178799459
epoch = 38,loss = 146696600.42990655,r2 = -0.25877202115858405,val_loss = 170840729.7777778,val_+r2 = -0.04122515944432733
[step = 10] loss:145243529.2,r2:-0.3022182484931414
[step = 20] loss:126053879.1,r2:-0.5405941742239211
[step = 30] loss:133693740.73333333,r2:-0.3453874619568748
[step = 40] loss:138061235.85,r2:-0.28524052203035927
[step = 50] loss:136516009.64,r2:-0.3577103909406443
[step = 60] loss:131246557.6,r2:-0.3039043416021045
[step = 70] loss:138551304.45714286,r2:-0.2639491820332798


  5%|███▉                                                                             | 39/800 [00:17<07:05,  1.79it/s]

[step = 80] loss:140531590.3,r2:-0.22836039737826877
[step = 90] loss:139862220.4888889,r2:-0.21690851031970687
[step = 100] loss:145334576.2,r2:-0.1883579312410487
epoch = 39,loss = 143611218.05607477,r2 = -0.21086077282263802,val_loss = 160927526.74074075,val_+r2 = -0.06749066109584302
[step = 10] loss:155910053.6,r2:-0.032646806404150075
[step = 20] loss:148804285.2,r2:-0.3781094685203974
[step = 30] loss:135505769.8,r2:-0.4296835086015618
[step = 40] loss:134331902.15,r2:-0.3318773217444131
[step = 50] loss:129016682.44,r2:-0.3514969363413979
[step = 60] loss:132278379.83333333,r2:-0.2887004664559012
[step = 70] loss:128850393.74285714,r2:-0.3150369673697017
[step = 80] loss:133062123.475,r2:-0.27271684474984337
[step = 90] loss:141264264.06666666,r2:-0.24975088208094362
[step = 100] loss:142249624.02,r2:-0.2550579723490285


  5%|████                                                                             | 40/800 [00:17<07:05,  1.79it/s]

epoch = 40,loss = 145682869.62616822,r2 = -0.2402653039062207,val_loss = 159547356.0,val_+r2 = -0.19931944207566352
[step = 10] loss:146321934.8,r2:-0.09931189494230724
[step = 20] loss:146163305.2,r2:-0.07448714114288661
[step = 30] loss:152912612.8,r2:-0.06643957965425656
[step = 40] loss:149970590.9,r2:-0.06396233931581244
[step = 50] loss:139579055.88,r2:-0.21091242353601575
[step = 60] loss:147036547.1,r2:-0.18741789898330005
[step = 70] loss:154509552.77142859,r2:-0.1424614711772906
[step = 80] loss:151950777.625,r2:-0.15685486514348476
[step = 90] loss:147090468.7111111,r2:-0.1637824346785649


  5%|████▏                                                                            | 41/800 [00:18<07:04,  1.79it/s]

[step = 100] loss:149335169.28,r2:-0.16082473086046178
epoch = 41,loss = 148452452.78504673,r2 = -0.2638828462059231,val_loss = 162538429.62962964,val_+r2 = -0.17829896677209978
[step = 10] loss:154999495.6,r2:-0.11157728927562922
[step = 20] loss:152437257.6,r2:-0.20987135312312905
[step = 30] loss:153226509.6,r2:-0.2617742540371823
[step = 40] loss:154774703.35,r2:-0.2950957460847595
[step = 50] loss:152610850.8,r2:-0.2846773581615367
[step = 60] loss:144839387.66666666,r2:-0.2652427567798594
[step = 70] loss:136678839.02857143,r2:-0.2521026181819752


  5%|████▎                                                                            | 42/800 [00:18<07:05,  1.78it/s]

[step = 80] loss:143947727.65,r2:-0.21702071239249543
[step = 90] loss:145962712.97777778,r2:-0.19051253928305892
[step = 100] loss:144814503.84,r2:-0.21285802809759288
epoch = 42,loss = 145201049.75700936,r2 = -0.20607603272660605,val_loss = 166346245.55555555,val_+r2 = -0.36167068465444385
[step = 10] loss:128230732.6,r2:-0.08081485584335765
[step = 20] loss:129263834.2,r2:-0.28558582455478404
[step = 30] loss:145960649.86666667,r2:-0.24360734805103915
[step = 40] loss:137052394.8,r2:-0.28087362773423213
[step = 50] loss:136637659.84,r2:-0.2866002698502265
[step = 60] loss:136197659.13333333,r2:-0.2318649145772128
[step = 70] loss:135751934.22857141,r2:-0.21108217379275832
[step = 80] loss:138431670.45,r2:-0.19734545926965788
[step = 90] loss:139595948.93333334,r2:-0.20728916409159426
[step = 100] loss:142893112.66,r2:-0.1949317883200117


  5%|████▎                                                                            | 43/800 [00:19<07:04,  1.78it/s]

epoch = 43,loss = 145084784.24299064,r2 = -0.1799924559549175,val_loss = 159424336.74074075,val_+r2 = -0.22449907829824906
[step = 10] loss:156946299.2,r2:-0.16085353384113982
[step = 20] loss:153055627.65,r2:-0.10923498274361369
[step = 30] loss:144101645.63333333,r2:-0.11376392167510495
[step = 40] loss:139362569.525,r2:-0.2843148335008962
[step = 50] loss:148636961.06,r2:-0.2337821451498096
[step = 60] loss:149005754.35,r2:-0.255553787683021
[step = 70] loss:143219537.24285713,r2:-0.23800991555949957
[step = 80] loss:144332477.0375,r2:-0.22056132756036445
[step = 90] loss:146452705.32222223,r2:-0.19534789621591608


  6%|████▍                                                                            | 44/800 [00:19<07:02,  1.79it/s]

[step = 100] loss:146853797.27,r2:-0.19672900271371518
epoch = 44,loss = 146794506.62616822,r2 = -0.2503198512131788,val_loss = 158361906.2222222,val_+r2 = -0.13057639665195872
[step = 10] loss:102649410.8,r2:-1.017584994894904
[step = 20] loss:136284461.8,r2:-0.5713060960660467
[step = 30] loss:147815126.33333334,r2:-0.3894547084198154
[step = 40] loss:152976847.75,r2:-0.3991605078057545
[step = 50] loss:151616534.84,r2:-0.4024812628501018
[step = 60] loss:152371052.3,r2:-0.39549872912071277
[step = 70] loss:149153385.4,r2:-0.3529209998494578
[step = 80] loss:150152476.6,r2:-0.36428583923873903


  6%|████▌                                                                            | 45/800 [00:20<07:03,  1.78it/s]

[step = 90] loss:149912086.8888889,r2:-0.3293260449659338
[step = 100] loss:146354998.88,r2:-0.32571936397429996
epoch = 45,loss = 145003272.8598131,r2 = -0.3236113150198572,val_loss = 161885492.07407406,val_+r2 = -0.08003786441454078
[step = 10] loss:196126943.2,r2:-0.12011506698263352
[step = 20] loss:159952525.6,r2:-0.16550793325639274
[step = 30] loss:162626893.33333334,r2:-0.09864905988282506
[step = 40] loss:160268644.1,r2:-0.11053586633651954
[step = 50] loss:154737504.56,r2:-0.0948839963637564
[step = 60] loss:153449497.53333333,r2:-0.10648307543189468


  6%|████▋                                                                            | 46/800 [00:21<07:02,  1.78it/s]

[step = 70] loss:148124825.25714287,r2:-0.12744244796023704
[step = 80] loss:147248326.45,r2:-0.14079300011908272
[step = 90] loss:146690202.2222222,r2:-0.12104262925615161
[step = 100] loss:144445279.2,r2:-0.10345965291720308
epoch = 46,loss = 145945972.74766356,r2 = -0.09785123224249134,val_loss = 160196049.92592594,val_+r2 = -0.2683817989138987
[step = 10] loss:140408477.0,r2:-0.011628898409322774
[step = 20] loss:130937294.5,r2:-0.1751290641421294
[step = 30] loss:143260728.46666667,r2:-0.11867877234572628
[step = 40] loss:144106438.85,r2:-0.1899684275398294
[step = 50] loss:147479491.96,r2:-0.23418620818594618
[step = 60] loss:145449129.75,r2:-0.21774303039899814
[step = 70] loss:143786042.12857142,r2:-0.21794653916147563
[step = 80] loss:143731468.7625,r2:-0.1800930026827839
[step = 90] loss:144332537.2111111,r2:-0.17717214431778538
[step = 100] loss:144901658.45,r2:-0.16344171694955786


  6%|████▊                                                                            | 47/800 [00:21<07:03,  1.78it/s]

epoch = 47,loss = 143836556.1588785,r2 = -0.18880311664869723,val_loss = 159421236.74074075,val_+r2 = -0.39417715031622536
[step = 10] loss:157855810.0,r2:-0.8345181377305076
[step = 20] loss:150189393.7,r2:-0.45503515347968043
[step = 30] loss:151926231.26666668,r2:-0.26540487536815793
[step = 40] loss:158338179.65,r2:-0.19624286534106553
[step = 50] loss:157786868.28,r2:-0.25003746239301705
[step = 60] loss:151891877.46666667,r2:-0.2404122300516849
[step = 70] loss:151248150.94285715,r2:-0.22930114617114136
[step = 80] loss:147331689.875,r2:-0.33434531570820186
[step = 90] loss:145989787.24444443,r2:-0.3105974673037658


  6%|████▊                                                                            | 48/800 [00:22<07:02,  1.78it/s]

[step = 100] loss:146040958.44,r2:-0.32186989730082133
epoch = 48,loss = 144801451.47663552,r2 = -0.2983202082591687,val_loss = 158133210.37037036,val_+r2 = -0.07436046513942311
[step = 10] loss:173287277.6,r2:0.0750059337248815
[step = 20] loss:146842973.8,r2:-0.016049848367058646
[step = 30] loss:147706397.4,r2:-0.06125353516992678
[step = 40] loss:138436904.85,r2:-0.06158864568806157
[step = 50] loss:147894134.88,r2:-0.15609214087954507
[step = 60] loss:147708759.26666668,r2:-0.2255386872251876


  6%|████▉                                                                            | 49/800 [00:22<07:00,  1.79it/s]

[step = 70] loss:139225579.7142857,r2:-0.2847996685932569
[step = 80] loss:138645927.9,r2:-0.26183311651187696
[step = 90] loss:143713622.4888889,r2:-0.22768860169187402
[step = 100] loss:144344908.24,r2:-0.20838166804117667
epoch = 49,loss = 144420530.54205608,r2 = -0.200248747162152,val_loss = 158237265.4814815,val_+r2 = -0.15167420425927694
[step = 10] loss:135660433.6,r2:-0.381752802232572
[step = 20] loss:148054365.0,r2:-0.22070713742766915
[step = 30] loss:151665402.33333334,r2:-0.3202800163887549
[step = 40] loss:143743262.35,r2:-0.4916083162995156
[step = 50] loss:142817206.6,r2:-0.41938206472052497
[step = 60] loss:141044935.9,r2:-0.445347485107037
[step = 70] loss:145963788.54285714,r2:-0.4633775214978428
[step = 80] loss:143625450.775,r2:-0.41383249344034523
[step = 90] loss:142671903.62222221,r2:-0.3865219689293457
[step = 100] loss:144506873.1,r2:-0.38186328287082405


  6%|█████                                                                            | 50/800 [00:23<06:59,  1.79it/s]

epoch = 50,loss = 145055258.07476637,r2 = -0.38459038971313075,val_loss = 156402917.7777778,val_+r2 = -0.14634723945929018
[step = 10] loss:153408574.4,r2:0.09918999186098312
[step = 20] loss:158587699.8,r2:0.06671547181813545
[step = 30] loss:156821929.06666666,r2:-0.58456798619332
[step = 40] loss:161244745.2,r2:-0.4434212451596785
[step = 50] loss:153374547.6,r2:-0.45777174581885804
[step = 60] loss:152540745.53333333,r2:-0.37693640758639796
[step = 70] loss:148366081.25714287,r2:-0.36693268301706594
[step = 80] loss:147559819.225,r2:-0.371398321088407
[step = 90] loss:144572527.13333333,r2:-0.3355941917521749


  6%|█████▏                                                                           | 51/800 [00:23<06:58,  1.79it/s]

[step = 100] loss:146634691.8,r2:-0.3004433188696723
epoch = 51,loss = 145375997.53271028,r2 = -0.29408316004876905,val_loss = 159468932.8888889,val_+r2 = -0.04334783075467191
[step = 10] loss:179567093.6,r2:-0.2842432186266657
[step = 20] loss:149460045.4,r2:-0.10196281790732879
[step = 30] loss:141336370.4,r2:-0.094180434389529
[step = 40] loss:140316017.5,r2:-0.190510893917866
[step = 50] loss:153719934.32,r2:-0.13511791952527968
[step = 60] loss:150956345.8,r2:-0.13755633681912063
[step = 70] loss:150669442.34285715,r2:-0.1165186837268152
[step = 80] loss:146503322.275,r2:-0.17743829700510724


  6%|█████▎                                                                           | 52/800 [00:26<14:59,  1.20s/it]

[step = 90] loss:147564330.2888889,r2:-0.17817892623367648
[step = 100] loss:145080867.02,r2:-0.1561339181823571
epoch = 52,loss = 143811407.42056075,r2 = -0.16341990099681372,val_loss = 165095594.74074075,val_+r2 = -0.03043106803583013
[step = 10] loss:157096332.0,r2:-0.008635546216525536
[step = 20] loss:138484563.8,r2:-0.0536201330147791
[step = 30] loss:131504466.93333334,r2:-0.16153084393614262
[step = 40] loss:123359135.875,r2:-0.1532163468458392
[step = 50] loss:123068289.66,r2:-0.22894846127259946
[step = 60] loss:133946652.18333334,r2:-0.1806317617606369


  7%|█████▎                                                                           | 53/800 [00:27<12:38,  1.01s/it]

[step = 70] loss:137125126.3857143,r2:-0.2121423114504197
[step = 80] loss:141728973.9375,r2:-0.19495548751022243
[step = 90] loss:142371356.83333334,r2:-0.18581497870118938
[step = 100] loss:142348146.63,r2:-0.17819382459766542
epoch = 53,loss = 145317399.91588786,r2 = -0.1678313345284009,val_loss = 159265120.44444445,val_+r2 = -0.00031450126172564936
[step = 10] loss:180839993.6,r2:-0.41172281619636414
[step = 20] loss:169980390.2,r2:-0.6015176702395907
[step = 30] loss:167199962.66666666,r2:-0.37628573263827036
[step = 40] loss:165619082.6,r2:-0.27066989124704
[step = 50] loss:159663537.6,r2:-0.27200606618338424
[step = 60] loss:153713399.2,r2:-0.31967654086549985
[step = 70] loss:154942820.34285715,r2:-0.2789845811501281
[step = 80] loss:151817324.85,r2:-0.33878914502967744
[step = 90] loss:145961654.31111112,r2:-0.31922835724594034
[step = 100] loss:148657068.46,r2:-0.2855345900056642


  7%|█████▍                                                                           | 54/800 [00:27<10:59,  1.13it/s]

epoch = 54,loss = 145909607.8317757,r2 = -0.2724902817074665,val_loss = 158433710.66666666,val_+r2 = -0.12318285711336803
[step = 10] loss:122371130.4,r2:-0.15200998701519505
[step = 20] loss:118414993.6,r2:-0.11614450808668406
[step = 30] loss:140846265.33333334,r2:-0.03776018484812679
[step = 40] loss:142681906.8,r2:-0.041889526997806306
[step = 50] loss:142884884.8,r2:-0.08841560469169454
[step = 60] loss:148595501.06666666,r2:-0.09422434027851226
[step = 70] loss:148419466.45714286,r2:-0.2497592027973534
[step = 80] loss:145858343.025,r2:-0.24028545938322918
[step = 90] loss:142153370.82222223,r2:-0.2168266416736431


  7%|█████▌                                                                           | 55/800 [00:28<09:48,  1.27it/s]

[step = 100] loss:141716623.46,r2:-0.25752279899863656
epoch = 55,loss = 146210885.1401869,r2 = -0.23720496743418637,val_loss = 157608856.0,val_+r2 = -0.36762912539013265
[step = 10] loss:129752473.2,r2:-0.5598329193991419
[step = 20] loss:164002090.6,r2:-0.4531802623318978
[step = 30] loss:142229602.53333333,r2:-0.2812729222051217
[step = 40] loss:133077606.2,r2:-0.194476041830612
[step = 50] loss:138664418.48,r2:-0.1734714897847413
[step = 60] loss:135470773.8,r2:-0.17905288722950854
[step = 70] loss:139187462.68571427,r2:-0.17562090885113574


  7%|█████▋                                                                           | 56/800 [00:28<08:57,  1.39it/s]

[step = 80] loss:137776114.775,r2:-0.1522710893166917
[step = 90] loss:142530345.84444445,r2:-0.13669822733370662
[step = 100] loss:146221721.82,r2:-0.1414364186906724
epoch = 56,loss = 144673742.03738317,r2 = -0.16224677217055572,val_loss = 158432541.14814815,val_+r2 = -0.12025349627367922
[step = 10] loss:148687125.6,r2:-0.4063637221881383
[step = 20] loss:135337898.0,r2:-0.2296326907853008
[step = 30] loss:144345598.8,r2:-0.31305997974417565
[step = 40] loss:149832909.7,r2:-0.21790113585059157
[step = 50] loss:155018294.8,r2:-0.17343283318625777
[step = 60] loss:149028317.26666668,r2:-0.1572044550360942


  7%|█████▊                                                                           | 57/800 [00:29<08:20,  1.48it/s]

[step = 70] loss:142210800.5142857,r2:-0.16682082174772814
[step = 80] loss:139330512.05,r2:-0.1694611717968486
[step = 90] loss:139226562.2222222,r2:-0.22395173700230775
[step = 100] loss:143457284.32,r2:-0.21589884697363104
epoch = 57,loss = 145294085.45794392,r2 = -0.20848795808290543,val_loss = 157452406.2222222,val_+r2 = -0.15329279965925394
[step = 10] loss:90923217.2,r2:-0.04925969438605354
[step = 20] loss:116019892.0,r2:-0.17065404974347148
[step = 30] loss:131650185.86666666,r2:-0.06149667324678202
[step = 40] loss:139468928.4,r2:-0.054011552605939386
[step = 50] loss:137615824.08,r2:-0.05625849482430697
[step = 60] loss:134168518.26666667,r2:-0.10774843118608667
[step = 70] loss:138286416.5142857,r2:-0.11674404816070594
[step = 80] loss:141198688.8,r2:-0.142967548532482
[step = 90] loss:140445554.53333333,r2:-0.18908191286323286
[step = 100] loss:143054116.96,r2:-0.1952313409844425


  7%|█████▊                                                                           | 58/800 [00:30<07:55,  1.56it/s]

epoch = 58,loss = 145044968.89719626,r2 = -0.17594554228654574,val_loss = 158357176.8888889,val_+r2 = -0.5040262033928421
[step = 10] loss:152479746.4,r2:-0.025156672376215272
[step = 20] loss:139487053.5,r2:-0.09996938256392929
[step = 30] loss:146991032.86666667,r2:-0.1052194878108305
[step = 40] loss:148164765.05,r2:-0.20392874487032575
[step = 50] loss:146697241.8,r2:-0.15177666544432247
[step = 60] loss:147205510.3,r2:-0.15011730815824323
[step = 70] loss:147829276.77142859,r2:-0.24684665712481443
[step = 80] loss:146559743.35,r2:-0.20085170686786497
[step = 90] loss:152650618.17777777,r2:-0.18632214606748576


  7%|█████▉                                                                           | 59/800 [00:30<07:36,  1.62it/s]

[step = 100] loss:147806324.7,r2:-0.22826627959278808
epoch = 59,loss = 145332307.92523363,r2 = -0.22751601463858726,val_loss = 167829162.5185185,val_+r2 = -0.15720312384337093
[step = 10] loss:103807983.2,r2:-0.09000017254975773
[step = 20] loss:113711977.8,r2:-0.46588318604648904
[step = 30] loss:129223280.6,r2:-0.36238433214944893
[step = 40] loss:129593840.65,r2:-0.3354640354296901
[step = 50] loss:139635329.16,r2:-0.2886330928608349
[step = 60] loss:143298466.16666666,r2:-0.3641071171769407
[step = 70] loss:139823659.4,r2:-0.32784260770449863


  8%|██████                                                                           | 60/800 [00:31<07:24,  1.66it/s]

[step = 80] loss:140863602.625,r2:-0.2866635698140172
[step = 90] loss:137024520.91111112,r2:-0.3844722323408926
[step = 100] loss:141598169.22,r2:-0.34137258294446843
epoch = 60,loss = 144546011.90654206,r2 = -0.31535533672722466,val_loss = 161988223.1111111,val_+r2 = -0.23198872702618484
[step = 10] loss:162615607.6,r2:-0.4288109856277128
[step = 20] loss:156790903.6,r2:-0.41882472950125516
[step = 30] loss:151358797.26666668,r2:-0.37617038134150804
[step = 40] loss:152588822.35,r2:-0.36550901806225095
[step = 50] loss:150588532.84,r2:-0.3308130185126854
[step = 60] loss:142124842.5,r2:-0.3190870886760902


  8%|██████▏                                                                          | 61/800 [00:31<07:14,  1.70it/s]

[step = 70] loss:141306224.37142858,r2:-0.28902659055222146
[step = 80] loss:140204721.625,r2:-0.34151358904682405
[step = 90] loss:145210474.42222223,r2:-0.3268434234133472
[step = 100] loss:144944727.62,r2:-0.2990898175350176
epoch = 61,loss = 144698970.09345794,r2 = -0.32225157828642054,val_loss = 158325548.5185185,val_+r2 = -0.050620271231674244
[step = 10] loss:174982144.4,r2:-0.40458127381810466
[step = 20] loss:146639971.6,r2:-0.3535315391449283
[step = 30] loss:153091504.53333333,r2:-0.22311980807351572
[step = 40] loss:161463729.7,r2:-0.1816796031556735
[step = 50] loss:154053512.08,r2:-0.1893865097703505
[step = 60] loss:156745715.93333334,r2:-0.17115794917943117
[step = 70] loss:153470324.62857142,r2:-0.16696415580270138
[step = 80] loss:152563049.95,r2:-0.19071810953620438
[step = 90] loss:146010944.65555555,r2:-0.18697895068861506


  8%|██████▎                                                                          | 62/800 [00:32<07:08,  1.72it/s]

[step = 100] loss:146626281.03,r2:-0.17763769114925368
epoch = 62,loss = 146478857.2990654,r2 = -0.18170318497618762,val_loss = 158300777.85185185,val_+r2 = -0.03166914631785008
[step = 10] loss:107543952.8,r2:-0.43744422398861815
[step = 20] loss:137606676.0,r2:-0.23588142262833597
[step = 30] loss:143329463.2,r2:-0.4631286605031334
[step = 40] loss:139816348.8,r2:-0.38413183926495614
[step = 50] loss:161506743.12,r2:-0.3626214417896373
[step = 60] loss:156089585.8,r2:-0.42769319168770964
[step = 70] loss:151044495.54285714,r2:-0.38912139940398155
[step = 80] loss:152839007.0,r2:-0.34593796177063474


  8%|██████▍                                                                          | 63/800 [00:32<07:02,  1.74it/s]

[step = 90] loss:147690872.44444445,r2:-0.40542939858714894
[step = 100] loss:152010735.12,r2:-0.3876260797648494
epoch = 63,loss = 150188481.86915886,r2 = -0.393895363162631,val_loss = 160644287.25925925,val_+r2 = -0.4219627393611897
[step = 10] loss:134519580.4,r2:-0.25564905153272044
[step = 20] loss:153486901.1,r2:-0.21096905413380757
[step = 30] loss:151485753.26666668,r2:-0.18252497274138385
[step = 40] loss:154983831.25,r2:-0.3839573096441856
[step = 50] loss:153475763.56,r2:-0.3689175470964603
[step = 60] loss:152366014.03333333,r2:-0.3442067236302385
[step = 70] loss:154161879.05714285,r2:-0.33495352551933966
[step = 80] loss:147450461.175,r2:-0.41144405856168725
[step = 90] loss:144358404.97777778,r2:-0.3745622942236408
[step = 100] loss:148880106.24,r2:-0.32729356596965586


  8%|██████▍                                                                          | 64/800 [00:33<06:59,  1.75it/s]

epoch = 64,loss = 145882566.46728972,r2 = -0.3229471092397064,val_loss = 157410254.5185185,val_+r2 = -0.28759030605143593
[step = 10] loss:168500953.6,r2:0.027712494860624038
[step = 20] loss:158730722.0,r2:-0.3581423257813742
[step = 30] loss:163645697.33333334,r2:-0.23870597126788506
[step = 40] loss:166107801.0,r2:-0.19096374061920177
[step = 50] loss:165522177.6,r2:-0.24478225362027436
[step = 60] loss:160522544.06666666,r2:-0.23416491104397302
[step = 70] loss:156833295.14285713,r2:-0.21377491147550756
[step = 80] loss:155338374.25,r2:-0.17396339893171753
[step = 90] loss:149343706.31111112,r2:-0.17639137338973868


  8%|██████▌                                                                          | 65/800 [00:33<06:57,  1.76it/s]

[step = 100] loss:145565977.28,r2:-0.16177030054744407
epoch = 65,loss = 145876549.58878505,r2 = -0.1685546542870398,val_loss = 157279744.0,val_+r2 = -0.04156502696674209
[step = 10] loss:166046167.2,r2:0.08566853286442341
[step = 20] loss:183221420.0,r2:-0.06830021130715191
[step = 30] loss:169652350.8,r2:-0.07997064395594018
[step = 40] loss:166998588.2,r2:-0.062307181747544085
[step = 50] loss:160712237.6,r2:-0.11544169156322379
[step = 60] loss:164359384.13333333,r2:-0.08796167932113766
[step = 70] loss:163709374.8,r2:-0.08559265897455849
[step = 80] loss:156814057.2,r2:-0.10677790360874576
[step = 90] loss:148935256.04444444,r2:-0.10460868215952798
[step = 100] loss:145871945.51,r2:-0.11567547545939856


  8%|██████▋                                                                          | 66/800 [00:34<06:58,  1.75it/s]

epoch = 66,loss = 143731968.25233644,r2 = -0.09715977818492073,val_loss = 159430277.4074074,val_+r2 = 0.015457005196617865
[step = 10] loss:107191527.8,r2:-0.10444390500580858
[step = 20] loss:142590795.5,r2:-0.003618944258244927
[step = 30] loss:137256295.0,r2:-0.040345944299663115
[step = 40] loss:136591723.45,r2:-0.07131698860930044
[step = 50] loss:150443868.52,r2:-0.05838123761197249
[step = 60] loss:146499244.7,r2:-0.219636352406588
[step = 70] loss:139703033.77142859,r2:-0.21895680479691318
[step = 80] loss:139624127.65,r2:-0.22346515485690116
[step = 90] loss:141870184.2222222,r2:-0.22528089863597575


  8%|██████▊                                                                          | 67/800 [00:35<06:56,  1.76it/s]

[step = 100] loss:146051815.64,r2:-0.19652782965906632
epoch = 67,loss = 145524134.50467288,r2 = -0.18102926934920235,val_loss = 158948333.7777778,val_+r2 = -0.10218719978296364
[step = 10] loss:204193078.4,r2:-0.034495972406976716
[step = 20] loss:152864024.0,r2:-0.5969624263716043
[step = 30] loss:136612299.96666667,r2:-0.4414961937135129
[step = 40] loss:148196340.975,r2:-0.33071439009860903
[step = 50] loss:150561383.5,r2:-0.45952015977657673
[step = 60] loss:154532463.05,r2:-0.41016898223857096
[step = 70] loss:145472861.12857142,r2:-0.39243219472457713
[step = 80] loss:145438243.4375,r2:-0.33901958150466616


  8%|██████▉                                                                          | 68/800 [00:35<06:56,  1.76it/s]

[step = 90] loss:147464161.76666668,r2:-0.3522439337797288
[step = 100] loss:149167690.95,r2:-0.3273512503165953
epoch = 68,loss = 145892294.19626167,r2 = -0.324423238882012,val_loss = 165663003.1851852,val_+r2 = -0.0279311696281948
[step = 10] loss:157457267.2,r2:0.011773479241477902
[step = 20] loss:156588045.6,r2:0.005533861582356336
[step = 30] loss:151628125.2,r2:0.0015054336401340758
[step = 40] loss:145789124.6,r2:-0.019884542343998383
[step = 50] loss:151974164.0,r2:-0.02816699695440456
[step = 60] loss:151363049.13333333,r2:-0.18621113192198047
[step = 70] loss:144597898.1142857,r2:-0.2050685032256547


  9%|██████▉                                                                          | 69/800 [00:36<06:54,  1.76it/s]

[step = 80] loss:142309179.3,r2:-0.18724644003167307
[step = 90] loss:141941385.7777778,r2:-0.15849450820611408
[step = 100] loss:144845830.2,r2:-0.14237889613781024
epoch = 69,loss = 145421782.1682243,r2 = -0.142909067787897,val_loss = 160504367.1111111,val_+r2 = -0.0887865097800388
[step = 10] loss:146170185.2,r2:-0.2522668975610623
[step = 20] loss:131665016.6,r2:-0.15176632618096927
[step = 30] loss:141677540.53333333,r2:-0.11565364386576078
[step = 40] loss:145807838.6,r2:-0.16945623208585517
[step = 50] loss:139976850.56,r2:-0.1476801948112827
[step = 60] loss:138666919.0,r2:-0.12446299915315447
[step = 70] loss:141473305.37142858,r2:-0.12653997765271585
[step = 80] loss:138702711.45,r2:-0.11938595724395065
[step = 90] loss:140300993.7111111,r2:-0.13705387962341592
[step = 100] loss:140772455.7,r2:-0.18645695816375657
epoch = 70,loss = 143717272.317757,r2 = -0.15934809311023213,val_loss = 157316582.5185185,val_+r2 = -0.34842829983610457


  9%|███████                                                                          | 70/800 [00:36<06:54,  1.76it/s]

[step = 10] loss:123026027.6,r2:-0.35938758501805523
[step = 20] loss:125908655.1,r2:-0.33004984585423686
[step = 30] loss:135335961.2,r2:-0.2643948779917609
[step = 40] loss:144427978.9,r2:-0.24964538676041345
[step = 50] loss:137509430.96,r2:-0.31688894612176793
[step = 60] loss:140574073.66666666,r2:-0.275580713958462
[step = 70] loss:143352803.42857143,r2:-0.2418353433062698
[step = 80] loss:144397737.15,r2:-0.25748038652932864
[step = 90] loss:143661157.75555557,r2:-0.2285870166496601


  9%|███████▏                                                                         | 71/800 [00:37<06:57,  1.75it/s]

[step = 100] loss:146629837.26,r2:-0.20156065742094906
epoch = 71,loss = 145892449.47663552,r2 = -0.18601130849497682,val_loss = 158279816.74074075,val_+r2 = -0.23935698477323072
[step = 10] loss:171505041.2,r2:-0.01799850319199895
[step = 20] loss:145531150.9,r2:-0.1121637558309441
[step = 30] loss:146129332.33333334,r2:-0.33782068888877775
[step = 40] loss:139701515.95,r2:-0.26730988426708613
[step = 50] loss:134915357.48,r2:-0.2254111553745806
[step = 60] loss:145118242.7,r2:-0.406642256956327
[step = 70] loss:150381761.62857142,r2:-0.3615618101135437


  9%|███████▎                                                                         | 72/800 [00:37<06:55,  1.75it/s]

[step = 80] loss:147926593.775,r2:-0.34258180215885303
[step = 90] loss:148762158.2,r2:-0.29583337021191347
[step = 100] loss:148605691.06,r2:-0.2991566285782969
epoch = 72,loss = 145155062.48598132,r2 = -0.3046913706822222,val_loss = 162259985.55555555,val_+r2 = -0.0023905619276715045
[step = 10] loss:124182013.2,r2:0.061639084340750126
[step = 20] loss:135712918.2,r2:0.036153149726698156
[step = 30] loss:136234020.66666666,r2:-0.04618318706687109
[step = 40] loss:146001626.0,r2:-0.07011064999450721
[step = 50] loss:148560726.6,r2:-0.06898753509619629
[step = 60] loss:145766400.56666666,r2:-0.050258774611909175
[step = 70] loss:157642387.1142857,r2:-0.04885747846210664
[step = 80] loss:154102130.125,r2:-0.045984884900978115
[step = 90] loss:148639858.8,r2:-0.04968226238542783
[step = 100] loss:146957480.72,r2:-0.05898050742368536


  9%|███████▍                                                                         | 73/800 [00:38<06:55,  1.75it/s]

epoch = 73,loss = 144724376.48598132,r2 = -0.054835225157856174,val_loss = 158727659.55555555,val_+r2 = -0.01751824718827318
[step = 10] loss:116990529.6,r2:0.00831300748706173
[step = 20] loss:134650785.6,r2:-0.10522910628860319
[step = 30] loss:141378467.86666667,r2:-0.09822629381067859
[step = 40] loss:135187782.8,r2:-0.08105176972147198
[step = 50] loss:140211619.28,r2:-0.11632680938477108
[step = 60] loss:139315137.66666666,r2:-0.3217274392022057
[step = 70] loss:142553337.45714286,r2:-0.27908233078461375
[step = 80] loss:142064551.675,r2:-0.2503869849128566
[step = 90] loss:142750592.55555555,r2:-0.21863074765213106


  9%|███████▍                                                                         | 74/800 [00:39<06:51,  1.76it/s]

[step = 100] loss:141631956.22,r2:-0.22521075230402063
epoch = 74,loss = 144182829.682243,r2 = -0.2111445035929469,val_loss = 157374166.07407406,val_+r2 = -0.28206148466366565
[step = 10] loss:143447487.2,r2:-0.3197219088354459
[step = 20] loss:139415304.4,r2:-0.27883613605180935
[step = 30] loss:121099203.46666667,r2:-0.26431838711126326
[step = 40] loss:136586064.0,r2:-0.2075112651490497
[step = 50] loss:144061536.24,r2:-0.17853941405594548
[step = 60] loss:139664901.6,r2:-0.1584961091416524
[step = 70] loss:140596678.0,r2:-0.19371278863592337
[step = 80] loss:147265822.15,r2:-0.22474364933460195


  9%|███████▌                                                                         | 75/800 [00:39<06:49,  1.77it/s]

[step = 90] loss:142144816.91111112,r2:-0.21326430355661052
[step = 100] loss:144736642.98,r2:-0.2061217420490333
epoch = 75,loss = 145850676.91588786,r2 = -0.19616733498342953,val_loss = 166097312.5925926,val_+r2 = -0.8354401871282989
[step = 10] loss:117674084.4,r2:-1.0983616450474685
[step = 20] loss:123069418.6,r2:-0.7005673391458792
[step = 30] loss:118455111.06666666,r2:-0.4959367088574317
[step = 40] loss:127146153.4,r2:-0.37425728287782534
[step = 50] loss:142166779.12,r2:-0.28756271393601185
[step = 60] loss:147890754.26666668,r2:-0.28650016068535605
[step = 70] loss:148894627.08571428,r2:-0.3056124697557501
[step = 80] loss:146171691.975,r2:-0.27353034236945584
[step = 90] loss:142894036.33333334,r2:-0.24276801339168214
[step = 100] loss:142278381.98,r2:-0.25813711782119264


 10%|███████▋                                                                         | 76/800 [00:40<06:48,  1.77it/s]

epoch = 76,loss = 146114695.3084112,r2 = -0.2532734704675702,val_loss = 158250876.44444445,val_+r2 = -0.2638819814821139
[step = 10] loss:203993623.2,r2:-0.732277742214562
[step = 20] loss:167936283.6,r2:-0.5267515571606092
[step = 30] loss:169012888.53333333,r2:-0.4054113723100208
[step = 40] loss:161391672.2,r2:-0.3564192105642711
[step = 50] loss:156868587.68,r2:-0.3078242236346999
[step = 60] loss:146513456.5,r2:-0.29736324653369933
[step = 70] loss:142539273.77142859,r2:-0.26096787653199527
[step = 80] loss:146780612.15,r2:-0.260544407072692
[step = 90] loss:143819233.02222222,r2:-0.2632682133242293


 10%|███████▊                                                                         | 77/800 [00:40<06:47,  1.77it/s]

[step = 100] loss:145530444.14,r2:-0.23900042179627815
epoch = 77,loss = 146274268.91588786,r2 = -0.22462283246575523,val_loss = 158600834.96296296,val_+r2 = -0.4625827937079955
[step = 10] loss:179568557.6,r2:-0.22749548815826098
[step = 20] loss:167270130.6,r2:-0.1413285948574244
[step = 30] loss:155758696.53333333,r2:-0.09284612483907972
[step = 40] loss:140223361.65,r2:-0.17308103252859114
[step = 50] loss:142551054.44,r2:-0.17273140854191632
[step = 60] loss:140331269.1,r2:-0.12256507418617761
[step = 70] loss:144385439.22857141,r2:-0.10118119140077088
[step = 80] loss:140646720.375,r2:-0.13272700870402396
[step = 90] loss:144030985.75555557,r2:-0.12632663266300118
[step = 100] loss:141650246.46,r2:-0.16156478907991825


 10%|███████▉                                                                         | 78/800 [00:41<06:48,  1.77it/s]

epoch = 78,loss = 144731776.6915888,r2 = -0.1514002879165535,val_loss = 157801610.96296296,val_+r2 = -0.056175928701649
[step = 10] loss:126919517.6,r2:-0.13266519046525832
[step = 20] loss:143560211.6,r2:-0.2241005255849732
[step = 30] loss:141352000.06666666,r2:-0.19856215194975066
[step = 40] loss:151835709.55,r2:-0.3713858022778454
[step = 50] loss:149643461.12,r2:-0.3160285406042743
[step = 60] loss:154083964.53333333,r2:-0.2835055985728514
[step = 70] loss:150569215.42857143,r2:-0.25922403406180144
[step = 80] loss:157148164.0,r2:-0.2293400799413714
[step = 90] loss:153839420.93333334,r2:-0.2560761674216892


 10%|███████▉                                                                         | 79/800 [00:41<06:47,  1.77it/s]

[step = 100] loss:150336645.08,r2:-0.2551881743989986
epoch = 79,loss = 145445878.59813085,r2 = -0.23946795552092204,val_loss = 164962334.44444445,val_+r2 = -0.044024694956937914
[step = 10] loss:199185410.0,r2:0.05144092049915329
[step = 20] loss:176854043.6,r2:-0.12127290888360742
[step = 30] loss:164193825.06666666,r2:-0.27350971119231515
[step = 40] loss:157862165.95,r2:-0.20922957579475732
[step = 50] loss:149734101.48,r2:-0.2261826607179025
[step = 60] loss:146453636.63333333,r2:-0.1831720523613843
[step = 70] loss:142060085.5142857,r2:-0.25576951510215984


 10%|████████                                                                         | 80/800 [00:42<06:46,  1.77it/s]

[step = 80] loss:139956416.8,r2:-0.22097519302599755
[step = 90] loss:141326744.13333333,r2:-0.20517289621953416
[step = 100] loss:142557461.48,r2:-0.1844616013514417
epoch = 80,loss = 144295525.34579438,r2 = -0.21961226747317325,val_loss = 160918202.37037036,val_+r2 = 0.041810713768862175
[step = 10] loss:156705087.2,r2:0.08867159523123194
[step = 20] loss:138298163.4,r2:-0.08583846126092351
[step = 30] loss:151493108.93333334,r2:-0.05375557290901236
[step = 40] loss:144641805.3,r2:-0.14355622049015063
[step = 50] loss:145921425.16,r2:-0.14215990420913097
[step = 60] loss:139822693.36666667,r2:-0.15965872189406533
[step = 70] loss:146872050.65714285,r2:-0.15233054790230569
[step = 80] loss:151624410.625,r2:-0.11893905098439927
[step = 90] loss:147616575.08888888,r2:-0.13839499008207715
[step = 100] loss:142975719.22,r2:-0.1323689402837923


 10%|████████▏                                                                        | 81/800 [00:43<06:47,  1.76it/s]

epoch = 81,loss = 144851080.95327103,r2 = -0.13465539177514485,val_loss = 160381302.74074075,val_+r2 = -0.18663130094375652
[step = 10] loss:175384082.8,r2:-1.309116591820874
[step = 20] loss:150627974.0,r2:-0.8087023129594714
[step = 30] loss:168801695.2,r2:-0.6209303266012169
[step = 40] loss:149742476.0,r2:-0.47791765808599
[step = 50] loss:149549038.44,r2:-0.47368840019633984
[step = 60] loss:144063901.6,r2:-0.4021050276713208
[step = 70] loss:141757904.0,r2:-0.352481607258487
[step = 80] loss:147638201.9,r2:-0.4576273003127934
[step = 90] loss:145519972.7111111,r2:-0.40247639131701973


 10%|████████▎                                                                        | 82/800 [00:43<06:51,  1.74it/s]

[step = 100] loss:143483586.14,r2:-0.3580934456655993
epoch = 82,loss = 144551496.24299064,r2 = -0.3403598884492594,val_loss = 159542966.5185185,val_+r2 = -0.08913464761309234
[step = 10] loss:156909073.2,r2:-0.24423239246623735
[step = 20] loss:151111171.6,r2:-0.19555913225969873
[step = 30] loss:146944157.06666666,r2:-0.12097529286799906
[step = 40] loss:156425726.4,r2:-0.10445423100452896
[step = 50] loss:155637205.12,r2:-0.08414914112022742
[step = 60] loss:151223222.53333333,r2:-0.08916112329493404
[step = 70] loss:146620307.22857141,r2:-0.09600130586160865
[step = 80] loss:146866644.425,r2:-0.09500958077794534
[step = 90] loss:143642749.75555557,r2:-0.10133507662125112
[step = 100] loss:144853502.94,r2:-0.09676407130775427
epoch = 83,loss = 144092394.52336448,r2 = -0.08535107129985668,val_loss = 158860343.7037037,val_+r2 = -0.09172127860697284


 10%|████████▍                                                                        | 83/800 [00:44<06:51,  1.74it/s]

[step = 10] loss:178893896.4,r2:0.008152522895560576
[step = 20] loss:144537893.8,r2:-0.3738022125052734
[step = 30] loss:127458136.53333333,r2:-0.2716568173116628
[step = 40] loss:131071166.8,r2:-0.21668509037239864
[step = 50] loss:142036319.12,r2:-0.31204301606058
[step = 60] loss:136438671.33333334,r2:-0.48234271849534915
[step = 70] loss:134732497.94285715,r2:-0.44133489646847873
[step = 80] loss:130657850.825,r2:-0.3880484043965385
[step = 90] loss:135025522.5111111,r2:-0.34309951041020414


 10%|████████▌                                                                        | 84/800 [00:44<06:51,  1.74it/s]

[step = 100] loss:140366015.38,r2:-0.3122782104109407
epoch = 84,loss = 144572644.65420562,r2 = -0.29317841376435716,val_loss = 165990678.66666666,val_+r2 = -0.3993732994066285
[step = 10] loss:127707262.6,r2:-0.17910550782821105
[step = 20] loss:148666779.1,r2:-0.10541006215391886
[step = 30] loss:156236210.73333332,r2:-0.1197425651315282
[step = 40] loss:145870727.75,r2:-0.1161442435136784
[step = 50] loss:137763043.3,r2:-0.13112620646536097
[step = 60] loss:134180722.11666666,r2:-0.13694315602341361
[step = 70] loss:141774088.44285715,r2:-0.2236851617357321
[step = 80] loss:143064410.0875,r2:-0.2054669543598501
[step = 90] loss:144457419.0111111,r2:-0.1890468839788468
[step = 100] loss:145453148.63,r2:-0.18083661623688088


 11%|████████▌                                                                        | 85/800 [00:45<06:52,  1.73it/s]

epoch = 85,loss = 143684206.19626167,r2 = -0.1969819136259535,val_loss = 161253702.74074075,val_+r2 = -0.45747752160186267
[step = 10] loss:120260755.6,r2:0.002320795792721353
[step = 20] loss:149079990.0,r2:-0.014235931126204082
[step = 30] loss:142988725.2,r2:-0.008663833380107263
[step = 40] loss:147759219.7,r2:0.001961595186846146
[step = 50] loss:145176207.52,r2:-0.010656859589223864
[step = 60] loss:151051100.8,r2:-0.027783501956304968
[step = 70] loss:146380044.8,r2:-0.028861064168696826
[step = 80] loss:142602817.05,r2:-0.0645029681920262
[step = 90] loss:139637340.5111111,r2:-0.07837589848430906


 11%|████████▋                                                                        | 86/800 [00:45<06:51,  1.74it/s]

[step = 100] loss:148305151.34,r2:-0.07176458681396886
epoch = 86,loss = 145649118.63551402,r2 = -0.12911276425855325,val_loss = 158008319.03703704,val_+r2 = -0.07043168852596234
[step = 10] loss:177285676.0,r2:0.035910358414688745
[step = 20] loss:173088381.8,r2:-0.14908117224532272
[step = 30] loss:146587760.0,r2:-0.37728985099830703
[step = 40] loss:136781003.5,r2:-0.31394963206159515
[step = 50] loss:142652664.8,r2:-0.22987287615675947
[step = 60] loss:140595216.53333333,r2:-0.16921115720273558
[step = 70] loss:149666419.77142859,r2:-0.15577005994284376


 11%|████████▊                                                                        | 87/800 [00:46<06:50,  1.74it/s]

[step = 80] loss:148995392.9,r2:-0.16301090890346032
[step = 90] loss:145622285.95555556,r2:-0.17281904747299726
[step = 100] loss:146880046.36,r2:-0.178756013811093
epoch = 87,loss = 145066151.51401868,r2 = -0.20141656576012912,val_loss = 159346404.8888889,val_+r2 = -0.05857572921442012
[step = 10] loss:123064660.4,r2:-0.4144229170500665
[step = 20] loss:125955936.2,r2:-0.1748179963034806
[step = 30] loss:120212694.8,r2:-0.2161416085932564
[step = 40] loss:131644215.75,r2:-0.1327569934044178
[step = 50] loss:137463358.2,r2:-0.0966552978962226
[step = 60] loss:139683924.23333332,r2:-0.20135582573704275
[step = 70] loss:136293791.2,r2:-0.21835870711372501
[step = 80] loss:135542755.9,r2:-0.21167868140984955


 11%|████████▉                                                                        | 88/800 [00:47<06:49,  1.74it/s]

7) Take a look of record

In [None]:
history1.head()

In [None]:
history1.to_csv('./logs/log1.csv',index=False)

In [None]:
def line_plotling(df,metric):
    custom_params = {"axes.spines.right": True, "axes.spines.top": True}
    sns.set_theme(style="ticks", rc=custom_params)
    sns.lineplot(x = 'epoch',y = metric,data =df,color = 'r',markers=True, dashes=False)
    sns.lineplot(x = 'epoch',y = 'val_'+metric,data =df,color = 'b',markers=True, dashes=False)
    plt.legend(["train_"+metric,'ignore ', 'val_'+metric])

In [None]:
line_plotling(history1,'loss')

In [None]:
# 查看r2
line_plotling(history1,'r2')

In [None]:
#y_pred = model(torch.tensor(X_test_tensor).float()).data
#y_pred

In [None]:
print(model.state_dict().keys())

In [None]:
# Save parameters to 'model_parameters' path
#torch.save(model.state_dict(),'./models/model1.pkl')

In [None]:
# Create a new model
#new_model = insurance_net()
# Load parameters to the untrained model
#new_model.load_state_dict(torch.load('./models/model1.pkl'))

In [None]:
#new_model

In [None]:
my_info_list = [[28, 20.11, 0, 1, 1, 0, 1 ,0]]
my_info_array = np.array(my_info_list)
my_info_array = scaler.fit_transform(my_info_array)
my_info_tensor = torch.tensor(my_info_array).float()
model.forward(my_info_tensor)