---
title: "Backpack Prediction Data Exploration"
format: html
jupyter: python3
---

In [8]:
import pandas as pd
from IPython.display import display

train1 = pd.read_csv("train.csv")
train2 = pd.read_csv("training_extra.csv")
train = pd.concat([train1, train2], ignore_index = True)
test = pd.read_csv("test.csv")
display(train.head())
display(train.describe())
display(train.isna().sum())
for col in train:
    print(train[col].unique())
    print(train[col].dtype)

Unnamed: 0,id,Brand,Material,Size,Compartments,Laptop Compartment,Waterproof,Style,Color,Weight Capacity (kg),Price
0,0,Jansport,Leather,Medium,7.0,Yes,No,Tote,Black,11.611723,112.15875
1,1,Jansport,Canvas,Small,10.0,Yes,Yes,Messenger,Green,27.078537,68.88056
2,2,Under Armour,Leather,Small,2.0,Yes,No,Messenger,Red,16.64376,39.1732
3,3,Nike,Nylon,Small,8.0,Yes,No,Messenger,Green,12.93722,80.60793
4,4,Adidas,Canvas,Medium,1.0,Yes,Yes,Messenger,Green,17.749338,86.02312


Unnamed: 0,id,Compartments,Weight Capacity (kg),Price
count,3994318.0,3994318.0,3992510.0,3994318.0
mean,2182137.0,5.43474,18.01042,81.36217
std,1178058.0,2.893043,6.973969,38.93868
min,0.0,1.0,5.0,15.0
25%,1198579.0,3.0,12.06896,47.47002
50%,2197158.0,5.0,18.05436,80.98495
75%,3195738.0,8.0,23.98751,114.855
max,4194317.0,10.0,30.0,150.0


id                           0
Brand                   126758
Material                110962
Size                     87785
Compartments                 0
Laptop Compartment       98533
Waterproof               94324
Style                   104180
Color                   133617
Weight Capacity (kg)      1808
Price                        0
dtype: int64

[      0       1       2 ... 4194315 4194316 4194317]
int64
['Jansport' 'Under Armour' 'Nike' 'Adidas' 'Puma' nan]
object
['Leather' 'Canvas' 'Nylon' nan 'Polyester']
object
['Medium' 'Small' 'Large' nan]
object
[ 7. 10.  2.  8.  1.  3.  5.  9.  6.  4.]
float64
['Yes' 'No' nan]
object
['No' 'Yes' nan]
object
['Tote' 'Messenger' nan 'Backpack']
object
['Black' 'Green' 'Red' 'Blue' 'Gray' 'Pink' nan]
object
[11.61172281 27.07853658 16.64375995 ... 12.79080004 22.95972519
 16.64173875]
float64
[112.15875  68.88056  39.1732  ...  72.77859 100.96727 100.97298]
float64


In [9]:
import torch
import pandas as pd
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder, StandardScaler

class BackpackPriceDataset(Dataset):
    def __init__(self, csv_files, test_mode=False) -> None:
        self.device = torch.device('cpu')
        self.test_mode = test_mode

        if isinstance(csv_files, str):
            self.data = pd.read_csv(csv_files)
        else:
            dfs = [pd.read_csv(file) for file in csv_files]
            self.data = pd.concat(dfs, ignore_index=True)


        cols_to_scale = ['Weight Capacity (kg)']
        self.numerical_cols = ['Size', 'Compartments', 'Laptop Compartment', 'Waterproof', 'Weight Capacity (kg)']
        self.categorical_cols = ['Brand', 'Material', 'Style', 'Color']

        self._handle_missing_values()

        self.label_encoders = {}
        self.scaler = StandardScaler()

        for col in self.categorical_cols:
            le = LabelEncoder()
            self.data[col] = le.fit_transform(self.data[col])
            self.label_encoders[col] = le

        self.data[cols_to_scale] = self.scaler.fit_transform(self.data[cols_to_scale])

        self.num_categories = {
            col: len(self.data[col].unique()) for col in self.categorical_cols
        }

        drop_cols = ['id']
        if not test_mode:
            drop_cols.append('Price')
            self.target = self.data['Price']

        self.features = self.data.drop(drop_cols, axis=1)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # return categorical and numerical seperately due to embedding of categorical
        categorical = torch.tensor([
            self.data[col].iloc[idx] for col in self.categorical_cols
        ], dtype=torch.long)#.to(self.device)

        numerical = torch.tensor([
            self.data[col].iloc[idx] for col in self.numerical_cols
        ], dtype=torch.float32)#.to(self.device)

        if self.test_mode:
            return categorical, numerical
        else:
            target = torch.tensor([self.target.iloc[idx]], dtype=torch.float32)#.to(self.device)
            return categorical, numerical, target

    def _handle_missing_values(self):
        # https://medium.com/@felipecaballero/deciphering-the-cryptic-futurewarning-for-fillna-in-pandas-2-01deb4e411a1
        with pd.option_context('future.no_silent_downcasting', True):
            for col in self.categorical_cols:
                self.data[col] = self.data[col].fillna("Missing")
        
            self.data['Size'] = self.data['Size'].fillna("Missing")
            self.data['Size'] = self.data['Size'].replace({
                'Small': -1,
                'Medium': 0,
                'Large': 1,
                'Missing': 0  # Assume missing sizes are Medium
            }).infer_objects()
            
            # Compartments (whole numbers)
            median_compartments = round(self.data['Compartments'].median())
            self.data['Compartments'] = self.data['Compartments'].fillna(median_compartments).infer_objects()
            
            # Weight Capacity (continuous)
            self.data['Weight Capacity (kg)'] = self.data['Weight Capacity (kg)'].fillna(
                self.data['Weight Capacity (kg)'].median()
            )
            
            # Binary features (assume missing means "No")
            for col in ['Laptop Compartment', 'Waterproof']:
                self.data[col] = self.data[col].fillna("No")
                self.data[col] = self.data[col].replace({'No': 0, 'Yes': 1}).infer_objects()


train_dataset = BackpackPriceDataset(["train.csv", "training_extra.csv"], test_mode=False)
test_dataset = BackpackPriceDataset("test.csv", test_mode=True)
random_seed = 42

### Feature Importance 

In [10]:
from sklearn.ensemble import RandomForestRegressor
import numpy as np 

train2 = BackpackPriceDataset("train.csv", test_mode=False)
X, y = train2.features, train2.target

rf = RandomForestRegressor(n_estimators=100, random_state=random_seed, max_depth=10)
rf.fit(X,y)

importance = pd.DataFrame({
    'feature': X.columns,
    'importance': rf.feature_importances_
}).sort_values('importance', ascending=False)

print(importance)

                feature  importance
8  Weight Capacity (kg)    0.515199
3          Compartments    0.111563
7                 Color    0.089124
0                 Brand    0.075244
1              Material    0.058089
6                 Style    0.052382
2                  Size    0.046405
5            Waterproof    0.026255
4    Laptop Compartment    0.025739


In [11]:
for col, count in train_dataset.num_categories.items():
    rec_dim = min(50, count//2)
    print(f"{col}: {count}. Recommended embedding_dim: {rec_dim}")

Brand: 6. Recommended embedding_dim: 3
Material: 5. Recommended embedding_dim: 2
Style: 4. Recommended embedding_dim: 2
Color: 7. Recommended embedding_dim: 3


In [12]:
import torch
import torch.nn as nn

class BackpackPriceNet(nn.Module):
    def __init__(self, num_categories_dict) -> None:
        super().__init__()

        self.embeddings = nn.ModuleDict({
            'Brand': nn.Embedding(num_categories_dict['Brand'], 3),
            'Material': nn.Embedding(num_categories_dict['Material'], 2),
            'Style': nn.Embedding(num_categories_dict['Style'], 2),
            'Color': nn.Embedding(num_categories_dict['Color'], 3)
        })

        embedding_dim = 3 + 2 + 2 + 3
        numerical_dim = 5
        input_dim = embedding_dim + numerical_dim


        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.3),


            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.2),


            nn.Linear(128, 64),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.1),

            nn.Linear(64, 1) 
        )

    def forward(self, categorical_inputs, numerical_inputs):
        embeddings = []
        for i, (_, embedding_layer) in enumerate(self.embeddings.items()):
            embedding = embedding_layer(categorical_inputs[:, i])
            embeddings.append(embedding)
        
        x_cat = torch.cat(embeddings, dim=1)
        x = torch.cat([x_cat, numerical_inputs], dim=1)

        return self.model(x)

In [13]:
from torch.utils.data import DataLoader
def train_loop(dataloader: DataLoader, model: BackpackPriceNet, loss_fn, optimizer):
    num_batches = len(dataloader)
    model.train()
    total_loss = 0
    
    for batch, (cat_features, num_features, target) in enumerate(dataloader):
        pred = model(cat_features, num_features)
        loss = loss_fn(pred, target) 
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            print(f"Training batch loss: {loss.item():>7f} [{batch:>5d}/{num_batches:>5d}]")

    return total_loss / num_batches

def val_loop(dataloader: DataLoader, model: BackpackPriceNet, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch, (cat_features, num_features, target) in enumerate(dataloader):
            pred = model(cat_features, num_features)
            val_loss += loss_fn(pred, target).item()

    val_loss /= num_batches
    print(f"\nValdidation average loss: {val_loss:>8f}\n")
    return val_loss

def predict(dataloader: DataLoader, model: BackpackPriceNet):
    model.eval()
    predictions = []

    with torch.no_grad():
        for cat_features, num_features in dataloader:
            outputs = model(cat_features, num_features)
            predictions.extend(outputs.cpu().numpy())

    return np.array(predictions)

In [14]:
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import train_test_split
import numpy as np
from collections import defaultdict

torch.manual_seed(random_seed)
np.random.seed(random_seed)

indicies = np.arange(len(train_dataset))
train_indicies, val_indicies = train_test_split(indicies, test_size=0.2, random_state=random_seed)

train_sampler = SubsetRandomSampler(train_indicies)
val_sampler = SubsetRandomSampler(val_indicies)

batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
val_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=val_sampler)

my_model = BackpackPriceNet(train_dataset.num_categories)
device = torch.device('cpu')
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(my_model.parameters(), lr=5e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.2, patience=2)
history = defaultdict(list)

best_val_loss = float('inf')
best_model_path = "best_model.pth"

patience = 3
early_stopping_counter = 0

num_epochs = 20
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}\n----------------------")
    train_loss = train_loop(train_loader, my_model, loss_fn, optimizer)
    val_loss = val_loop(val_loader, my_model, loss_fn)
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)


    if val_loss < best_val_loss:
        early_stopping_counter = 0
        best_val_loss = val_loss
        torch.save(my_model.state_dict(), best_model_path)
        print(f"Saved best new model with val_loss: {val_loss:.4f}")
    else:
        early_stopping_counter += 1

    if early_stopping_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

    scheduler.step(val_loss)

Epoch 1
----------------------
Training batch loss: 7435.022461 [    0/24965]
Training batch loss: 8452.383789 [  100/24965]
Training batch loss: 8118.629883 [  200/24965]
Training batch loss: 8527.283203 [  300/24965]
Training batch loss: 7577.905762 [  400/24965]
Training batch loss: 9326.549805 [  500/24965]
Training batch loss: 7728.471680 [  600/24965]
Training batch loss: 7010.750488 [  700/24965]
Training batch loss: 7773.871582 [  800/24965]
Training batch loss: 8372.834961 [  900/24965]
Training batch loss: 8770.481445 [ 1000/24965]
Training batch loss: 8973.727539 [ 1100/24965]
Training batch loss: 8109.026367 [ 1200/24965]
Training batch loss: 7710.382812 [ 1300/24965]
Training batch loss: 7395.491699 [ 1400/24965]
Training batch loss: 7229.483887 [ 1500/24965]
Training batch loss: 8430.384766 [ 1600/24965]
Training batch loss: 8327.617188 [ 1700/24965]
Training batch loss: 7625.402344 [ 1800/24965]
Training batch loss: 7368.419922 [ 1900/24965]
Training batch loss: 8170.317

Training batch loss: 2223.137207 [17400/24965]
Training batch loss: 2066.749512 [17500/24965]
Training batch loss: 1767.474121 [17600/24965]
Training batch loss: 1879.516113 [17700/24965]
Training batch loss: 2035.404297 [17800/24965]
Training batch loss: 2177.940918 [17900/24965]
Training batch loss: 1923.640259 [18000/24965]
Training batch loss: 1957.852051 [18100/24965]
Training batch loss: 1836.286621 [18200/24965]
Training batch loss: 1663.052612 [18300/24965]
Training batch loss: 1961.281860 [18400/24965]
Training batch loss: 1613.505005 [18500/24965]
Training batch loss: 2192.770020 [18600/24965]
Training batch loss: 1799.018433 [18700/24965]
Training batch loss: 1800.094849 [18800/24965]
Training batch loss: 2074.854980 [18900/24965]
Training batch loss: 1850.255249 [19000/24965]
Training batch loss: 1687.624023 [19100/24965]
Training batch loss: 1703.998291 [19200/24965]
Training batch loss: 1706.073608 [19300/24965]
Training batch loss: 1730.828857 [19400/24965]
Training batc

Training batch loss: 1654.461304 [ 9600/24965]
Training batch loss: 1379.119019 [ 9700/24965]
Training batch loss: 1451.369141 [ 9800/24965]
Training batch loss: 1524.604004 [ 9900/24965]
Training batch loss: 1510.420166 [10000/24965]
Training batch loss: 1439.000244 [10100/24965]
Training batch loss: 1597.928955 [10200/24965]
Training batch loss: 1296.024292 [10300/24965]
Training batch loss: 1433.372070 [10400/24965]
Training batch loss: 1430.522461 [10500/24965]
Training batch loss: 1419.919678 [10600/24965]
Training batch loss: 1647.433350 [10700/24965]
Training batch loss: 1424.681641 [10800/24965]
Training batch loss: 1388.713501 [10900/24965]
Training batch loss: 1591.324097 [11000/24965]
Training batch loss: 1620.975098 [11100/24965]
Training batch loss: 1554.707031 [11200/24965]
Training batch loss: 1632.086670 [11300/24965]
Training batch loss: 1386.722168 [11400/24965]
Training batch loss: 1449.629883 [11500/24965]
Training batch loss: 1555.185791 [11600/24965]
Training batc

Training batch loss: 1566.329346 [ 1800/24965]
Training batch loss: 1279.798706 [ 1900/24965]
Training batch loss: 1528.656006 [ 2000/24965]
Training batch loss: 1409.481201 [ 2100/24965]
Training batch loss: 1649.669312 [ 2200/24965]
Training batch loss: 1593.489014 [ 2300/24965]
Training batch loss: 1728.335449 [ 2400/24965]
Training batch loss: 1825.300537 [ 2500/24965]
Training batch loss: 1501.076538 [ 2600/24965]
Training batch loss: 1527.099609 [ 2700/24965]
Training batch loss: 1662.006470 [ 2800/24965]
Training batch loss: 1443.320190 [ 2900/24965]
Training batch loss: 1823.209473 [ 3000/24965]
Training batch loss: 1428.242920 [ 3100/24965]
Training batch loss: 1552.173096 [ 3200/24965]
Training batch loss: 1529.279541 [ 3300/24965]
Training batch loss: 1460.691895 [ 3400/24965]
Training batch loss: 1634.933594 [ 3500/24965]
Training batch loss: 1397.359253 [ 3600/24965]
Training batch loss: 1617.921265 [ 3700/24965]
Training batch loss: 1628.846313 [ 3800/24965]
Training batc

Training batch loss: 1447.532227 [19300/24965]
Training batch loss: 1505.209473 [19400/24965]
Training batch loss: 1453.030273 [19500/24965]
Training batch loss: 1838.580811 [19600/24965]
Training batch loss: 1398.358154 [19700/24965]
Training batch loss: 1552.665405 [19800/24965]
Training batch loss: 1881.666504 [19900/24965]
Training batch loss: 1731.533936 [20000/24965]
Training batch loss: 1577.401367 [20100/24965]
Training batch loss: 1559.473389 [20200/24965]
Training batch loss: 1336.409668 [20300/24965]
Training batch loss: 1505.838867 [20400/24965]
Training batch loss: 1312.485962 [20500/24965]
Training batch loss: 1457.518433 [20600/24965]
Training batch loss: 1556.842041 [20700/24965]
Training batch loss: 1436.734375 [20800/24965]
Training batch loss: 1537.797607 [20900/24965]
Training batch loss: 1697.265625 [21000/24965]
Training batch loss: 1640.639526 [21100/24965]
Training batch loss: 1532.984619 [21200/24965]
Training batch loss: 1534.218018 [21300/24965]
Training batc

Training batch loss: 1569.105835 [11600/24965]
Training batch loss: 1501.100464 [11700/24965]
Training batch loss: 1665.773682 [11800/24965]
Training batch loss: 1560.345459 [11900/24965]
Training batch loss: 1612.896973 [12000/24965]
Training batch loss: 1640.108521 [12100/24965]
Training batch loss: 1575.722046 [12200/24965]
Training batch loss: 1513.358032 [12300/24965]
Training batch loss: 1765.749023 [12400/24965]
Training batch loss: 1479.712769 [12500/24965]
Training batch loss: 1487.331299 [12600/24965]
Training batch loss: 1442.582764 [12700/24965]
Training batch loss: 1558.166016 [12800/24965]
Training batch loss: 1324.170532 [12900/24965]
Training batch loss: 1609.466553 [13000/24965]
Training batch loss: 1629.092407 [13100/24965]
Training batch loss: 1373.449097 [13200/24965]
Training batch loss: 1581.291748 [13300/24965]
Training batch loss: 1475.670044 [13400/24965]
Training batch loss: 1463.971924 [13500/24965]
Training batch loss: 1788.394409 [13600/24965]
Training batc

Training batch loss: 1493.183105 [ 3800/24965]
Training batch loss: 1547.776245 [ 3900/24965]
Training batch loss: 1458.045288 [ 4000/24965]
Training batch loss: 1911.064941 [ 4100/24965]
Training batch loss: 1596.598389 [ 4200/24965]
Training batch loss: 1537.109253 [ 4300/24965]
Training batch loss: 1583.044312 [ 4400/24965]
Training batch loss: 1629.024048 [ 4500/24965]
Training batch loss: 1372.716431 [ 4600/24965]
Training batch loss: 1473.670044 [ 4700/24965]
Training batch loss: 1470.822510 [ 4800/24965]
Training batch loss: 1319.551758 [ 4900/24965]
Training batch loss: 1549.680664 [ 5000/24965]
Training batch loss: 1344.026367 [ 5100/24965]
Training batch loss: 1553.631226 [ 5200/24965]
Training batch loss: 1568.559204 [ 5300/24965]
Training batch loss: 1339.137573 [ 5400/24965]
Training batch loss: 1520.302856 [ 5500/24965]
Training batch loss: 1707.706909 [ 5600/24965]
Training batch loss: 1459.256592 [ 5700/24965]
Training batch loss: 1571.508057 [ 5800/24965]
Training batc

Training batch loss: 1482.662720 [21300/24965]
Training batch loss: 1579.262817 [21400/24965]
Training batch loss: 1606.850098 [21500/24965]
Training batch loss: 1576.645874 [21600/24965]
Training batch loss: 1495.675415 [21700/24965]
Training batch loss: 1730.849609 [21800/24965]
Training batch loss: 1373.790161 [21900/24965]
Training batch loss: 1553.114258 [22000/24965]
Training batch loss: 1510.237915 [22100/24965]
Training batch loss: 1476.708740 [22200/24965]
Training batch loss: 1660.680176 [22300/24965]
Training batch loss: 1433.407471 [22400/24965]
Training batch loss: 1738.038330 [22500/24965]
Training batch loss: 1412.877930 [22600/24965]
Training batch loss: 1537.830200 [22700/24965]
Training batch loss: 1706.430542 [22800/24965]
Training batch loss: 1604.781982 [22900/24965]
Training batch loss: 1573.181274 [23000/24965]
Training batch loss: 1391.753418 [23100/24965]
Training batch loss: 1532.963623 [23200/24965]
Training batch loss: 1337.601562 [23300/24965]
Training batc

Training batch loss: 1739.799561 [13500/24965]
Training batch loss: 1582.417725 [13600/24965]
Training batch loss: 1712.798706 [13700/24965]
Training batch loss: 1588.728882 [13800/24965]
Training batch loss: 1711.854980 [13900/24965]
Training batch loss: 1616.841431 [14000/24965]
Training batch loss: 1535.154297 [14100/24965]
Training batch loss: 1387.232178 [14200/24965]
Training batch loss: 1642.568848 [14300/24965]
Training batch loss: 1422.808105 [14400/24965]
Training batch loss: 1655.029541 [14500/24965]
Training batch loss: 1208.858032 [14600/24965]
Training batch loss: 1649.759766 [14700/24965]
Training batch loss: 1619.447510 [14800/24965]
Training batch loss: 1672.494629 [14900/24965]
Training batch loss: 1652.381836 [15000/24965]
Training batch loss: 1518.801270 [15100/24965]
Training batch loss: 1636.125732 [15200/24965]
Training batch loss: 1597.594482 [15300/24965]
Training batch loss: 1553.082031 [15400/24965]
Training batch loss: 1450.116577 [15500/24965]
Training batc

Training batch loss: 1364.312988 [ 5800/24965]
Training batch loss: 1661.155762 [ 5900/24965]
Training batch loss: 1442.313721 [ 6000/24965]
Training batch loss: 1532.512939 [ 6100/24965]
Training batch loss: 1516.230713 [ 6200/24965]
Training batch loss: 1422.595947 [ 6300/24965]
Training batch loss: 1561.768066 [ 6400/24965]
Training batch loss: 1425.065430 [ 6500/24965]
Training batch loss: 1619.039429 [ 6600/24965]
Training batch loss: 1354.949951 [ 6700/24965]
Training batch loss: 1501.912231 [ 6800/24965]
Training batch loss: 1551.894287 [ 6900/24965]
Training batch loss: 1562.266113 [ 7000/24965]
Training batch loss: 1361.649902 [ 7100/24965]
Training batch loss: 1597.582764 [ 7200/24965]
Training batch loss: 1469.941650 [ 7300/24965]
Training batch loss: 1655.879639 [ 7400/24965]
Training batch loss: 1453.826172 [ 7500/24965]
Training batch loss: 1515.516602 [ 7600/24965]
Training batch loss: 1676.529785 [ 7700/24965]
Training batch loss: 1509.465576 [ 7800/24965]
Training batc

Training batch loss: 1433.591064 [23300/24965]
Training batch loss: 1499.760620 [23400/24965]
Training batch loss: 1652.404419 [23500/24965]
Training batch loss: 1495.310547 [23600/24965]
Training batch loss: 1584.729248 [23700/24965]
Training batch loss: 1466.127686 [23800/24965]
Training batch loss: 1695.906372 [23900/24965]
Training batch loss: 1630.242065 [24000/24965]
Training batch loss: 1562.821533 [24100/24965]
Training batch loss: 1493.789551 [24200/24965]
Training batch loss: 1439.615845 [24300/24965]
Training batch loss: 1399.052002 [24400/24965]
Training batch loss: 1690.518677 [24500/24965]
Training batch loss: 1585.938477 [24600/24965]
Training batch loss: 1555.505371 [24700/24965]
Training batch loss: 1489.509277 [24800/24965]
Training batch loss: 1457.419067 [24900/24965]

Valdidation average loss: 1512.852643

Saved best new model with val_loss: 1512.8526
Epoch 8
----------------------
Training batch loss: 1474.605347 [    0/24965]
Training batch loss: 1564.360229 [  1

Training batch loss: 1492.590576 [15500/24965]
Training batch loss: 1442.728638 [15600/24965]
Training batch loss: 1589.690308 [15700/24965]
Training batch loss: 1468.378174 [15800/24965]
Training batch loss: 1456.079468 [15900/24965]
Training batch loss: 1583.744507 [16000/24965]
Training batch loss: 1486.659668 [16100/24965]
Training batch loss: 1656.908081 [16200/24965]
Training batch loss: 1753.070312 [16300/24965]
Training batch loss: 1524.675781 [16400/24965]
Training batch loss: 1567.188843 [16500/24965]
Training batch loss: 1557.633545 [16600/24965]
Training batch loss: 1460.766235 [16700/24965]
Training batch loss: 1559.666016 [16800/24965]
Training batch loss: 1514.164551 [16900/24965]
Training batch loss: 1320.965088 [17000/24965]
Training batch loss: 1235.179199 [17100/24965]
Training batch loss: 1552.304199 [17200/24965]
Training batch loss: 1288.289551 [17300/24965]
Training batch loss: 1456.307617 [17400/24965]
Training batch loss: 1499.489990 [17500/24965]
Training batc

Training batch loss: 1525.964233 [ 7700/24965]
Training batch loss: 1642.516479 [ 7800/24965]
Training batch loss: 1547.167114 [ 7900/24965]
Training batch loss: 1408.508423 [ 8000/24965]
Training batch loss: 1514.239136 [ 8100/24965]
Training batch loss: 1503.700928 [ 8200/24965]
Training batch loss: 1670.320312 [ 8300/24965]
Training batch loss: 1594.161133 [ 8400/24965]
Training batch loss: 1382.493286 [ 8500/24965]
Training batch loss: 1659.177002 [ 8600/24965]
Training batch loss: 1480.062988 [ 8700/24965]
Training batch loss: 1783.317139 [ 8800/24965]
Training batch loss: 1694.484863 [ 8900/24965]
Training batch loss: 1737.347900 [ 9000/24965]
Training batch loss: 1687.156616 [ 9100/24965]
Training batch loss: 1557.188477 [ 9200/24965]
Training batch loss: 1709.616577 [ 9300/24965]
Training batch loss: 1703.900757 [ 9400/24965]
Training batch loss: 1508.000732 [ 9500/24965]
Training batch loss: 1401.359131 [ 9600/24965]
Training batch loss: 1418.184814 [ 9700/24965]
Training batc

Training batch loss: 1747.395752 [    0/24965]
Training batch loss: 1566.282104 [  100/24965]
Training batch loss: 1419.647705 [  200/24965]
Training batch loss: 1408.679321 [  300/24965]
Training batch loss: 1505.484375 [  400/24965]
Training batch loss: 1471.615845 [  500/24965]
Training batch loss: 1387.279297 [  600/24965]
Training batch loss: 1355.944336 [  700/24965]
Training batch loss: 1566.263672 [  800/24965]
Training batch loss: 1600.377075 [  900/24965]
Training batch loss: 1428.605469 [ 1000/24965]
Training batch loss: 1570.876831 [ 1100/24965]
Training batch loss: 1539.791992 [ 1200/24965]
Training batch loss: 1730.848877 [ 1300/24965]
Training batch loss: 1688.593628 [ 1400/24965]
Training batch loss: 1512.961426 [ 1500/24965]
Training batch loss: 1881.628662 [ 1600/24965]
Training batch loss: 1528.375610 [ 1700/24965]
Training batch loss: 1443.827148 [ 1800/24965]
Training batch loss: 1638.624756 [ 1900/24965]
Training batch loss: 1881.926758 [ 2000/24965]
Training batc

Training batch loss: 1495.550537 [17500/24965]
Training batch loss: 1581.665039 [17600/24965]
Training batch loss: 1658.697632 [17700/24965]
Training batch loss: 1277.811401 [17800/24965]
Training batch loss: 1372.690186 [17900/24965]
Training batch loss: 1686.733765 [18000/24965]
Training batch loss: 1662.855835 [18100/24965]
Training batch loss: 1768.999756 [18200/24965]
Training batch loss: 1593.817993 [18300/24965]
Training batch loss: 1736.391357 [18400/24965]
Training batch loss: 1469.734497 [18500/24965]
Training batch loss: 1392.452026 [18600/24965]
Training batch loss: 1498.351685 [18700/24965]
Training batch loss: 1563.713501 [18800/24965]
Training batch loss: 1497.589355 [18900/24965]
Training batch loss: 1451.369751 [19000/24965]
Training batch loss: 1405.109863 [19100/24965]
Training batch loss: 1575.748413 [19200/24965]
Training batch loss: 1724.825439 [19300/24965]
Training batch loss: 1802.944702 [19400/24965]
Training batch loss: 1290.926025 [19500/24965]
Training batc

Training batch loss: 1540.551025 [ 9800/24965]
Training batch loss: 1481.988403 [ 9900/24965]
Training batch loss: 1582.145020 [10000/24965]
Training batch loss: 1634.385986 [10100/24965]
Training batch loss: 1527.229492 [10200/24965]
Training batch loss: 1254.158203 [10300/24965]
Training batch loss: 1641.947021 [10400/24965]
Training batch loss: 1415.331299 [10500/24965]
Training batch loss: 1479.791992 [10600/24965]
Training batch loss: 1622.416504 [10700/24965]
Training batch loss: 1534.786011 [10800/24965]
Training batch loss: 1556.313965 [10900/24965]
Training batch loss: 1326.552979 [11000/24965]
Training batch loss: 1639.604004 [11100/24965]
Training batch loss: 1501.131470 [11200/24965]
Training batch loss: 1394.296265 [11300/24965]
Training batch loss: 1587.944458 [11400/24965]
Training batch loss: 1512.000732 [11500/24965]
Training batch loss: 1548.135010 [11600/24965]
Training batch loss: 1463.538818 [11700/24965]
Training batch loss: 1447.692627 [11800/24965]
Training batc

Training batch loss: 1468.256104 [ 2000/24965]
Training batch loss: 1622.921509 [ 2100/24965]
Training batch loss: 1504.946533 [ 2200/24965]
Training batch loss: 1450.439209 [ 2300/24965]
Training batch loss: 1436.020142 [ 2400/24965]
Training batch loss: 1579.130371 [ 2500/24965]
Training batch loss: 1609.826660 [ 2600/24965]
Training batch loss: 1625.101440 [ 2700/24965]
Training batch loss: 1592.593262 [ 2800/24965]
Training batch loss: 1635.536621 [ 2900/24965]
Training batch loss: 1554.958862 [ 3000/24965]
Training batch loss: 1407.502197 [ 3100/24965]
Training batch loss: 1533.423340 [ 3200/24965]
Training batch loss: 1619.139282 [ 3300/24965]
Training batch loss: 1462.885132 [ 3400/24965]
Training batch loss: 1208.569336 [ 3500/24965]
Training batch loss: 1465.408081 [ 3600/24965]
Training batch loss: 1670.674194 [ 3700/24965]
Training batch loss: 1473.441406 [ 3800/24965]
Training batch loss: 1563.728149 [ 3900/24965]
Training batch loss: 1566.364014 [ 4000/24965]
Training batc

Training batch loss: 1703.772827 [19500/24965]
Training batch loss: 1615.294922 [19600/24965]
Training batch loss: 1331.990967 [19700/24965]
Training batch loss: 1380.929932 [19800/24965]
Training batch loss: 1524.980347 [19900/24965]
Training batch loss: 1632.619995 [20000/24965]
Training batch loss: 1395.432861 [20100/24965]
Training batch loss: 1560.114014 [20200/24965]
Training batch loss: 1493.000000 [20300/24965]
Training batch loss: 1575.111450 [20400/24965]
Training batch loss: 1503.574951 [20500/24965]
Training batch loss: 1610.295166 [20600/24965]
Training batch loss: 1501.185303 [20700/24965]
Training batch loss: 1467.336304 [20800/24965]
Training batch loss: 1730.845215 [20900/24965]
Training batch loss: 1621.004761 [21000/24965]
Training batch loss: 1356.439575 [21100/24965]
Training batch loss: 1364.284546 [21200/24965]
Training batch loss: 1634.645996 [21300/24965]
Training batch loss: 1486.981079 [21400/24965]
Training batch loss: 1665.983276 [21500/24965]
Training batc

Training batch loss: 1610.632080 [11800/24965]
Training batch loss: 1539.104248 [11900/24965]
Training batch loss: 1510.541260 [12000/24965]
Training batch loss: 1494.773315 [12100/24965]
Training batch loss: 1545.222534 [12200/24965]
Training batch loss: 1461.443726 [12300/24965]
Training batch loss: 1614.707275 [12400/24965]
Training batch loss: 1450.517822 [12500/24965]
Training batch loss: 1390.133057 [12600/24965]
Training batch loss: 1664.048218 [12700/24965]
Training batch loss: 1532.425537 [12800/24965]
Training batch loss: 1579.309204 [12900/24965]
Training batch loss: 1521.003784 [13000/24965]
Training batch loss: 1634.213135 [13100/24965]
Training batch loss: 1353.768677 [13200/24965]
Training batch loss: 1469.040771 [13300/24965]
Training batch loss: 1618.965088 [13400/24965]
Training batch loss: 1658.867676 [13500/24965]
Training batch loss: 1308.738281 [13600/24965]
Training batch loss: 1360.612793 [13700/24965]
Training batch loss: 1673.206055 [13800/24965]
Training batc

Training batch loss: 1631.617188 [ 4100/24965]
Training batch loss: 1641.433228 [ 4200/24965]
Training batch loss: 1584.925781 [ 4300/24965]
Training batch loss: 1526.925659 [ 4400/24965]
Training batch loss: 1396.666870 [ 4500/24965]
Training batch loss: 1242.057739 [ 4600/24965]
Training batch loss: 1660.438843 [ 4700/24965]
Training batch loss: 1497.468384 [ 4800/24965]
Training batch loss: 1501.564453 [ 4900/24965]
Training batch loss: 1505.325195 [ 5000/24965]
Training batch loss: 1508.386963 [ 5100/24965]
Training batch loss: 1561.161743 [ 5200/24965]
Training batch loss: 1545.519043 [ 5300/24965]
Training batch loss: 1452.749756 [ 5400/24965]
Training batch loss: 1504.237549 [ 5500/24965]
Training batch loss: 1488.012939 [ 5600/24965]
Training batch loss: 1568.400024 [ 5700/24965]
Training batch loss: 1608.726807 [ 5800/24965]
Training batch loss: 1481.857788 [ 5900/24965]
Training batch loss: 1537.159180 [ 6000/24965]
Training batch loss: 1610.009033 [ 6100/24965]
Training batc

Training batch loss: 1429.528442 [21600/24965]
Training batch loss: 1841.033813 [21700/24965]
Training batch loss: 1626.209473 [21800/24965]
Training batch loss: 1340.996216 [21900/24965]
Training batch loss: 1610.324463 [22000/24965]
Training batch loss: 1452.061279 [22100/24965]
Training batch loss: 1600.983887 [22200/24965]
Training batch loss: 1647.770508 [22300/24965]
Training batch loss: 1415.121338 [22400/24965]
Training batch loss: 1481.910278 [22500/24965]
Training batch loss: 1481.324585 [22600/24965]
Training batch loss: 1611.178101 [22700/24965]
Training batch loss: 1400.018555 [22800/24965]
Training batch loss: 1375.686646 [22900/24965]
Training batch loss: 1478.522095 [23000/24965]
Training batch loss: 1407.957031 [23100/24965]
Training batch loss: 1466.532471 [23200/24965]
Training batch loss: 1640.520874 [23300/24965]
Training batch loss: 1582.032593 [23400/24965]
Training batch loss: 1367.555786 [23500/24965]
Training batch loss: 1635.899414 [23600/24965]
Training batc

In [15]:
my_model.load_state_dict(torch.load(best_model_path))
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
predictions = predict(test_loader, my_model)
test_predictions_df = pd.DataFrame({
    'id': test_dataset.data['id'],
    'Price': predictions.flatten()
})
test_predictions_df.to_csv('predictions.csv', index=False)