In [1]:
#%env CUDA_VISIBLE_DEVICES=0

In [2]:
from datetime import datetime

import pandas as pd
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import Callback, EarlyStopping, LearningRateMonitor
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchinfo import summary

from era_data import TabletPeriodDataset, get_IDS
from era_model import EraClassifier  # also used for periods

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# Hyperparameters

In [4]:
LR = 5e-5
EPOCHS = 20
BATCH_SIZE = 16
SUFFIX = '-resnet50'
DATE = datetime.now().strftime("%B%d")
RUN_NAME_SUFFIX = '-preprocessed' 
IMG_DIR = 'output/images_preprocessed'
IDS = get_IDS(IMG_DIR=IMG_DIR)
print(len(IDS))
VERSION_NAME = f'period_clf_bs{BATCH_SIZE}_lr{LR}_{EPOCHS}epochs{SUFFIX}-{len(IDS)}_samples{RUN_NAME_SUFFIX}_{DATE}-1000_test_val'
VERSION_NAME

# Load data

In [6]:
#! du -h {IMG_DIR}

In [9]:
train_ids, test_ids = train_test_split(IDS, test_size=1000, random_state=0)
len(train_ids), len(test_ids)

(93936, 1000)

In [10]:
train_ids, val_ids = train_test_split(train_ids, test_size=1000, random_state=0)
len(train_ids), len(val_ids)

(92936, 1000)

In [11]:
ds_train = TabletPeriodDataset(IDS=train_ids, IMG_DIR=IMG_DIR)
ds_val = TabletPeriodDataset(IDS=val_ids, IMG_DIR=IMG_DIR)
ds_test = TabletPeriodDataset(IDS=test_ids, IMG_DIR=IMG_DIR)

Filtering 94936 IDS down to provided 92936...
Filtering 94936 IDS down to provided 1000...
Filtering 94936 IDS down to provided 1000...


In [12]:
def collate_fn(batch):
    data = torch.stack([torch.from_numpy(sample[1]) for sample in batch])
    labels = torch.tensor([sample[2] for sample in batch])

    return data, labels

In [13]:
dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=4)
dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=4)
dl_test = DataLoader(ds_test, batch_size=BATCH_SIZE, collate_fn=collate_fn, num_workers=4)

In [14]:
# save model IDs so we can keep track of what data it was trained on
pd.Series(train_ids).to_csv(f'output/clf_ids/period-train-{VERSION_NAME}.csv', index=False, header=None)
pd.Series(val_ids).to_csv(f'output/clf_ids/period-val-{VERSION_NAME}.csv', index=False, header=None)
pd.Series(test_ids).to_csv(f'output/clf_ids/period-test-{VERSION_NAME}.csv', index=False, header=None)

# Create Model

In [15]:
num_classes = len(TabletPeriodDataset.PERIOD_INDICES) + 2
num_classes

24

In [16]:
model = EraClassifier(LR=LR, num_classes=num_classes)



In [17]:
summary(model, input_size=(BATCH_SIZE, 512, 512))

Layer (type:depth-idx)                        Output Shape              Param #
EraClassifier                                 [16, 24]                  --
├─Conv2d: 1-1                                 [16, 3, 512, 512]         6
├─ResNet: 1-2                                 [16, 24]                  --
│    └─Conv2d: 2-1                            [16, 64, 256, 256]        9,408
│    └─BatchNorm2d: 2-2                       [16, 64, 256, 256]        128
│    └─ReLU: 2-3                              [16, 64, 256, 256]        --
│    └─MaxPool2d: 2-4                         [16, 64, 128, 128]        --
│    └─Sequential: 2-5                        [16, 256, 128, 128]       --
│    │    └─Bottleneck: 3-1                   [16, 256, 128, 128]       75,008
│    │    └─Bottleneck: 3-2                   [16, 256, 128, 128]       70,400
│    │    └─Bottleneck: 3-3                   [16, 256, 128, 128]       70,400
│    └─Sequential: 2-6                        [16, 512, 64, 64]         --
│    

# Train Model

In [18]:
early_stop_callback = EarlyStopping(
    monitor='val_loss',  
    min_delta=0.00001,      
    patience=3,          
    verbose=10,       
    mode='min',
    check_on_train_epoch_end=True
)

In [19]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [26]:
class PrintMetricsCallback(Callback):
    def on_validation_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        # Use get() with a default of 0 to avoid KeyError
        train_loss = metrics.get('train_loss', torch.tensor(0.0)).item()
        val_loss = metrics.get('val_loss', torch.tensor(0.0)).item()
        train_acc = metrics.get('train_acc', torch.tensor(0.0)).item()
        val_acc = metrics.get('val_acc', torch.tensor(0.0)).item()

        print(f"\nEpoch {trainer.current_epoch} Metrics:")
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, "
              f"Train Acc: {train_acc * 100:.2f}%, Val Acc: {val_acc * 100:.2f}%")



In [22]:
logger = pl.loggers.TensorBoardLogger(
    save_dir='.',
    name='lightning_logs',
    version=VERSION_NAME
)

In [27]:
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator='gpu',
    devices='auto',
    val_check_interval=0.2,
    callbacks=[lr_monitor, early_stop_callback, PrintMetricsCallback()],
    logger=logger
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
print('Logs to:', VERSION_NAME)

Logs to: period_clf_bs16_lr5e-05_20epochs-resnet50-94936_samples-preprocessed_March28-1000_test_val


In [29]:
trainer.fit(model, dl_train, dl_val)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type               | Params
------------------------------------------------------
0 | gray_to_triple | Conv2d             | 6     
1 | core           | ResNet             | 23.6 M
2 | objective      | CrossEntropyLoss   | 0     
3 | train_acc      | MulticlassAccuracy | 0     
4 | val_acc        | MulticlassAccuracy | 0     
------------------------------------------------------
23.6 M    Trainable params
0         Non-trainable params
23.6 M    Total params
94.229    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 3.1577, Train Acc: 0.00%, Val Acc: 0.00%


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 0.8125, Train Acc: 0.00%, Val Acc: 52.52%


Validation: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 0.6377, Train Acc: 0.00%, Val Acc: 62.56%


Validation: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 0.6261, Train Acc: 0.00%, Val Acc: 62.14%


Validation: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 0.6321, Train Acc: 0.00%, Val Acc: 62.84%


Validation: 0it [00:00, ?it/s]


Epoch 0 Metrics:
Train Loss: 0.0000, Val Loss: 0.5534, Train Acc: 0.00%, Val Acc: 64.86%


Metric val_loss improved. New best score: 0.553


Validation: 0it [00:00, ?it/s]


Epoch 1 Metrics:
Train Loss: 0.7846, Val Loss: 1.1515, Train Acc: 59.23%, Val Acc: 38.81%


Validation: 0it [00:00, ?it/s]


Epoch 1 Metrics:
Train Loss: 0.7846, Val Loss: 0.5450, Train Acc: 59.23%, Val Acc: 68.31%


Validation: 0it [00:00, ?it/s]


Epoch 1 Metrics:
Train Loss: 0.7846, Val Loss: 0.5374, Train Acc: 59.23%, Val Acc: 61.01%


Validation: 0it [00:00, ?it/s]


Epoch 1 Metrics:
Train Loss: 0.7846, Val Loss: 4.6561, Train Acc: 59.23%, Val Acc: 16.60%


Validation: 0it [00:00, ?it/s]


Epoch 1 Metrics:
Train Loss: 0.7846, Val Loss: 0.4856, Train Acc: 59.23%, Val Acc: 67.85%


Metric val_loss improved by 0.068 >= min_delta = 1e-05. New best score: 0.486


Validation: 0it [00:00, ?it/s]


Epoch 2 Metrics:
Train Loss: 0.5303, Val Loss: 0.4759, Train Acc: 69.69%, Val Acc: 67.04%


Validation: 0it [00:00, ?it/s]


Epoch 2 Metrics:
Train Loss: 0.5303, Val Loss: 0.5001, Train Acc: 69.69%, Val Acc: 65.86%


Validation: 0it [00:00, ?it/s]


Epoch 2 Metrics:
Train Loss: 0.5303, Val Loss: 0.4870, Train Acc: 69.69%, Val Acc: 65.14%


Validation: 0it [00:00, ?it/s]


Epoch 2 Metrics:
Train Loss: 0.5303, Val Loss: 0.5470, Train Acc: 69.69%, Val Acc: 62.57%


Validation: 0it [00:00, ?it/s]


Epoch 2 Metrics:
Train Loss: 0.5303, Val Loss: 0.4805, Train Acc: 69.69%, Val Acc: 66.44%


Metric val_loss improved by 0.005 >= min_delta = 1e-05. New best score: 0.481


Validation: 0it [00:00, ?it/s]


Epoch 3 Metrics:
Train Loss: 0.4254, Val Loss: 0.5260, Train Acc: 74.78%, Val Acc: 63.08%


Validation: 0it [00:00, ?it/s]


Epoch 3 Metrics:
Train Loss: 0.4254, Val Loss: 0.4667, Train Acc: 74.78%, Val Acc: 67.76%


Validation: 0it [00:00, ?it/s]


Epoch 3 Metrics:
Train Loss: 0.4254, Val Loss: 0.4522, Train Acc: 74.78%, Val Acc: 68.36%


Validation: 0it [00:00, ?it/s]


Epoch 3 Metrics:
Train Loss: 0.4254, Val Loss: 0.5469, Train Acc: 74.78%, Val Acc: 66.56%


Validation: 0it [00:00, ?it/s]


Epoch 3 Metrics:
Train Loss: 0.4254, Val Loss: 0.4468, Train Acc: 74.78%, Val Acc: 71.18%


Metric val_loss improved by 0.034 >= min_delta = 1e-05. New best score: 0.447


Validation: 0it [00:00, ?it/s]


Epoch 4 Metrics:
Train Loss: 0.3403, Val Loss: 0.4761, Train Acc: 78.78%, Val Acc: 68.11%


Validation: 0it [00:00, ?it/s]


Epoch 4 Metrics:
Train Loss: 0.3403, Val Loss: 0.4477, Train Acc: 78.78%, Val Acc: 64.55%


Validation: 0it [00:00, ?it/s]


Epoch 4 Metrics:
Train Loss: 0.3403, Val Loss: 0.4493, Train Acc: 78.78%, Val Acc: 69.17%


Validation: 0it [00:00, ?it/s]


Epoch 4 Metrics:
Train Loss: 0.3403, Val Loss: 0.5409, Train Acc: 78.78%, Val Acc: 65.59%


Validation: 0it [00:00, ?it/s]


Epoch 4 Metrics:
Train Loss: 0.3403, Val Loss: 0.5352, Train Acc: 78.78%, Val Acc: 69.92%


Validation: 0it [00:00, ?it/s]


Epoch 5 Metrics:
Train Loss: 0.2579, Val Loss: 0.5415, Train Acc: 83.11%, Val Acc: 71.16%


Validation: 0it [00:00, ?it/s]


Epoch 5 Metrics:
Train Loss: 0.2579, Val Loss: 0.6000, Train Acc: 83.11%, Val Acc: 60.46%


Validation: 0it [00:00, ?it/s]


Epoch 5 Metrics:
Train Loss: 0.2579, Val Loss: 0.5151, Train Acc: 83.11%, Val Acc: 68.61%


Validation: 0it [00:00, ?it/s]


Epoch 5 Metrics:
Train Loss: 0.2579, Val Loss: 0.5218, Train Acc: 83.11%, Val Acc: 71.14%


Validation: 0it [00:00, ?it/s]


Epoch 5 Metrics:
Train Loss: 0.2579, Val Loss: 0.5306, Train Acc: 83.11%, Val Acc: 64.97%


Validation: 0it [00:00, ?it/s]


Epoch 6 Metrics:
Train Loss: 0.1862, Val Loss: 0.6887, Train Acc: 87.27%, Val Acc: 65.94%


Validation: 0it [00:00, ?it/s]


Epoch 6 Metrics:
Train Loss: 0.1862, Val Loss: 0.5600, Train Acc: 87.27%, Val Acc: 69.66%


Validation: 0it [00:00, ?it/s]


Epoch 6 Metrics:
Train Loss: 0.1862, Val Loss: 0.5511, Train Acc: 87.27%, Val Acc: 65.02%


Validation: 0it [00:00, ?it/s]


Epoch 6 Metrics:
Train Loss: 0.1862, Val Loss: 0.5698, Train Acc: 87.27%, Val Acc: 68.61%


Validation: 0it [00:00, ?it/s]


Epoch 6 Metrics:
Train Loss: 0.1862, Val Loss: 0.5186, Train Acc: 87.27%, Val Acc: 69.34%


Monitored metric val_loss did not improve in the last 3 records. Best score: 0.447. Signaling Trainer to stop.
