# Import Modules

In [1]:
import pandas as pd
import torch

device = "cuda" if torch.cuda.is_available else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


# Import dataset

In [2]:
california_houses_df = pd.read_csv("../datasets/california_houses.csv")
california_houses_df.head()

Unnamed: 0,Median_House_Value,Median_Income,Median_Age,Tot_Rooms,Tot_Bedrooms,Population,Households,Latitude,Longitude,Distance_to_coast,Distance_to_LA,Distance_to_SanDiego,Distance_to_SanJose,Distance_to_SanFrancisco
0,452600.0,8.3252,41,880,129,322,126,37.88,-122.23,9263.040773,556529.158342,735501.806984,67432.517001,21250.213767
1,358500.0,8.3014,21,7099,1106,2401,1138,37.86,-122.22,10225.733072,554279.850069,733236.88436,65049.908574,20880.6004
2,352100.0,7.2574,52,1467,190,496,177,37.85,-122.24,8259.085109,554610.717069,733525.682937,64867.289833,18811.48745
3,341300.0,5.6431,52,1274,235,558,219,37.85,-122.25,7768.086571,555194.266086,734095.290744,65287.138412,18031.047568
4,342200.0,3.8462,52,1627,280,565,259,37.85,-122.25,7768.086571,555194.266086,734095.290744,65287.138412,18031.047568


In [3]:
california_houses_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 20640 entries, 0 to 20639
Data columns (total 14 columns):
 #   Column                    Non-Null Count  Dtype  
---  ------                    --------------  -----  
 0   Median_House_Value        20640 non-null  float64
 1   Median_Income             20640 non-null  float64
 2   Median_Age                20640 non-null  int64  
 3   Tot_Rooms                 20640 non-null  int64  
 4   Tot_Bedrooms              20640 non-null  int64  
 5   Population                20640 non-null  int64  
 6   Households                20640 non-null  int64  
 7   Latitude                  20640 non-null  float64
 8   Longitude                 20640 non-null  float64
 9   Distance_to_coast         20640 non-null  float64
 10  Distance_to_LA            20640 non-null  float64
 11  Distance_to_SanDiego      20640 non-null  float64
 12  Distance_to_SanJose       20640 non-null  float64
 13  Distance_to_SanFrancisco  20640 non-null  float64
dtypes: flo

In [4]:
california_houses_df.isna().sum()

Median_House_Value          0
Median_Income               0
Median_Age                  0
Tot_Rooms                   0
Tot_Bedrooms                0
Population                  0
Households                  0
Latitude                    0
Longitude                   0
Distance_to_coast           0
Distance_to_LA              0
Distance_to_SanDiego        0
Distance_to_SanJose         0
Distance_to_SanFrancisco    0
dtype: int64

In [5]:
california_houses_df.describe()

Unnamed: 0,Median_House_Value,Median_Income,Median_Age,Tot_Rooms,Tot_Bedrooms,Population,Households,Latitude,Longitude,Distance_to_coast,Distance_to_LA,Distance_to_SanDiego,Distance_to_SanJose,Distance_to_SanFrancisco
count,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0,20640.0
mean,206855.816909,3.870671,28.639486,2635.763081,537.898014,1425.476744,499.53968,35.631861,-119.569704,40509.264883,269422.0,398164.9,349187.551219,386688.422291
std,115395.615874,1.899822,12.585558,2181.615252,421.247906,1132.462122,382.329753,2.135952,2.003532,49140.03916,247732.4,289400.6,217149.875026,250122.192316
min,14999.0,0.4999,1.0,2.0,1.0,3.0,1.0,32.54,-124.35,120.676447,420.5891,484.918,569.448118,456.141313
25%,119600.0,2.5634,18.0,1447.75,295.0,787.0,280.0,33.93,-121.8,9079.756762,32111.25,159426.4,113119.928682,117395.477505
50%,179700.0,3.5348,29.0,2127.0,435.0,1166.0,409.0,34.26,-118.49,20522.019101,173667.5,214739.8,459758.877,526546.661701
75%,264725.0,4.74325,37.0,3148.0,647.0,1725.0,605.0,37.71,-118.01,49830.414479,527156.2,705795.4,516946.490963,584552.007907
max,500001.0,15.0001,52.0,39320.0,6445.0,35682.0,6082.0,41.95,-114.31,333804.686371,1018260.0,1196919.0,836762.67821,903627.663298


# Split feature and target

In [6]:
y = california_houses_df.pop("Median_House_Value").values.reshape(-1, 1)
X = california_houses_df.values

print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")

X shape: (20640, 13)
y shape: (20640, 1)


# Split train and test dataset

In [7]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, shuffle=True, test_size=0.3)
print(f"X_train shape: {X_train.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (14448, 13)
y_train shape: (14448, 1)
X_test shape: (6192, 13)
y_test shape: (6192, 1)


In [8]:
X_train = torch.FloatTensor(X_train)
y_train = torch.FloatTensor(y_train)

X_test = torch.FloatTensor(X_test)
y_test = torch.FloatTensor(y_test)

# Build Architecture

## TensorDataset and DataLoader

In [9]:
from torch.utils.data import TensorDataset, DataLoader

# TensorDataset
tensor_dataset_train = TensorDataset(X_train, y_train)
tensor_dataset_test = TensorDataset(X_test, y_test)

# DataLoader
dataloader_train = DataLoader(tensor_dataset_train, batch_size=32)
dataloader_test = DataLoader(tensor_dataset_test, batch_size=32)

## Build Model

In [10]:
from torch import nn, optim

model = nn.Sequential(          
            # layer 2 (16 neurons) to layer 3 (8 neurons) with ReLU
            nn.Linear(13, 8),
            nn.ReLU6(),
    
            # layer 4 (8 neurons) to layer 5 (4 neurons) with ReLU
            nn.Linear(8, 4),
            nn.ReLU6(),
    
            # layer 6 (4 neurons) to layer 6 (1 neurons)
            nn.Linear(4, 1),
            nn.ReLU6()
        )

model

Sequential(
  (0): Linear(in_features=13, out_features=8, bias=True)
  (1): ReLU6()
  (2): Linear(in_features=8, out_features=4, bias=True)
  (3): ReLU6()
  (4): Linear(in_features=4, out_features=1, bias=True)
  (5): ReLU6()
)

In [11]:
model.state_dict()

OrderedDict([('0.weight',
              tensor([[-0.0648,  0.1598,  0.0996,  0.0361, -0.0384,  0.2731,  0.2266,  0.0104,
                        0.2059,  0.0164, -0.1110,  0.1497, -0.1123],
                      [ 0.2502,  0.1730, -0.0255, -0.1914, -0.2052,  0.0649,  0.0854,  0.0138,
                        0.2365, -0.0535,  0.0802, -0.0408,  0.1674],
                      [-0.1456,  0.1104,  0.2604,  0.0824, -0.1296, -0.2348, -0.1655, -0.0941,
                        0.0556,  0.2487, -0.2170,  0.1475, -0.1497],
                      [ 0.0407,  0.0111, -0.0278,  0.0448,  0.0503,  0.0387, -0.0899, -0.0344,
                       -0.1585, -0.0915,  0.2308, -0.1656, -0.2288],
                      [-0.1609,  0.1879,  0.2385,  0.0220,  0.2248,  0.1938, -0.0903,  0.0356,
                        0.1139,  0.1380,  0.2133,  0.2471, -0.1938],
                      [-0.1507, -0.0386, -0.0830,  0.0814,  0.2519, -0.2158, -0.0531, -0.1548,
                       -0.0048, -0.1326, -0.2294,  0.2540, 

## Loss Function

In [12]:
mse_loss = nn.MSELoss()
mse_loss

MSELoss()

## Optimizer

In [13]:
adamw_optimizer = optim.AdamW(model.parameters(), lr=0.001,)
adamw_optimizer

AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    maximize: False
    weight_decay: 0.01
)

## Loop function

In [14]:
from torchmetrics.functional import r2_score

def train_loop(dataloader, model, loss_fn, optimizer_fn):
    size = len(dataloader.dataset)
    loss_batches = []
    r2_score_batches = []
    
    for batch, (X, y) in enumerate(dataloader_train):
        # Forwardpropagation
        pred = model(X)
        loss = loss_fn(pred, y)
        r_2_score = r2_score(pred, y)
        
        # Backpropagation
        optimizer_fn.zero_grad()
        loss.backward()
        optimizer_fn.step()
        
        loss = loss.item()
        loss_batches.append(loss)
        
        r_2_score = r_2_score.item()
        r2_score_batches.append(r_2_score)
        
        current_batch = (batch + 1) * len(X)
        
        print(f"batch {batch + 1}: [{current_batch:>3d}/{size:>3d}] | r2-score: {r_2_score:>8f} | loss: {loss:>7f}")
    print(f"Mean R2 score  : {sum(r2_score_batches) / len(r2_score_batches)}")
    print(f"Mean RMSE loss : {(sum(loss_batches) / len(loss_batches))**0.5}")
    
def test_loop(dataloader, model, loss_fn, optimizer_fn):
    size = len(dataloader.dataset)
    loss_batches = []
    r2_score_batches = []
    
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader_test):
            pred = model(X)
            loss = loss_fn(pred, y)
            r_2_score = r2_score(pred, y)

            loss = loss.item()
            loss_batches.append(loss)

            r_2_score = r_2_score.item()
            r2_score_batches.append(r_2_score)

            current_batch = (batch + 1) * len(X)

            print(f"batch {batch + 1}: [{current_batch:>3d}/{size:>3d}] | r2-score: {r_2_score:>8f} | loss: {loss:>7f}")
    print(f"Mean R2 score  : {sum(r2_score_batches) / len(r2_score_batches)}")
    print(f"Mean RMSE loss : {(sum(loss_batches) / len(loss_batches))**0.5}")

In [15]:
epochs = 5

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}")
    
    print(f"Train")
    train_loop(dataloader_train, model, mse_loss, adamw_optimizer)
    
    print(f"\nTest")
    test_loop(dataloader_test, model, mse_loss, adamw_optimizer)
    print("-" * 80, end="\n\n")
print("Done!")

Epoch 1
Train
batch 1: [ 32/14448] | r2-score: -2.256716 | loss: 55027707904.000000
batch 2: [ 64/14448] | r2-score: -4.097435 | loss: 44023836672.000000
batch 3: [ 96/14448] | r2-score: -3.492011 | loss: 44534435840.000000
batch 4: [128/14448] | r2-score: -3.338817 | loss: 44882038784.000000
batch 5: [160/14448] | r2-score: -3.299134 | loss: 65397383168.000000
batch 6: [192/14448] | r2-score: -2.944703 | loss: 63606013952.000000
batch 7: [224/14448] | r2-score: -3.634229 | loss: 50239266816.000000
batch 8: [256/14448] | r2-score: -3.868825 | loss: 73876381696.000000
batch 9: [288/14448] | r2-score: -3.300528 | loss: 65567989760.000000
batch 10: [320/14448] | r2-score: -2.428397 | loss: 46898954240.000000
batch 11: [352/14448] | r2-score: -2.928611 | loss: 61691748352.000000
batch 12: [384/14448] | r2-score: -4.161337 | loss: 83421650944.000000
batch 13: [416/14448] | r2-score: -3.811436 | loss: 59475873792.000000
batch 14: [448/14448] | r2-score: -2.927846 | loss: 52143046656.000000
b

batch 117: [3744/14448] | r2-score: -3.180305 | loss: 60321730560.000000
batch 118: [3776/14448] | r2-score: -2.810019 | loss: 47157878784.000000
batch 119: [3808/14448] | r2-score: -3.547956 | loss: 50248298496.000000
batch 120: [3840/14448] | r2-score: -3.010956 | loss: 54435852288.000000
batch 121: [3872/14448] | r2-score: -2.938738 | loss: 46999334912.000000
batch 122: [3904/14448] | r2-score: -2.908273 | loss: 52094656512.000000
batch 123: [3936/14448] | r2-score: -3.132016 | loss: 66485780480.000000
batch 124: [3968/14448] | r2-score: -2.562836 | loss: 53244928000.000000
batch 125: [4000/14448] | r2-score: -4.035034 | loss: 48143986688.000000
batch 126: [4032/14448] | r2-score: -1.942349 | loss: 35627069440.000000
batch 127: [4064/14448] | r2-score: -3.934032 | loss: 62813618176.000000
batch 128: [4096/14448] | r2-score: -3.917577 | loss: 53798871040.000000
batch 129: [4128/14448] | r2-score: -4.934822 | loss: 35772497920.000000
batch 130: [4160/14448] | r2-score: -2.989801 | los

batch 243: [7776/14448] | r2-score: -2.842750 | loss: 52819542016.000000
batch 244: [7808/14448] | r2-score: -2.964167 | loss: 51296100352.000000
batch 245: [7840/14448] | r2-score: -4.238004 | loss: 75235196928.000000
batch 246: [7872/14448] | r2-score: -3.668447 | loss: 57411772416.000000
batch 247: [7904/14448] | r2-score: -3.351680 | loss: 32643547136.000000
batch 248: [7936/14448] | r2-score: -4.694502 | loss: 73707577344.000000
batch 249: [7968/14448] | r2-score: -4.275792 | loss: 57687871488.000000
batch 250: [8000/14448] | r2-score: -3.197907 | loss: 53302280192.000000
batch 251: [8032/14448] | r2-score: -3.784241 | loss: 47864078336.000000
batch 252: [8064/14448] | r2-score: -2.807253 | loss: 52708880384.000000
batch 253: [8096/14448] | r2-score: -3.466101 | loss: 64944992256.000000
batch 254: [8128/14448] | r2-score: -3.739981 | loss: 76517900288.000000
batch 255: [8160/14448] | r2-score: -3.286695 | loss: 53634875392.000000
batch 256: [8192/14448] | r2-score: -3.141279 | los

batch 376: [12032/14448] | r2-score: -3.878399 | loss: 67381485568.000000
batch 377: [12064/14448] | r2-score: -3.867143 | loss: 56422834176.000000
batch 378: [12096/14448] | r2-score: -4.418599 | loss: 65075175424.000000
batch 379: [12128/14448] | r2-score: -3.246072 | loss: 49965957120.000000
batch 380: [12160/14448] | r2-score: -2.963912 | loss: 68437352448.000000
batch 381: [12192/14448] | r2-score: -3.677355 | loss: 41102135296.000000
batch 382: [12224/14448] | r2-score: -3.146748 | loss: 68818100224.000000
batch 383: [12256/14448] | r2-score: -3.718592 | loss: 58927382528.000000
batch 384: [12288/14448] | r2-score: -3.822782 | loss: 58998927360.000000
batch 385: [12320/14448] | r2-score: -2.127235 | loss: 53676335104.000000
batch 386: [12352/14448] | r2-score: -3.191897 | loss: 71049191424.000000
batch 387: [12384/14448] | r2-score: -3.030830 | loss: 79388762112.000000
batch 388: [12416/14448] | r2-score: -4.428813 | loss: 54505840640.000000
batch 389: [12448/14448] | r2-score: -

batch 130: [4160/6192] | r2-score: -4.490444 | loss: 39380574208.000000
batch 131: [4192/6192] | r2-score: -5.317116 | loss: 50903212032.000000
batch 132: [4224/6192] | r2-score: -4.073326 | loss: 48246284288.000000
batch 133: [4256/6192] | r2-score: -3.504773 | loss: 59625259008.000000
batch 134: [4288/6192] | r2-score: -3.133541 | loss: 61395910656.000000
batch 135: [4320/6192] | r2-score: -2.949124 | loss: 53069209600.000000
batch 136: [4352/6192] | r2-score: -3.594536 | loss: 63974203392.000000
batch 137: [4384/6192] | r2-score: -3.649611 | loss: 78070358016.000000
batch 138: [4416/6192] | r2-score: -2.998203 | loss: 70117302272.000000
batch 139: [4448/6192] | r2-score: -5.552470 | loss: 58631061504.000000
batch 140: [4480/6192] | r2-score: -3.664926 | loss: 41762136064.000000
batch 141: [4512/6192] | r2-score: -3.243754 | loss: 70553993216.000000
batch 142: [4544/6192] | r2-score: -2.517846 | loss: 41551945728.000000
batch 143: [4576/6192] | r2-score: -3.651667 | loss: 55543300096

batch 109: [3488/14448] | r2-score: -2.668370 | loss: 75137949696.000000
batch 110: [3520/14448] | r2-score: -2.738893 | loss: 64639770624.000000
batch 111: [3552/14448] | r2-score: -3.310522 | loss: 50731040768.000000
batch 112: [3584/14448] | r2-score: -2.644147 | loss: 72317493248.000000
batch 113: [3616/14448] | r2-score: -5.111373 | loss: 56193146880.000000
batch 114: [3648/14448] | r2-score: -5.360078 | loss: 64765349888.000000
batch 115: [3680/14448] | r2-score: -3.915864 | loss: 42444263424.000000
batch 116: [3712/14448] | r2-score: -3.110011 | loss: 65741275136.000000
batch 117: [3744/14448] | r2-score: -3.180263 | loss: 60321124352.000000
batch 118: [3776/14448] | r2-score: -2.809977 | loss: 47157350400.000000
batch 119: [3808/14448] | r2-score: -3.547903 | loss: 50247712768.000000
batch 120: [3840/14448] | r2-score: -3.010913 | loss: 54435266560.000000
batch 121: [3872/14448] | r2-score: -2.938691 | loss: 46998773760.000000
batch 122: [3904/14448] | r2-score: -2.908233 | los

batch 245: [7840/14448] | r2-score: -4.238003 | loss: 75235180544.000000
batch 246: [7872/14448] | r2-score: -3.668443 | loss: 57411715072.000000
batch 247: [7904/14448] | r2-score: -3.351675 | loss: 32643510272.000000
batch 248: [7936/14448] | r2-score: -4.694501 | loss: 73707560960.000000
batch 249: [7968/14448] | r2-score: -4.275787 | loss: 57687818240.000000
batch 250: [8000/14448] | r2-score: -3.197903 | loss: 53302231040.000000
batch 251: [8032/14448] | r2-score: -3.784240 | loss: 47864066048.000000
batch 252: [8064/14448] | r2-score: -2.807252 | loss: 52708864000.000000
batch 253: [8096/14448] | r2-score: -3.466099 | loss: 64944975872.000000
batch 254: [8128/14448] | r2-score: -3.739980 | loss: 76517883904.000000
batch 255: [8160/14448] | r2-score: -3.286692 | loss: 53634838528.000000
batch 256: [8192/14448] | r2-score: -3.141277 | loss: 54501801984.000000
batch 257: [8224/14448] | r2-score: -2.894618 | loss: 36435120128.000000
batch 258: [8256/14448] | r2-score: -3.077367 | los

batch 364: [11648/14448] | r2-score: -3.020008 | loss: 69567766528.000000
batch 365: [11680/14448] | r2-score: -7.769320 | loss: 53041389568.000000
batch 366: [11712/14448] | r2-score: -3.115874 | loss: 56433655808.000000
batch 367: [11744/14448] | r2-score: -6.083557 | loss: 51037892608.000000
batch 368: [11776/14448] | r2-score: -3.196107 | loss: 76927492096.000000
batch 369: [11808/14448] | r2-score: -3.074346 | loss: 64183484416.000000
batch 370: [11840/14448] | r2-score: -5.170531 | loss: 57307242496.000000
batch 371: [11872/14448] | r2-score: -3.637842 | loss: 59084423168.000000
batch 372: [11904/14448] | r2-score: -3.477482 | loss: 40328265728.000000
batch 373: [11936/14448] | r2-score: -2.382679 | loss: 58845388800.000000
batch 374: [11968/14448] | r2-score: -2.847700 | loss: 58316333056.000000
batch 375: [12000/14448] | r2-score: -3.246603 | loss: 69055037440.000000
batch 376: [12032/14448] | r2-score: -3.878398 | loss: 67381469184.000000
batch 377: [12064/14448] | r2-score: -

batch 114: [3648/6192] | r2-score: -3.758795 | loss: 41791188992.000000
batch 115: [3680/6192] | r2-score: -3.041313 | loss: 70173704192.000000
batch 116: [3712/6192] | r2-score: -3.133354 | loss: 73286672384.000000
batch 117: [3744/6192] | r2-score: -3.307627 | loss: 57681096704.000000
batch 118: [3776/6192] | r2-score: -3.591442 | loss: 61240205312.000000
batch 119: [3808/6192] | r2-score: -3.594192 | loss: 49811668992.000000
batch 120: [3840/6192] | r2-score: -3.949601 | loss: 53673775104.000000
batch 121: [3872/6192] | r2-score: -3.511272 | loss: 58047512576.000000
batch 122: [3904/6192] | r2-score: -4.733282 | loss: 63895400448.000000
batch 123: [3936/6192] | r2-score: -3.036925 | loss: 60034297856.000000
batch 124: [3968/6192] | r2-score: -2.491173 | loss: 50259046400.000000
batch 125: [4000/6192] | r2-score: -3.468369 | loss: 49074900992.000000
batch 126: [4032/6192] | r2-score: -2.224557 | loss: 48309915648.000000
batch 127: [4064/6192] | r2-score: -4.416440 | loss: 66996609024

batch 90: [2880/14448] | r2-score: -3.941823 | loss: 69056700416.000000
batch 91: [2912/14448] | r2-score: -3.407712 | loss: 64147480576.000000
batch 92: [2944/14448] | r2-score: -4.180635 | loss: 72147722240.000000
batch 93: [2976/14448] | r2-score: -3.394987 | loss: 60580667392.000000
batch 94: [3008/14448] | r2-score: -2.855589 | loss: 67632058368.000000
batch 95: [3040/14448] | r2-score: -4.244689 | loss: 55190806528.000000
batch 96: [3072/14448] | r2-score: -4.792716 | loss: 64840908800.000000
batch 97: [3104/14448] | r2-score: -3.211012 | loss: 72689926144.000000
batch 98: [3136/14448] | r2-score: -2.815260 | loss: 45591031808.000000
batch 99: [3168/14448] | r2-score: -3.486435 | loss: 67881025536.000000
batch 100: [3200/14448] | r2-score: -3.610381 | loss: 50021748736.000000
batch 101: [3232/14448] | r2-score: -3.237915 | loss: 53399486464.000000
batch 102: [3264/14448] | r2-score: -3.755632 | loss: 69359976448.000000
batch 103: [3296/14448] | r2-score: -4.697978 | loss: 3982239

batch 217: [6944/14448] | r2-score: -3.031546 | loss: 57224699904.000000
batch 218: [6976/14448] | r2-score: -4.334396 | loss: 55233425408.000000
batch 219: [7008/14448] | r2-score: -3.806149 | loss: 45195735040.000000
batch 220: [7040/14448] | r2-score: -2.945431 | loss: 52789821440.000000
batch 221: [7072/14448] | r2-score: -3.220521 | loss: 64707366912.000000
batch 222: [7104/14448] | r2-score: -3.781035 | loss: 38702620672.000000
batch 223: [7136/14448] | r2-score: -2.955949 | loss: 48234987520.000000
batch 224: [7168/14448] | r2-score: -6.101195 | loss: 43982737408.000000
batch 225: [7200/14448] | r2-score: -4.129101 | loss: 45547847680.000000
batch 226: [7232/14448] | r2-score: -3.619225 | loss: 69581471744.000000
batch 227: [7264/14448] | r2-score: -3.092141 | loss: 77347758080.000000
batch 228: [7296/14448] | r2-score: -3.738043 | loss: 76005474304.000000
batch 229: [7328/14448] | r2-score: -2.944056 | loss: 67365732352.000000
batch 230: [7360/14448] | r2-score: -3.592477 | los

batch 332: [10624/14448] | r2-score: -2.703134 | loss: 55106482176.000000
batch 333: [10656/14448] | r2-score: -4.639363 | loss: 51587776512.000000
batch 334: [10688/14448] | r2-score: -3.253098 | loss: 82552381440.000000
batch 335: [10720/14448] | r2-score: -3.189187 | loss: 47226413056.000000
batch 336: [10752/14448] | r2-score: -2.676831 | loss: 65806888960.000000
batch 337: [10784/14448] | r2-score: -5.040491 | loss: 51244699648.000000
batch 338: [10816/14448] | r2-score: -4.449162 | loss: 59130503168.000000
batch 339: [10848/14448] | r2-score: -3.424603 | loss: 58866782208.000000
batch 340: [10880/14448] | r2-score: -3.464176 | loss: 48341233664.000000
batch 341: [10912/14448] | r2-score: -4.564861 | loss: 30163767296.000000
batch 342: [10944/14448] | r2-score: -3.164110 | loss: 75946606592.000000
batch 343: [10976/14448] | r2-score: -4.648753 | loss: 45110960128.000000
batch 344: [11008/14448] | r2-score: -4.097062 | loss: 54953779200.000000
batch 345: [11040/14448] | r2-score: -

Mean R2 score  : -3.515689186290302
Mean RMSE loss : 236567.59587135763

Test
batch 1: [ 32/6192] | r2-score: -4.762015 | loss: 54215553024.000000
batch 2: [ 64/6192] | r2-score: -3.276297 | loss: 64075255808.000000
batch 3: [ 96/6192] | r2-score: -3.381619 | loss: 63481274368.000000
batch 4: [128/6192] | r2-score: -2.948027 | loss: 67447848960.000000
batch 5: [160/6192] | r2-score: -4.246280 | loss: 48563658752.000000
batch 6: [192/6192] | r2-score: -2.872000 | loss: 56147705856.000000
batch 7: [224/6192] | r2-score: -3.670346 | loss: 47867060224.000000
batch 8: [256/6192] | r2-score: -5.287085 | loss: 62501199872.000000
batch 9: [288/6192] | r2-score: -2.961820 | loss: 71605559296.000000
batch 10: [320/6192] | r2-score: -4.093207 | loss: 57342828544.000000
batch 11: [352/6192] | r2-score: -3.276094 | loss: 67835121664.000000
batch 12: [384/6192] | r2-score: -2.737924 | loss: 66985750528.000000
batch 13: [416/6192] | r2-score: -3.926473 | loss: 41226899456.000000
batch 14: [448/6192] 

batch 46: [1472/14448] | r2-score: -2.584530 | loss: 54618591232.000000
batch 47: [1504/14448] | r2-score: -3.821926 | loss: 43088187392.000000
batch 48: [1536/14448] | r2-score: -2.685828 | loss: 70149283840.000000
batch 49: [1568/14448] | r2-score: -3.527829 | loss: 55987576832.000000
batch 50: [1600/14448] | r2-score: -4.060518 | loss: 56808140800.000000
batch 51: [1632/14448] | r2-score: -3.628175 | loss: 79306530816.000000
batch 52: [1664/14448] | r2-score: -5.265008 | loss: 32945635328.000000
batch 53: [1696/14448] | r2-score: -4.225728 | loss: 55275577344.000000
batch 54: [1728/14448] | r2-score: -3.996856 | loss: 38320205824.000000
batch 55: [1760/14448] | r2-score: -3.673606 | loss: 42717356032.000000
batch 56: [1792/14448] | r2-score: -3.081332 | loss: 64801619968.000000
batch 57: [1824/14448] | r2-score: -3.207560 | loss: 74058088448.000000
batch 58: [1856/14448] | r2-score: -2.319551 | loss: 74185605120.000000
batch 59: [1888/14448] | r2-score: -2.802140 | loss: 54504169472

batch 183: [5856/14448] | r2-score: -2.653716 | loss: 39485366272.000000
batch 184: [5888/14448] | r2-score: -2.844733 | loss: 51097534464.000000
batch 185: [5920/14448] | r2-score: -4.295349 | loss: 55436963840.000000
batch 186: [5952/14448] | r2-score: -4.171212 | loss: 44541464576.000000
batch 187: [5984/14448] | r2-score: -2.743322 | loss: 75053670400.000000
batch 188: [6016/14448] | r2-score: -3.490040 | loss: 40626663424.000000
batch 189: [6048/14448] | r2-score: -3.198194 | loss: 62471467008.000000
batch 190: [6080/14448] | r2-score: -3.275435 | loss: 77647052800.000000
batch 191: [6112/14448] | r2-score: -2.602901 | loss: 60035080192.000000
batch 192: [6144/14448] | r2-score: -2.816443 | loss: 54728839168.000000
batch 193: [6176/14448] | r2-score: -5.205332 | loss: 36428996608.000000
batch 194: [6208/14448] | r2-score: -2.640963 | loss: 58689290240.000000
batch 195: [6240/14448] | r2-score: -4.061607 | loss: 40336601088.000000
batch 196: [6272/14448] | r2-score: -3.938886 | los

batch 312: [9984/14448] | r2-score: -4.387369 | loss: 36566413312.000000
batch 313: [10016/14448] | r2-score: -2.252676 | loss: 57771311104.000000
batch 314: [10048/14448] | r2-score: -4.052214 | loss: 40757018624.000000
batch 315: [10080/14448] | r2-score: -2.395498 | loss: 63781199872.000000
batch 316: [10112/14448] | r2-score: -5.154809 | loss: 44006899712.000000
batch 317: [10144/14448] | r2-score: -2.794616 | loss: 51466313728.000000
batch 318: [10176/14448] | r2-score: -2.434743 | loss: 64588607488.000000
batch 319: [10208/14448] | r2-score: -4.354911 | loss: 72199954432.000000
batch 320: [10240/14448] | r2-score: -2.940106 | loss: 42544152576.000000
batch 321: [10272/14448] | r2-score: -3.674726 | loss: 61075443712.000000
batch 322: [10304/14448] | r2-score: -3.234478 | loss: 59098161152.000000
batch 323: [10336/14448] | r2-score: -3.551503 | loss: 75983536128.000000
batch 324: [10368/14448] | r2-score: -5.949562 | loss: 26452987904.000000
batch 325: [10400/14448] | r2-score: -2

batch 446: [14272/14448] | r2-score: -3.226383 | loss: 37715693568.000000
batch 447: [14304/14448] | r2-score: -3.659529 | loss: 36858200064.000000
batch 448: [14336/14448] | r2-score: -3.081904 | loss: 61748408320.000000
batch 449: [14368/14448] | r2-score: -5.203619 | loss: 37625868288.000000
batch 450: [14400/14448] | r2-score: -2.852618 | loss: 69455077376.000000
batch 451: [14432/14448] | r2-score: -3.653930 | loss: 51205492736.000000
batch 452: [7232/14448] | r2-score: -3.163484 | loss: 52835475456.000000
Mean R2 score  : -3.515689186290302
Mean RMSE loss : 236567.59587135763

Test
batch 1: [ 32/6192] | r2-score: -4.762015 | loss: 54215553024.000000
batch 2: [ 64/6192] | r2-score: -3.276297 | loss: 64075255808.000000
batch 3: [ 96/6192] | r2-score: -3.381619 | loss: 63481274368.000000
batch 4: [128/6192] | r2-score: -2.948027 | loss: 67447848960.000000
batch 5: [160/6192] | r2-score: -4.246280 | loss: 48563658752.000000
batch 6: [192/6192] | r2-score: -2.872000 | loss: 5614770585

batch 42: [1344/14448] | r2-score: -2.849109 | loss: 43140628480.000000
batch 43: [1376/14448] | r2-score: -5.168202 | loss: 47394275328.000000
batch 44: [1408/14448] | r2-score: -3.899067 | loss: 44198002688.000000
batch 45: [1440/14448] | r2-score: -3.058224 | loss: 61290258432.000000
batch 46: [1472/14448] | r2-score: -2.584530 | loss: 54618591232.000000
batch 47: [1504/14448] | r2-score: -3.821926 | loss: 43088187392.000000
batch 48: [1536/14448] | r2-score: -2.685828 | loss: 70149283840.000000
batch 49: [1568/14448] | r2-score: -3.527829 | loss: 55987576832.000000
batch 50: [1600/14448] | r2-score: -4.060518 | loss: 56808140800.000000
batch 51: [1632/14448] | r2-score: -3.628175 | loss: 79306530816.000000
batch 52: [1664/14448] | r2-score: -5.265008 | loss: 32945635328.000000
batch 53: [1696/14448] | r2-score: -4.225728 | loss: 55275577344.000000
batch 54: [1728/14448] | r2-score: -3.996856 | loss: 38320205824.000000
batch 55: [1760/14448] | r2-score: -3.673606 | loss: 42717356032

batch 177: [5664/14448] | r2-score: -2.904813 | loss: 63696130048.000000
batch 178: [5696/14448] | r2-score: -3.163656 | loss: 63449432064.000000
batch 179: [5728/14448] | r2-score: -2.013594 | loss: 55627096064.000000
batch 180: [5760/14448] | r2-score: -3.246350 | loss: 43618217984.000000
batch 181: [5792/14448] | r2-score: -4.337768 | loss: 69649399808.000000
batch 182: [5824/14448] | r2-score: -2.618397 | loss: 49607000064.000000
batch 183: [5856/14448] | r2-score: -2.653716 | loss: 39485366272.000000
batch 184: [5888/14448] | r2-score: -2.844733 | loss: 51097534464.000000
batch 185: [5920/14448] | r2-score: -4.295349 | loss: 55436963840.000000
batch 186: [5952/14448] | r2-score: -4.171212 | loss: 44541464576.000000
batch 187: [5984/14448] | r2-score: -2.743322 | loss: 75053670400.000000
batch 188: [6016/14448] | r2-score: -3.490040 | loss: 40626663424.000000
batch 189: [6048/14448] | r2-score: -3.198194 | loss: 62471467008.000000
batch 190: [6080/14448] | r2-score: -3.275435 | los

batch 317: [10144/14448] | r2-score: -2.794616 | loss: 51466313728.000000
batch 318: [10176/14448] | r2-score: -2.434743 | loss: 64588607488.000000
batch 319: [10208/14448] | r2-score: -4.354911 | loss: 72199954432.000000
batch 320: [10240/14448] | r2-score: -2.940106 | loss: 42544152576.000000
batch 321: [10272/14448] | r2-score: -3.674726 | loss: 61075443712.000000
batch 322: [10304/14448] | r2-score: -3.234478 | loss: 59098161152.000000
batch 323: [10336/14448] | r2-score: -3.551503 | loss: 75983536128.000000
batch 324: [10368/14448] | r2-score: -5.949562 | loss: 26452987904.000000
batch 325: [10400/14448] | r2-score: -2.429657 | loss: 65356791808.000000
batch 326: [10432/14448] | r2-score: -2.682276 | loss: 59049828352.000000
batch 327: [10464/14448] | r2-score: -2.880058 | loss: 58275565568.000000
batch 328: [10496/14448] | r2-score: -3.266445 | loss: 49392828416.000000
batch 329: [10528/14448] | r2-score: -3.321722 | loss: 51269214208.000000
batch 330: [10560/14448] | r2-score: -

batch 8: [256/6192] | r2-score: -5.287085 | loss: 62501199872.000000
batch 9: [288/6192] | r2-score: -2.961820 | loss: 71605559296.000000
batch 10: [320/6192] | r2-score: -4.093207 | loss: 57342828544.000000
batch 11: [352/6192] | r2-score: -3.276094 | loss: 67835121664.000000
batch 12: [384/6192] | r2-score: -2.737924 | loss: 66985750528.000000
batch 13: [416/6192] | r2-score: -3.926473 | loss: 41226899456.000000
batch 14: [448/6192] | r2-score: -3.388986 | loss: 74495451136.000000
batch 15: [480/6192] | r2-score: -2.787923 | loss: 41757261824.000000
batch 16: [512/6192] | r2-score: -4.843940 | loss: 82316238848.000000
batch 17: [544/6192] | r2-score: -4.093804 | loss: 46184542208.000000
batch 18: [576/6192] | r2-score: -4.197429 | loss: 35894550528.000000
batch 19: [608/6192] | r2-score: -4.604744 | loss: 47541878784.000000
batch 20: [640/6192] | r2-score: -2.928231 | loss: 67264503808.000000
batch 21: [672/6192] | r2-score: -2.580067 | loss: 71168901120.000000
batch 22: [704/6192] |

In [17]:
14448 % 32

16