# Import Modules

In [1]:
import pandas as pd
import torch

device = "cuda" if torch.cuda.is_available else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


# Import dataset

In [2]:
california_houses_df = pd.read_csv("../datasets/california_houses.csv")
california_houses_df.head()

Unnamed: 0,Median_House_Value,Median_Income,Median_Age,Tot_Rooms,Tot_Bedrooms,Population,Households,Latitude,Longitude,Distance_to_coast,Distance_to_LA,Distance_to_SanDiego,Distance_to_SanJose,Distance_to_SanFrancisco
0,452600.0,8.3252,41,880,129,322,126,37.88,-122.23,9263.040773,556529.158342,735501.806984,67432.517001,21250.213767
1,358500.0,8.3014,21,7099,1106,2401,1138,37.86,-122.22,10225.733072,554279.850069,733236.88436,65049.908574,20880.6004
2,352100.0,7.2574,52,1467,190,496,177,37.85,-122.24,8259.085109,554610.717069,733525.682937,64867.289833,18811.48745
3,341300.0,5.6431,52,1274,235,558,219,37.85,-122.25,7768.086571,555194.266086,734095.290744,65287.138412,18031.047568
4,342200.0,3.8462,52,1627,280,565,259,37.85,-122.25,7768.086571,555194.266086,734095.290744,65287.138412,18031.047568


In [3]:
california_houses_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20640 entries, 0 to 20639
Data columns (total 14 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Median_House_Value        20640 non-null  float64
 1   Median_Income             20640 non-null  float64
 2   Median_Age                20640 non-null  int64  
 3   Tot_Rooms                 20640 non-null  int64  
 4   Tot_Bedrooms              20640 non-null  int64  
 5   Population                20640 non-null  int64  
 6   Households                20640 non-null  int64  
 7   Latitude                  20640 non-null  float64
 8   Longitude                 20640 non-null  float64
 9   Distance_to_coast         20640 non-null  float64
 10  Distance_to_LA            20640 non-null  float64
 11  Distance_to_SanDiego      20640 non-null  float64
 12  Distance_to_SanJose       20640 non-null  float64
 13  Distance_to_SanFrancisco  20640 non-null  float64
dtypes: flo

In [4]:
california_houses_df.isna().sum()

Median_House_Value          0
Median_Income               0
Median_Age                  0
Tot_Rooms                   0
Tot_Bedrooms                0
Population                  0
Households                  0
Latitude                    0
Longitude                   0
Distance_to_coast           0
Distance_to_LA              0
Distance_to_SanDiego        0
Distance_to_SanJose         0
Distance_to_SanFrancisco    0
dtype: int64

In [5]:
california_houses_df.describe()

Unnamed: 0,Median_House_Value,Median_Income,Median_Age,Tot_Rooms,Tot_Bedrooms,Population,Households,Latitude,Longitude,Distance_to_coast,Distance_to_LA,Distance_to_SanDiego,Distance_to_SanJose,Distance_to_SanFrancisco
count,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0
mean,206855.816909,3.870671,28.639486,2635.763081,537.898014,1425.476744,499.53968,35.631861,-119.569704,40509.264883,269422.0,398164.9,349187.551219,386688.422291
std,115395.615874,1.899822,12.585558,2181.615252,421.247906,1132.462122,382.329753,2.135952,2.003532,49140.03916,247732.4,289400.6,217149.875026,250122.192316
min,14999.0,0.4999,1.0,2.0,1.0,3.0,1.0,32.54,-124.35,120.676447,420.5891,484.918,569.448118,456.141313
25%,119600.0,2.5634,18.0,1447.75,295.0,787.0,280.0,33.93,-121.8,9079.756762,32111.25,159426.4,113119.928682,117395.477505
50%,179700.0,3.5348,29.0,2127.0,435.0,1166.0,409.0,34.26,-118.49,20522.019101,173667.5,214739.8,459758.877,526546.661701
75%,264725.0,4.74325,37.0,3148.0,647.0,1725.0,605.0,37.71,-118.01,49830.414479,527156.2,705795.4,516946.490963,584552.007907
max,500001.0,15.0001,52.0,39320.0,6445.0,35682.0,6082.0,41.95,-114.31,333804.686371,1018260.0,1196919.0,836762.67821,903627.663298


# Split feature and target

In [6]:
y = california_houses_df.pop("Median_House_Value").values.reshape(-1, 1)
X = california_houses_df.values

print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")

X shape: (20640, 13)
y shape: (20640, 1)


# Split train and test dataset

In [7]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, test_size=0.3)
print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (14448, 13)
y_train shape: (14448, 1)
X_test shape: (6192, 13)
y_test shape: (6192, 1)


In [8]:
device

'cuda'

In [9]:
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train)

X_test = torch.FloatTensor(X_test)
y_test = torch.FloatTensor(y_test)

# Build Architecture

## TensorDataset and DataLoader

In [10]:
from torch.utils.data import TensorDataset, DataLoader

# TensorDataset
tensor_dataset_train = TensorDataset(X_train, y_train)
tensor_dataset_test = TensorDataset(X_test, y_test)

# DataLoader
dataloader_train = DataLoader(tensor_dataset_train, batch_size=32)
dataloader_test = DataLoader(tensor_dataset_test, batch_size=32)

## Build Model

In [11]:
from torch import nn, optim

model = nn.Sequential(          
            # layer 2 (16 neurons) to layer 3 (8 neurons) with ReLU
            nn.Linear(13, 8),
            nn.ReLU6(),
    
            # layer 4 (8 neurons) to layer 5 (4 neurons) with ReLU
            nn.Linear(8, 4),
            nn.ReLU6(),
    
            # layer 6 (4 neurons) to layer 6 (1 neurons)
            nn.Linear(4, 1),
            nn.ReLU6()
        )

model

Sequential(
  (0): Linear(in_features=13, out_features=8, bias=True)
  (1): ReLU6()
  (2): Linear(in_features=8, out_features=4, bias=True)
  (3): ReLU6()
  (4): Linear(in_features=4, out_features=1, bias=True)
  (5): ReLU6()
)

In [12]:
model.state_dict()

OrderedDict([('0.weight',
              tensor([[ 0.1819,  0.0519, -0.0929, -0.2653, -0.2568,  0.0989, -0.0195,  0.0521,
                       -0.1715, -0.1736, -0.2099, -0.0110,  0.0383],
                      [ 0.2041,  0.0763,  0.0155,  0.0271, -0.2541,  0.0116,  0.0223,  0.0280,
                        0.0855,  0.2154, -0.1827,  0.1206, -0.0666],
                      [ 0.2276,  0.0169,  0.2722, -0.1712, -0.0190,  0.0426,  0.0740,  0.2414,
                        0.0065,  0.2074,  0.1585, -0.2765,  0.2658],
                      [-0.1326,  0.1667, -0.1672, -0.0850,  0.2032, -0.1821, -0.0065, -0.2308,
                        0.2216, -0.2628,  0.0929, -0.1886,  0.1308],
                      [ 0.0187,  0.0126,  0.1603, -0.0721,  0.1696,  0.1663,  0.2572, -0.1217,
                        0.0990,  0.0159,  0.2468, -0.0479,  0.1073],
                      [-0.0563,  0.1525, -0.1900, -0.0727,  0.0334,  0.0844,  0.2487, -0.2201,
                       -0.2651, -0.0910, -0.2458,  0.1693, 

## Loss Function

In [13]:
mse_loss = nn.MSELoss()
mse_loss

MSELoss()

## Optimizer

In [14]:
adamw_optimizer = optim.AdamW(model.parameters(), lr=0.001,)
adamw_optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)

## Loop function

In [15]:
from torchmetrics.functional import r2_score

def train_loop(dataloader, model, loss_fn, optimizer_fn):
    size = len(dataloader.dataset)
    loss_batches = []
    r2_score_batches = []
    
    for batch, (X, y) in enumerate(dataloader_train):
        # Forwardpropagation
        pred = model(X)
        loss = loss_fn(pred, y)
        r_2_score = r2_score(pred, y)
        
        # Backpropagation
        optimizer_fn.zero_grad()
        loss.backward()
        optimizer_fn.step()
        
        loss = loss.item()
        loss_batches.append(loss)
        
        r_2_score = r_2_score.item()
        r2_score_batches.append(r_2_score)
        
        current_batch = (batch + 1) * len(X)
        
        print(f"batch {batch + 1}: [{current_batch:>3d}/{size:>3d}] | r2-score: {r_2_score:>8f} | loss: {loss:>7f}")
    print(f"Mean R2 score  : {sum(r2_score_batches) / len(r2_score_batches)}")
    print(f"Mean RMSE loss : {(sum(loss_batches) / len(loss_batches))**0.5}")
    
def test_loop(dataloader, model, loss_fn, optimizer_fn):
    size = len(dataloader.dataset)
    loss_batches = []
    r2_score_batches = []
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader_test):
            pred = model(X)
            loss = loss_fn(pred, y)
            r_2_score = r2_score(pred, y)

            loss = loss.item()
            loss_batches.append(loss)

            r_2_score = r_2_score.item()
            r2_score_batches.append(r_2_score)

            current_batch = (batch + 1) * len(X)

            print(f"batch {batch + 1}: [{current_batch:>3d}/{size:>3d}] | r2-score: {r_2_score:>8f} | loss: {loss:>7f}")
    print(f"Mean R2 score  : {sum(r2_score_batches) / len(r2_score_batches)}")
    print(f"Mean RMSE loss : {(sum(loss_batches) / len(loss_batches))**0.5}")

In [16]:
epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    
    print(f"Train")
    train_loop(dataloader_train, model, mse_loss, adamw_optimizer)
    
    print(f"\nTest")
    test_loop(dataloader_test, model, mse_loss, adamw_optimizer)
    print("-" * 80, end="\n\n")
print("Done!")

Epoch 1
Train
batch 1: [ 32/14448] | r2-score: -3.937635 | loss: 54569484288.000000
batch 2: [ 64/14448] | r2-score: -3.255689 | loss: 74724442112.000000
batch 3: [ 96/14448] | r2-score: -3.136192 | loss: 46658543616.000000
batch 4: [128/14448] | r2-score: -2.595669 | loss: 61764198400.000000
batch 5: [160/14448] | r2-score: -4.067076 | loss: 48541401088.000000
batch 6: [192/14448] | r2-score: -2.962251 | loss: 82635874304.000000
batch 7: [224/14448] | r2-score: -3.474829 | loss: 61706727424.000000
batch 8: [256/14448] | r2-score: -3.227703 | loss: 58890805248.000000
batch 9: [288/14448] | r2-score: -3.821945 | loss: 42836062208.000000
batch 10: [320/14448] | r2-score: -3.288178 | loss: 83215564800.000000
batch 11: [352/14448] | r2-score: -3.815562 | loss: 43159638016.000000
batch 12: [384/14448] | r2-score: -1.996954 | loss: 60432392192.000000
batch 13: [416/14448] | r2-score: -2.416570 | loss: 66400247808.000000
batch 14: [448/14448] | r2-score: -2.942170 | loss: 57085325312.000000
b

batch 116: [3712/14448] | r2-score: -2.601739 | loss: 48641368064.000000
batch 117: [3744/14448] | r2-score: -2.414576 | loss: 64811630592.000000
batch 118: [3776/14448] | r2-score: -2.979099 | loss: 67987054592.000000
batch 119: [3808/14448] | r2-score: -3.680471 | loss: 50550800384.000000
batch 120: [3840/14448] | r2-score: -4.758480 | loss: 61058289664.000000
batch 121: [3872/14448] | r2-score: -3.952586 | loss: 20654643200.000000
batch 122: [3904/14448] | r2-score: -2.988356 | loss: 52063850496.000000
batch 123: [3936/14448] | r2-score: -4.974672 | loss: 42904244224.000000
batch 124: [3968/14448] | r2-score: -2.626536 | loss: 46596788224.000000
batch 125: [4000/14448] | r2-score: -2.653584 | loss: 39928430592.000000
batch 126: [4032/14448] | r2-score: -4.313872 | loss: 66760396800.000000
batch 127: [4064/14448] | r2-score: -3.895627 | loss: 51704283136.000000
batch 128: [4096/14448] | r2-score: -3.813733 | loss: 55132794880.000000
batch 129: [4128/14448] | r2-score: -2.865551 | los

batch 242: [7744/14448] | r2-score: -3.034055 | loss: 67593441280.000000
batch 243: [7776/14448] | r2-score: -3.651489 | loss: 48890011648.000000
batch 244: [7808/14448] | r2-score: -3.485617 | loss: 41474269184.000000
batch 245: [7840/14448] | r2-score: -2.672833 | loss: 65623244800.000000
batch 246: [7872/14448] | r2-score: -3.487311 | loss: 76273287168.000000
batch 247: [7904/14448] | r2-score: -4.002019 | loss: 42759778304.000000
batch 248: [7936/14448] | r2-score: -2.835936 | loss: 51175006208.000000
batch 249: [7968/14448] | r2-score: -2.758041 | loss: 43950075904.000000
batch 250: [8000/14448] | r2-score: -4.286229 | loss: 41767714816.000000
batch 251: [8032/14448] | r2-score: -4.678197 | loss: 68732452864.000000
batch 252: [8064/14448] | r2-score: -4.742980 | loss: 39951065088.000000
batch 253: [8096/14448] | r2-score: -2.693073 | loss: 61504036864.000000
batch 254: [8128/14448] | r2-score: -3.422594 | loss: 40512749568.000000
batch 255: [8160/14448] | r2-score: -2.482856 | los

batch 442: [14144/14448] | r2-score: -2.462486 | loss: 71980859392.000000
batch 443: [14176/14448] | r2-score: -2.818179 | loss: 72616984576.000000
batch 444: [14208/14448] | r2-score: -3.057086 | loss: 60487495680.000000
batch 445: [14240/14448] | r2-score: -3.373385 | loss: 46070071296.000000
batch 446: [14272/14448] | r2-score: -2.837485 | loss: 46878265344.000000
batch 447: [14304/14448] | r2-score: -2.896800 | loss: 44405940224.000000
batch 448: [14336/14448] | r2-score: -4.931195 | loss: 51072716800.000000
batch 449: [14368/14448] | r2-score: -2.785129 | loss: 52651515904.000000
batch 450: [14400/14448] | r2-score: -4.650333 | loss: 47443722240.000000
batch 451: [14432/14448] | r2-score: -6.415433 | loss: 33703960576.000000
batch 452: [7232/14448] | r2-score: -3.532365 | loss: 37275439104.000000
Mean R2 score  : -3.5109738364683842
Mean RMSE loss : 236630.43952756553

Test
batch 1: [ 32/6192] | r2-score: -3.900288 | loss: 57208958976.000000
batch 2: [ 64/6192] | r2-score: -3.2605

batch 192: [6144/6192] | r2-score: -2.613381 | loss: 64296345600.000000
batch 193: [6176/6192] | r2-score: -3.173014 | loss: 53904875520.000000
batch 194: [3104/6192] | r2-score: -3.693748 | loss: 71074881536.000000
Mean R2 score  : -3.531563279555016
Mean RMSE loss : 237373.8872806393
--------------------------------------------------------------------------------

Epoch 2
Train
batch 1: [ 32/14448] | r2-score: -3.937423 | loss: 54567145472.000000
batch 2: [ 64/14448] | r2-score: -3.255535 | loss: 74721722368.000000
batch 3: [ 96/14448] | r2-score: -3.136001 | loss: 46656389120.000000
batch 4: [128/14448] | r2-score: -2.595532 | loss: 61761847296.000000
batch 5: [160/14448] | r2-score: -4.066849 | loss: 48539230208.000000
batch 6: [192/14448] | r2-score: -2.962117 | loss: 82633097216.000000
batch 7: [224/14448] | r2-score: -3.474656 | loss: 61704343552.000000
batch 8: [256/14448] | r2-score: -3.227534 | loss: 58888454144.000000
batch 9: [288/14448] | r2-score: -3.821715 | loss: 428340

batch 112: [3584/14448] | r2-score: -2.746157 | loss: 53708263424.000000
batch 113: [3616/14448] | r2-score: -2.871250 | loss: 57898319872.000000
batch 114: [3648/14448] | r2-score: -3.664113 | loss: 61765238784.000000
batch 115: [3680/14448] | r2-score: -3.210410 | loss: 61323534336.000000
batch 116: [3712/14448] | r2-score: -2.601656 | loss: 48640245760.000000
batch 117: [3744/14448] | r2-score: -2.414514 | loss: 64810459136.000000
batch 118: [3776/14448] | r2-score: -2.979024 | loss: 67985772544.000000
batch 119: [3808/14448] | r2-score: -3.680361 | loss: 50549608448.000000
batch 120: [3840/14448] | r2-score: -4.758348 | loss: 61056901120.000000
batch 121: [3872/14448] | r2-score: -3.952405 | loss: 20653891584.000000
batch 122: [3904/14448] | r2-score: -2.988273 | loss: 52062777344.000000
batch 123: [3936/14448] | r2-score: -4.974539 | loss: 42903293952.000000
batch 124: [3968/14448] | r2-score: -2.626458 | loss: 46595788800.000000
batch 125: [4000/14448] | r2-score: -2.653504 | los

batch 251: [8032/14448] | r2-score: -4.678193 | loss: 68732403712.000000
batch 252: [8064/14448] | r2-score: -4.742976 | loss: 39951044608.000000
batch 253: [8096/14448] | r2-score: -2.693070 | loss: 61503995904.000000
batch 254: [8128/14448] | r2-score: -3.422592 | loss: 40512729088.000000
batch 255: [8160/14448] | r2-score: -2.482856 | loss: 47083216896.000000
batch 256: [8192/14448] | r2-score: -4.282856 | loss: 54878896128.000000
batch 257: [8224/14448] | r2-score: -3.809792 | loss: 72127569920.000000
batch 258: [8256/14448] | r2-score: -3.983166 | loss: 59337818112.000000
batch 259: [8288/14448] | r2-score: -4.690547 | loss: 47141265408.000000
batch 260: [8320/14448] | r2-score: -3.403296 | loss: 53052121088.000000
batch 261: [8352/14448] | r2-score: -3.568373 | loss: 66100502528.000000
batch 262: [8384/14448] | r2-score: -4.651234 | loss: 52573659136.000000
batch 263: [8416/14448] | r2-score: -2.957242 | loss: 41605455872.000000
batch 264: [8448/14448] | r2-score: -3.443022 | los

batch 381: [12192/14448] | r2-score: -3.394852 | loss: 49255936000.000000
batch 382: [12224/14448] | r2-score: -3.300272 | loss: 85746655232.000000
batch 383: [12256/14448] | r2-score: -3.570450 | loss: 44668928000.000000
batch 384: [12288/14448] | r2-score: -3.453694 | loss: 77608706048.000000
batch 385: [12320/14448] | r2-score: -5.359710 | loss: 56365166592.000000
batch 386: [12352/14448] | r2-score: -3.006083 | loss: 76639485952.000000
batch 387: [12384/14448] | r2-score: -2.126977 | loss: 68543987712.000000
batch 388: [12416/14448] | r2-score: -2.907747 | loss: 66930733056.000000
batch 389: [12448/14448] | r2-score: -2.890823 | loss: 68520370176.000000
batch 390: [12480/14448] | r2-score: -3.296122 | loss: 58432897024.000000
batch 391: [12512/14448] | r2-score: -3.031600 | loss: 67202793472.000000
batch 392: [12544/14448] | r2-score: -3.109540 | loss: 57616412672.000000
batch 393: [12576/14448] | r2-score: -2.594259 | loss: 67677356032.000000
batch 394: [12608/14448] | r2-score: -

batch 93: [2976/6192] | r2-score: -2.796594 | loss: 56486768640.000000
batch 94: [3008/6192] | r2-score: -4.152338 | loss: 56743825408.000000
batch 95: [3040/6192] | r2-score: -4.618212 | loss: 55295135744.000000
batch 96: [3072/6192] | r2-score: -3.307025 | loss: 48638509056.000000
batch 97: [3104/6192] | r2-score: -2.539963 | loss: 52396253184.000000
batch 98: [3136/6192] | r2-score: -3.886841 | loss: 37972545536.000000
batch 99: [3168/6192] | r2-score: -3.296240 | loss: 78477033472.000000
batch 100: [3200/6192] | r2-score: -3.910409 | loss: 56994668544.000000
batch 101: [3232/6192] | r2-score: -6.220920 | loss: 38557065216.000000
batch 102: [3264/6192] | r2-score: -3.972777 | loss: 72061337600.000000
batch 103: [3296/6192] | r2-score: -2.892732 | loss: 49871388672.000000
batch 104: [3328/6192] | r2-score: -3.439802 | loss: 44475301888.000000
batch 105: [3360/6192] | r2-score: -3.735047 | loss: 71495155712.000000
batch 106: [3392/6192] | r2-score: -3.725651 | loss: 45127856128.000000

batch 91: [2912/14448] | r2-score: -4.158280 | loss: 42359644160.000000
batch 92: [2944/14448] | r2-score: -3.676334 | loss: 61644173312.000000
batch 93: [2976/14448] | r2-score: -3.636378 | loss: 49376428032.000000
batch 94: [3008/14448] | r2-score: -5.835763 | loss: 37553848320.000000
batch 95: [3040/14448] | r2-score: -2.402939 | loss: 68973240320.000000
batch 96: [3072/14448] | r2-score: -3.123542 | loss: 39038390272.000000
batch 97: [3104/14448] | r2-score: -3.670782 | loss: 35982442496.000000
batch 98: [3136/14448] | r2-score: -3.648358 | loss: 45316657152.000000
batch 99: [3168/14448] | r2-score: -2.193394 | loss: 70684901376.000000
batch 100: [3200/14448] | r2-score: -3.548864 | loss: 62534574080.000000
batch 101: [3232/14448] | r2-score: -3.582792 | loss: 55717756928.000000
batch 102: [3264/14448] | r2-score: -3.213405 | loss: 56163274752.000000
batch 103: [3296/14448] | r2-score: -2.487469 | loss: 59570118656.000000
batch 104: [3328/14448] | r2-score: -2.781914 | loss: 643686

batch 204: [6528/14448] | r2-score: -4.079494 | loss: 69615419392.000000
batch 205: [6560/14448] | r2-score: -4.295311 | loss: 42407211008.000000
batch 206: [6592/14448] | r2-score: -5.143575 | loss: 41950765056.000000
batch 207: [6624/14448] | r2-score: -2.410789 | loss: 52592652288.000000
batch 208: [6656/14448] | r2-score: -4.402518 | loss: 65966272512.000000
batch 209: [6688/14448] | r2-score: -3.166756 | loss: 37566222336.000000
batch 210: [6720/14448] | r2-score: -2.790542 | loss: 70525779968.000000
batch 211: [6752/14448] | r2-score: -3.660508 | loss: 63532904448.000000
batch 212: [6784/14448] | r2-score: -3.160166 | loss: 63771942912.000000
batch 213: [6816/14448] | r2-score: -5.848222 | loss: 26042013696.000000
batch 214: [6848/14448] | r2-score: -3.117037 | loss: 60897071104.000000
batch 215: [6880/14448] | r2-score: -4.159490 | loss: 45127376896.000000
batch 216: [6912/14448] | r2-score: -3.851748 | loss: 49173708800.000000
batch 217: [6944/14448] | r2-score: -3.961482 | los

batch 343: [10976/14448] | r2-score: -2.459085 | loss: 55873241088.000000
batch 344: [11008/14448] | r2-score: -3.458063 | loss: 61756428288.000000
batch 345: [11040/14448] | r2-score: -2.522465 | loss: 57909739520.000000
batch 346: [11072/14448] | r2-score: -2.271732 | loss: 44106043392.000000
batch 347: [11104/14448] | r2-score: -3.519719 | loss: 68647895040.000000
batch 348: [11136/14448] | r2-score: -2.676483 | loss: 62734135296.000000
batch 349: [11168/14448] | r2-score: -2.444797 | loss: 58498412544.000000
batch 350: [11200/14448] | r2-score: -3.344100 | loss: 55107555328.000000
batch 351: [11232/14448] | r2-score: -3.426898 | loss: 66927345664.000000
batch 352: [11264/14448] | r2-score: -3.926064 | loss: 58081673216.000000
batch 353: [11296/14448] | r2-score: -3.189581 | loss: 58659733504.000000
batch 354: [11328/14448] | r2-score: -3.674016 | loss: 59559616512.000000
batch 355: [11360/14448] | r2-score: -3.043524 | loss: 51558862848.000000
batch 356: [11392/14448] | r2-score: -

batch 69: [2208/6192] | r2-score: -4.529369 | loss: 52516503552.000000
batch 70: [2240/6192] | r2-score: -3.820933 | loss: 43495112704.000000
batch 71: [2272/6192] | r2-score: -2.819438 | loss: 48709664768.000000
batch 72: [2304/6192] | r2-score: -3.667513 | loss: 64925523968.000000
batch 73: [2336/6192] | r2-score: -2.453595 | loss: 55285186560.000000
batch 74: [2368/6192] | r2-score: -3.201915 | loss: 51916857344.000000
batch 75: [2400/6192] | r2-score: -3.111160 | loss: 40069357568.000000
batch 76: [2432/6192] | r2-score: -3.825997 | loss: 55163990016.000000
batch 77: [2464/6192] | r2-score: -3.751889 | loss: 56730370048.000000
batch 78: [2496/6192] | r2-score: -2.144534 | loss: 61109719040.000000
batch 79: [2528/6192] | r2-score: -3.857051 | loss: 64894205952.000000
batch 80: [2560/6192] | r2-score: -2.289415 | loss: 54254899200.000000
batch 81: [2592/6192] | r2-score: -4.480968 | loss: 53486804992.000000
batch 82: [2624/6192] | r2-score: -6.112208 | loss: 37923844096.000000
batch 

batch 81: [2592/14448] | r2-score: -2.863077 | loss: 57766371328.000000
batch 82: [2624/14448] | r2-score: -4.828343 | loss: 55172837376.000000
batch 83: [2656/14448] | r2-score: -4.256867 | loss: 46161973248.000000
batch 84: [2688/14448] | r2-score: -3.037522 | loss: 40530206720.000000
batch 85: [2720/14448] | r2-score: -3.223147 | loss: 58149322752.000000
batch 86: [2752/14448] | r2-score: -3.280288 | loss: 48522727424.000000
batch 87: [2784/14448] | r2-score: -3.321436 | loss: 63154798592.000000
batch 88: [2816/14448] | r2-score: -2.477098 | loss: 55788998656.000000
batch 89: [2848/14448] | r2-score: -2.892776 | loss: 68979900416.000000
batch 90: [2880/14448] | r2-score: -2.853538 | loss: 59324456960.000000
batch 91: [2912/14448] | r2-score: -4.158280 | loss: 42359644160.000000
batch 92: [2944/14448] | r2-score: -3.676334 | loss: 61644173312.000000
batch 93: [2976/14448] | r2-score: -3.636378 | loss: 49376428032.000000
batch 94: [3008/14448] | r2-score: -5.835763 | loss: 37553848320

batch 223: [7136/14448] | r2-score: -4.756871 | loss: 41645576192.000000
batch 224: [7168/14448] | r2-score: -5.329603 | loss: 50987319296.000000
batch 225: [7200/14448] | r2-score: -3.297402 | loss: 60870590464.000000
batch 226: [7232/14448] | r2-score: -3.536936 | loss: 54791434240.000000
batch 227: [7264/14448] | r2-score: -3.967424 | loss: 52508635136.000000
batch 228: [7296/14448] | r2-score: -6.497169 | loss: 42782023680.000000
batch 229: [7328/14448] | r2-score: -4.049312 | loss: 70164242432.000000
batch 230: [7360/14448] | r2-score: -3.515106 | loss: 50490769408.000000
batch 231: [7392/14448] | r2-score: -2.653642 | loss: 49933254656.000000
batch 232: [7424/14448] | r2-score: -3.463491 | loss: 42491682816.000000
batch 233: [7456/14448] | r2-score: -2.707104 | loss: 65934966784.000000
batch 234: [7488/14448] | r2-score: -6.502561 | loss: 36581036032.000000
batch 235: [7520/14448] | r2-score: -4.725698 | loss: 51154239488.000000
batch 236: [7552/14448] | r2-score: -3.956703 | los

batch 366: [11712/14448] | r2-score: -2.142332 | loss: 63739670528.000000
batch 367: [11744/14448] | r2-score: -2.473579 | loss: 60059246592.000000
batch 368: [11776/14448] | r2-score: -5.044917 | loss: 56186085376.000000
batch 369: [11808/14448] | r2-score: -4.687952 | loss: 68911267840.000000
batch 370: [11840/14448] | r2-score: -2.925735 | loss: 66110590976.000000
batch 371: [11872/14448] | r2-score: -3.439593 | loss: 72487870464.000000
batch 372: [11904/14448] | r2-score: -2.657843 | loss: 78351884288.000000
batch 373: [11936/14448] | r2-score: -5.033010 | loss: 54876037120.000000
batch 374: [11968/14448] | r2-score: -3.177626 | loss: 55093903360.000000
batch 375: [12000/14448] | r2-score: -2.644005 | loss: 71127769088.000000
batch 376: [12032/14448] | r2-score: -4.456605 | loss: 51121360896.000000
batch 377: [12064/14448] | r2-score: -3.953001 | loss: 53651472384.000000
batch 378: [12096/14448] | r2-score: -2.795601 | loss: 63701360640.000000
batch 379: [12128/14448] | r2-score: -

batch 101: [3232/6192] | r2-score: -6.220920 | loss: 38557065216.000000
batch 102: [3264/6192] | r2-score: -3.972777 | loss: 72061337600.000000
batch 103: [3296/6192] | r2-score: -2.892732 | loss: 49871388672.000000
batch 104: [3328/6192] | r2-score: -3.439802 | loss: 44475301888.000000
batch 105: [3360/6192] | r2-score: -3.735047 | loss: 71495155712.000000
batch 106: [3392/6192] | r2-score: -3.725651 | loss: 45127856128.000000
batch 107: [3424/6192] | r2-score: -2.542410 | loss: 57608912896.000000
batch 108: [3456/6192] | r2-score: -4.615676 | loss: 38920314880.000000
batch 109: [3488/6192] | r2-score: -2.965518 | loss: 54966185984.000000
batch 110: [3520/6192] | r2-score: -4.070800 | loss: 38725316608.000000
batch 111: [3552/6192] | r2-score: -2.947289 | loss: 48817418240.000000
batch 112: [3584/6192] | r2-score: -3.960714 | loss: 56964046848.000000
batch 113: [3616/6192] | r2-score: -3.202795 | loss: 49737781248.000000
batch 114: [3648/6192] | r2-score: -2.882857 | loss: 57810173952

batch 81: [2592/14448] | r2-score: -2.863077 | loss: 57766371328.000000
batch 82: [2624/14448] | r2-score: -4.828343 | loss: 55172837376.000000
batch 83: [2656/14448] | r2-score: -4.256867 | loss: 46161973248.000000
batch 84: [2688/14448] | r2-score: -3.037522 | loss: 40530206720.000000
batch 85: [2720/14448] | r2-score: -3.223147 | loss: 58149322752.000000
batch 86: [2752/14448] | r2-score: -3.280288 | loss: 48522727424.000000
batch 87: [2784/14448] | r2-score: -3.321436 | loss: 63154798592.000000
batch 88: [2816/14448] | r2-score: -2.477098 | loss: 55788998656.000000
batch 89: [2848/14448] | r2-score: -2.892776 | loss: 68979900416.000000
batch 90: [2880/14448] | r2-score: -2.853538 | loss: 59324456960.000000
batch 91: [2912/14448] | r2-score: -4.158280 | loss: 42359644160.000000
batch 92: [2944/14448] | r2-score: -3.676334 | loss: 61644173312.000000
batch 93: [2976/14448] | r2-score: -3.636378 | loss: 49376428032.000000
batch 94: [3008/14448] | r2-score: -5.835763 | loss: 37553848320

batch 221: [7072/14448] | r2-score: -3.533837 | loss: 55643783168.000000
batch 222: [7104/14448] | r2-score: -3.308444 | loss: 81613225984.000000
batch 223: [7136/14448] | r2-score: -4.756871 | loss: 41645576192.000000
batch 224: [7168/14448] | r2-score: -5.329603 | loss: 50987319296.000000
batch 225: [7200/14448] | r2-score: -3.297402 | loss: 60870590464.000000
batch 226: [7232/14448] | r2-score: -3.536936 | loss: 54791434240.000000
batch 227: [7264/14448] | r2-score: -3.967424 | loss: 52508635136.000000
batch 228: [7296/14448] | r2-score: -6.497169 | loss: 42782023680.000000
batch 229: [7328/14448] | r2-score: -4.049312 | loss: 70164242432.000000
batch 230: [7360/14448] | r2-score: -3.515106 | loss: 50490769408.000000
batch 231: [7392/14448] | r2-score: -2.653642 | loss: 49933254656.000000
batch 232: [7424/14448] | r2-score: -3.463491 | loss: 42491682816.000000
batch 233: [7456/14448] | r2-score: -2.707104 | loss: 65934966784.000000
batch 234: [7488/14448] | r2-score: -6.502561 | los

batch 372: [11904/14448] | r2-score: -2.657843 | loss: 78351884288.000000
batch 373: [11936/14448] | r2-score: -5.033010 | loss: 54876037120.000000
batch 374: [11968/14448] | r2-score: -3.177626 | loss: 55093903360.000000
batch 375: [12000/14448] | r2-score: -2.644005 | loss: 71127769088.000000
batch 376: [12032/14448] | r2-score: -4.456605 | loss: 51121360896.000000
batch 377: [12064/14448] | r2-score: -3.953001 | loss: 53651472384.000000
batch 378: [12096/14448] | r2-score: -2.795601 | loss: 63701360640.000000
batch 379: [12128/14448] | r2-score: -2.404190 | loss: 65833181184.000000
batch 380: [12160/14448] | r2-score: -4.323248 | loss: 68194598912.000000
batch 381: [12192/14448] | r2-score: -3.394852 | loss: 49255936000.000000
batch 382: [12224/14448] | r2-score: -3.300272 | loss: 85746655232.000000
batch 383: [12256/14448] | r2-score: -3.570450 | loss: 44668928000.000000
batch 384: [12288/14448] | r2-score: -3.453694 | loss: 77608706048.000000
batch 385: [12320/14448] | r2-score: -

batch 96: [3072/6192] | r2-score: -3.307025 | loss: 48638509056.000000
batch 97: [3104/6192] | r2-score: -2.539963 | loss: 52396253184.000000
batch 98: [3136/6192] | r2-score: -3.886841 | loss: 37972545536.000000
batch 99: [3168/6192] | r2-score: -3.296240 | loss: 78477033472.000000
batch 100: [3200/6192] | r2-score: -3.910409 | loss: 56994668544.000000
batch 101: [3232/6192] | r2-score: -6.220920 | loss: 38557065216.000000
batch 102: [3264/6192] | r2-score: -3.972777 | loss: 72061337600.000000
batch 103: [3296/6192] | r2-score: -2.892732 | loss: 49871388672.000000
batch 104: [3328/6192] | r2-score: -3.439802 | loss: 44475301888.000000
batch 105: [3360/6192] | r2-score: -3.735047 | loss: 71495155712.000000
batch 106: [3392/6192] | r2-score: -3.725651 | loss: 45127856128.000000
batch 107: [3424/6192] | r2-score: -2.542410 | loss: 57608912896.000000
batch 108: [3456/6192] | r2-score: -4.615676 | loss: 38920314880.000000
batch 109: [3488/6192] | r2-score: -2.965518 | loss: 54966185984.000