In [1]:
import sys
import os

In [2]:
import torch
import torch.nn as nn

In [3]:
sys.path.append(os.path.abspath("../.."))
from utils import read_features, read_targets, print_info_features, print_info_targets, train_val_split, get_dimensions, scale, \
    metrics_r, get_device, plot_history

In [4]:
sys.path.append(os.path.abspath(".."))
from PotteryDataset import create_pottery_dataloaders, create_pottery_datasets, feature_types, feature_type_combos
from PotteryChronologyPredictor import PotteryChronologyPredictor, train, tune

## Set Working Device

In [5]:
device = get_device()

PyTorch Version: 2.5.1
CUDA is available
GPU: NVIDIA GeForce RTX 4080
Using Device: cuda


## Read Features and Targets

In [6]:
path = os.path.abspath(os.path.join(os.getcwd(), "../../../data/chronology_prediction"))

In [7]:
targets = ["StartYear", "YearRange"]

In [8]:
X = read_features(path, f_type="tensors")
y = read_targets(path, targets, f_type="np")

Loaded X_train_tfidf
Loaded X_train_bert
Loaded X_train_cannyhog
Loaded X_train_resnet
Loaded X_train_vit
Loaded X_test_tfidf
Loaded X_test_bert
Loaded X_test_cannyhog
Loaded X_test_resnet
Loaded X_test_vit
Loaded y_train
Loaded y_test


## Train-Validation Split

In [9]:
X, y = train_val_split(X, y)

## Scale Regression Targets

In [10]:
y, y_scaler = scale(y)

In [11]:
y = {subset: torch.tensor(_y, dtype=torch.float32, device=device) for subset, _y in y.items()}

## Dimensions

In [12]:
X_dim, y_dim = get_dimensions(X, y)

X Dimensions: {'tfidf': 300, 'bert': 768, 'cannyhog': 2917, 'resnet': 2048, 'vit': 768}
y Dimensions: 2


## Torch Datasets and Dataloaders

In [13]:
datasets = create_pottery_datasets(X, y)
loaders = create_pottery_dataloaders(datasets, batch_size=64)

## Tune Model



In [14]:
criterion = nn.MSELoss()
metrics = metrics_r

param_grid = {
    # Architecture Params
    # "hidden_size": [128, 256, 512, 1024],
    "hidden_size": [256, 512],
    "activation": ["relu", "gelu"],
    # "dropout": [0, 0.1, 0.3, 0.5],
    "dropout": [0.1, 0.3],
    # "blocks": [1, 2, 3, 5, 10],
    "blocks": [1, 3, 5],
    "hidden_size_pattern": ["decreasing", "constant"],

    # Training Params
    # "lr": [1e-2, 1e-3, 1e-4, 1e-5],
    "lr": [1e-3, 1e-4],
    # "weight_decay": [1e-5, 1e-6, 1e-7],
    "weight_decay": [1e-5, 1e-6],
}

In [15]:
results, best_result = tune(param_grid, [X_dim["tfidf"]], y_dim, loaders["train"]["tfidf"], loaders["val"]["tfidf"], criterion, metrics, device, ["mae"], y_scaler, chronology_target="years")

+-----------+-------------+------------+---------+--------+---------------------+--------+--------------+-----------+-----------+-----------+
| combo_idx | hidden_size | activation | dropout | blocks | hidden_size_pattern |     lr | weight_decay |  val_loss |     mae_0 |     mae_1 |
+-----------+-------------+------------+---------+--------+---------------------+--------+--------------+-----------+-----------+-----------+
|   001/192 |         256 |       relu |     0.1 |      1 |          decreasing |  0.001 |        1e-05 |    0.6989 |   39.3670 |   10.3009 |
|   002/192 |         256 |       relu |     0.1 |      1 |          decreasing |  0.001 |        1e-06 |    0.6982 |   39.1106 |   10.2933 |
|   003/192 |         256 |       relu |     0.1 |      1 |          decreasing | 0.0001 |        1e-05 |    0.6933 |   39.8462 |   10.3249 |
|   004/192 |         256 |       relu |     0.1 |      1 |          decreasing | 0.0001 |        1e-06 |    0.7050 |   40.0092 |   10.4330 |
|   00

In [16]:
results

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
0,relu,1,0.1,256,decreasing,0.0010,0.000010,0.522023,0.698869,39.367023,10.300889,51.193939,12.628751,0.475412,0.207795,29.236832,9.547012
1,relu,1,0.1,256,decreasing,0.0010,0.000001,0.517046,0.698203,39.110558,10.293305,50.627907,12.682905,0.486948,0.200986,28.037323,9.509384
2,relu,1,0.1,256,decreasing,0.0001,0.000010,0.536878,0.693303,39.846188,10.324909,50.945854,12.588129,0.480484,0.212883,30.815109,9.927046
3,relu,1,0.1,256,decreasing,0.0001,0.000001,0.536639,0.705013,40.009163,10.433044,51.314739,12.718434,0.472934,0.196503,29.053375,10.400080
4,relu,1,0.1,256,constant,0.0010,0.000010,0.473692,0.703977,38.856178,10.335842,50.318436,12.809311,0.493202,0.184980,27.403519,9.425718
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,gelu,5,0.3,512,decreasing,0.0001,0.000001,0.719363,0.817804,40.860920,12.180607,52.299404,14.220887,0.452512,-0.004547,30.999985,11.025905
188,gelu,5,0.3,512,constant,0.0010,0.000010,0.624552,0.737823,40.929306,10.468419,53.317257,12.836107,0.430994,0.181566,28.804153,9.941649
189,gelu,5,0.3,512,constant,0.0010,0.000001,0.563985,0.739436,40.902794,10.329735,53.090523,12.902822,0.435823,0.173037,30.947601,9.709642
190,gelu,5,0.3,512,constant,0.0001,0.000010,0.762742,0.827137,42.253788,12.132495,53.383808,14.211568,0.429573,-0.003231,34.998215,10.111308


In [17]:
results.sort_values("val_loss").head(1)

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
84,relu,5,0.3,256,constant,0.001,1e-05,0.408783,0.665797,34.990295,10.363354,46.973461,12.888397,0.558342,0.174885,25.417358,9.503178


In [18]:
results.sort_values("mae_0").head()

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
88,relu,5,0.3,512,decreasing,0.001,1e-05,0.378869,0.669857,34.234901,10.174505,47.100307,12.714995,0.555953,0.196938,23.516068,8.998859
77,relu,5,0.1,512,constant,0.001,1e-06,0.313494,0.690476,34.640507,10.384533,47.853043,12.937171,0.541647,0.168628,22.753494,9.562515
53,relu,3,0.3,256,constant,0.001,1e-06,0.284399,0.691974,34.887783,10.434025,47.85622,12.956754,0.541586,0.166109,24.237122,8.834621
84,relu,5,0.3,256,constant,0.001,1e-05,0.408783,0.665797,34.990295,10.363354,46.973461,12.888397,0.558342,0.174885,25.417358,9.503178
64,relu,5,0.1,256,decreasing,0.001,1e-05,0.442005,0.676297,35.891453,10.347416,50.224258,12.499561,0.495097,0.22392,26.430099,10.197212


In [19]:
best_result = results.sort_values("val_loss", ascending=True).head(1).to_dict(orient="records")[0]

In [20]:
best_result

{'activation': 'relu',
 'blocks': 5,
 'dropout': 0.3,
 'hidden_size': 256,
 'hidden_size_pattern': 'constant',
 'lr': 0.001,
 'weight_decay': 1e-05,
 'train_loss': 0.408782958984375,
 'val_loss': 0.6657971739768982,
 'mae_0': 34.99029541015625,
 'mae_1': 10.363353729248047,
 'rmse_0': 46.97346115112305,
 'rmse_1': 12.888397216796875,
 'r2_0': 0.558341920375824,
 'r2_1': 0.17488467693328857,
 'medae_0': 25.4173583984375,
 'medae_1': 9.503177642822266}