In [1]:
import sys
import os

In [2]:
import torch
import torch.nn as nn

In [3]:
sys.path.append(os.path.abspath("../.."))
from utils import read_features, read_targets, print_info_features, print_info_targets, train_val_split, get_dimensions, scale, \
    metrics_r, get_device, plot_history

In [4]:
sys.path.append(os.path.abspath(".."))
from PotteryDataset import create_pottery_dataloaders, create_pottery_datasets, feature_types, feature_type_combos
from PotteryChronologyPredictor import PotteryChronologyPredictor, train, tune

## Set Working Device

In [5]:
device = get_device()

PyTorch Version: 2.5.1
CUDA is available
GPU: NVIDIA GeForce RTX 4080
Using Device: cuda


## Read Features and Targets

In [6]:
path = os.path.abspath(os.path.join(os.getcwd(), "../../../data/chronology_prediction"))

In [7]:
targets = ["StartYear", "YearRange"]

In [8]:
X = read_features(path, f_type="tensors")
y = read_targets(path, targets, f_type="np")

Loaded X_train_tfidf
Loaded X_train_bert
Loaded X_train_cannyhog
Loaded X_train_resnet
Loaded X_train_vit
Loaded X_test_tfidf
Loaded X_test_bert
Loaded X_test_cannyhog
Loaded X_test_resnet
Loaded X_test_vit
Loaded y_train
Loaded y_test


## Train-Validation Split

In [9]:
X, y = train_val_split(X, y)

## Scale Regression Targets

In [10]:
y, y_scaler = scale(y)

In [11]:
y = {subset: torch.tensor(_y, dtype=torch.float32, device=device) for subset, _y in y.items()}

## Dimensions

In [12]:
X_dim, y_dim = get_dimensions(X, y)

X Dimensions: {'tfidf': 300, 'bert': 768, 'cannyhog': 2917, 'resnet': 2048, 'vit': 768}
y Dimensions: 2


## Torch Datasets and Dataloaders

In [13]:
datasets = create_pottery_datasets(X, y)
loaders = create_pottery_dataloaders(datasets, batch_size=64)

## Tune Model



In [14]:
criterion = nn.MSELoss()
metrics = metrics_r

param_grid = {
    # Architecture Params
    # "hidden_size": [128, 256, 512, 1024],
    "hidden_size": [256, 512],
    "activation": ["relu", "gelu"],
    # "dropout": [0, 0.1, 0.3, 0.5],
    "dropout": [0.1, 0.3],
    # "blocks": [1, 2, 3, 5, 10],
    "blocks": [1, 3, 5],
    "hidden_size_pattern": ["decreasing", "constant"],

    # Training Params
    # "lr": [1e-2, 1e-3, 1e-4, 1e-5],
    "lr": [1e-3, 1e-4],
    # "weight_decay": [1e-5, 1e-6, 1e-7],
    "weight_decay": [1e-5, 1e-6],
}

In [15]:
results = tune(param_grid, [X_dim["tfidf"]], y_dim, loaders["train"]["tfidf"], loaders["val"]["tfidf"], criterion, metrics, device, ["mae"], y_scaler, chronology_target="years")

+-------------+------------+---------+--------+---------------------+--------+--------------+-----------+-----------+-----------+
| hidden_size | activation | dropout | blocks | hidden_size_pattern |     lr | weight_decay |  val_loss |     mae_0 |     mae_1 |
+-------------+------------+---------+--------+---------------------+--------+--------------+-----------+-----------+-----------+
|         256 |       relu |  0.1000 |      1 |          decreasing | 0.0010 |       0.0000 |    0.6990 |   37.9618 |   10.3504 |
|         256 |       relu |  0.1000 |      1 |          decreasing | 0.0010 |       0.0000 |    0.6753 |   38.4136 |   10.3322 |
|         256 |       relu |  0.1000 |      1 |          decreasing | 0.0001 |       0.0000 |    0.6938 |   39.4314 |   10.3666 |
|         256 |       relu |  0.1000 |      1 |          decreasing | 0.0001 |       0.0000 |    0.6904 |   39.2657 |   10.2208 |
|         256 |       relu |  0.1000 |      1 |            constant | 0.0010 |       0.000

In [16]:
results

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
0,relu,1,0.1,256,decreasing,0.0010,0.000010,0.440393,0.699005,37.961823,10.350368,50.200699,12.746825,0.495570,0.192912,26.750259,9.624872
1,relu,1,0.1,256,decreasing,0.0010,0.000001,0.576279,0.675343,38.413647,10.332208,49.812416,12.485808,0.503343,0.225627,28.441040,9.299501
2,relu,1,0.1,256,decreasing,0.0001,0.000010,0.508080,0.693785,39.431416,10.366599,50.411591,12.670953,0.491323,0.202491,29.544495,10.105228
3,relu,1,0.1,256,decreasing,0.0001,0.000001,0.479013,0.690392,39.265663,10.220801,50.641953,12.548526,0.486664,0.217828,27.704361,9.855229
4,relu,1,0.1,256,constant,0.0010,0.000010,0.520643,0.698958,39.309059,10.283403,50.983768,12.608711,0.479711,0.210307,28.291138,9.913374
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,gelu,5,0.3,512,decreasing,0.0001,0.000001,0.570172,0.733297,40.612545,10.576758,52.966240,12.891362,0.438462,0.174505,29.140717,10.502915
188,gelu,5,0.3,512,constant,0.0010,0.000010,0.551382,0.727946,41.134708,10.389469,52.877045,12.800246,0.440351,0.186133,31.306046,9.541079
189,gelu,5,0.3,512,constant,0.0010,0.000001,0.653696,0.726964,41.370792,10.460917,53.239540,12.694678,0.432652,0.199502,32.876190,9.492807
190,gelu,5,0.3,512,constant,0.0001,0.000010,0.571836,0.720266,40.745872,10.346955,52.496986,12.729712,0.448368,0.195078,29.599594,10.100544


In [17]:
results.sort_values("val_loss").head()

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
85,relu,5,0.3,256,constant,0.001,1e-06,0.363592,0.662539,35.151901,10.478162,47.814095,12.666899,0.542393,0.203002,22.736145,9.401186
81,relu,5,0.3,256,decreasing,0.001,1e-06,0.380675,0.665269,34.600182,9.969546,48.136787,12.614222,0.536195,0.209617,23.651154,8.323781
8,relu,1,0.1,512,decreasing,0.001,1e-05,0.449879,0.673214,37.370804,10.218632,49.024063,12.553938,0.51894,0.217153,28.222229,9.568018
48,relu,3,0.3,256,decreasing,0.001,1e-05,0.503125,0.675213,37.067127,10.22918,49.067825,12.637218,0.51808,0.206732,29.996384,9.440868
1,relu,1,0.1,256,decreasing,0.001,1e-06,0.576279,0.675343,38.413647,10.332208,49.812416,12.485808,0.503343,0.225627,28.44104,9.299501


In [18]:
results.sort_values("mae_0").head()

Unnamed: 0,activation,blocks,dropout,hidden_size,hidden_size_pattern,lr,weight_decay,train_loss,val_loss,mae_0,mae_1,rmse_0,rmse_1,r2_0,r2_1,medae_0,medae_1
81,relu,5,0.3,256,decreasing,0.001,1e-06,0.380675,0.665269,34.600182,9.969546,48.136787,12.614222,0.536195,0.209617,23.651154,8.323781
80,relu,5,0.3,256,decreasing,0.001,1e-05,0.620988,0.783341,34.824268,12.1759,49.089817,14.294971,0.517648,-0.015041,23.146652,8.927105
92,relu,5,0.3,512,constant,0.001,1e-05,0.207395,0.693536,34.98735,9.990408,48.883034,12.872267,0.521703,0.176949,24.942993,6.846298
85,relu,5,0.3,256,constant,0.001,1e-06,0.363592,0.662539,35.151901,10.478162,47.814095,12.666899,0.542393,0.203002,22.736145,9.401186
44,relu,3,0.1,512,constant,0.001,1e-05,0.196567,0.68393,35.759853,10.215884,48.49725,12.815312,0.529223,0.184216,25.3461,8.990006
