In [17]:
import torch
from model import load_resnet_model
import os
from train_regression_weighted_loss import train_model
import re
from dense_weight import DenseWeight
from dataloader import create_dataloader
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_resnet_model('resnet50', num_classes=1)



In [12]:
def get_file_paths(directory):
    """
    Get all file paths in the specified directory with the specified file extension.

    Args:
    - directory (str): The directory to search for files.
    - file_extension (str): The file extension to filter by.

    Returns:
    - List[str]: A list of file paths.
    """
    file_paths = []
    for file in os.listdir(directory):
        file_path = os.path.join(directory, file)
        file_paths.append(file_path)
    return file_paths

# Use a regular expression to extract the batch number from the filename
def extract_batch_number(file_path):
    match = re.search(r"data_batch_(\d+)", file_path)
    if match:
        return int(match.group(1))
    else:
        return -1  # If for some reason a file doesn't match the pattern


In [13]:
train_files = sorted(get_file_paths('./data/records/train/'), key= extract_batch_number)
val_files = sorted(get_file_paths('./data/records/val/'), key= extract_batch_number)

In [None]:
training_labels = []

for train_file in train_files:
    train_loader = create_dataloader(train_file, True, 32)
    for _, labels in tqdm(train_loader):
        training_labels.extend(labels.detach().cpu().numpy())

dense_weight_model = DenseWeight(0.5)
dense_weight_model.fit(training_labels)

In [5]:
train_model(model, train_files, val_files, device, dense_weight_model,epochs=15)

100%|██████████| 16/16 [00:09<00:00,  1.63it/s]
100%|██████████| 16/16 [00:05<00:00,  3.03it/s]
100%|██████████| 16/16 [00:05<00:00,  3.05it/s]
100%|██████████| 16/16 [00:05<00:00,  2.95it/s]
100%|██████████| 16/16 [00:05<00:00,  2.93it/s]
100%|██████████| 16/16 [00:05<00:00,  2.84it/s]
100%|██████████| 4/4 [00:01<00:00,  2.20it/s]
100%|██████████| 16/16 [00:03<00:00,  4.97it/s]
100%|██████████| 16/16 [00:03<00:00,  4.28it/s]
100%|██████████| 3/3 [00:01<00:00,  2.89it/s]


Epoch 1/10, Train Loss: 1.0007, Val Loss: 1.2756, F1 Macro: 0.2628, F1 Micro: 0.5494
Model saved as checkpoint_epoch_1.pth


100%|██████████| 16/16 [00:05<00:00,  3.00it/s]
100%|██████████| 16/16 [00:05<00:00,  2.80it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]
100%|██████████| 16/16 [00:05<00:00,  2.84it/s]
100%|██████████| 16/16 [00:05<00:00,  2.86it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.82it/s]
100%|██████████| 16/16 [00:03<00:00,  4.50it/s]
100%|██████████| 3/3 [00:00<00:00,  3.39it/s]


Epoch 2/10, Train Loss: 0.8812, Val Loss: 1.5962, F1 Macro: 0.2859, F1 Micro: 0.5023


100%|██████████| 16/16 [00:05<00:00,  2.99it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 16/16 [00:05<00:00,  2.76it/s]
100%|██████████| 16/16 [00:05<00:00,  2.80it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 4/4 [00:01<00:00,  2.23it/s]
100%|██████████| 16/16 [00:03<00:00,  4.82it/s]
100%|██████████| 16/16 [00:03<00:00,  4.12it/s]
100%|██████████| 3/3 [00:00<00:00,  3.31it/s]


Epoch 3/10, Train Loss: 0.8563, Val Loss: 3.7346, F1 Macro: 0.2336, F1 Micro: 0.4231


100%|██████████| 16/16 [00:05<00:00,  2.94it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]
100%|██████████| 16/16 [00:05<00:00,  2.77it/s]
100%|██████████| 16/16 [00:05<00:00,  2.77it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.88it/s]
100%|██████████| 16/16 [00:03<00:00,  4.13it/s]
100%|██████████| 3/3 [00:00<00:00,  3.23it/s]


Epoch 4/10, Train Loss: 0.8281, Val Loss: 1.6483, F1 Macro: 0.2699, F1 Micro: 0.4881


100%|██████████| 16/16 [00:05<00:00,  2.96it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]
100%|██████████| 16/16 [00:05<00:00,  2.83it/s]
100%|██████████| 4/4 [00:01<00:00,  2.19it/s]
100%|██████████| 16/16 [00:03<00:00,  4.83it/s]
100%|██████████| 16/16 [00:03<00:00,  4.11it/s]
100%|██████████| 3/3 [00:00<00:00,  3.20it/s]


Epoch 5/10, Train Loss: 0.7970, Val Loss: 1.1841, F1 Macro: 0.4077, F1 Micro: 0.5993
Model saved as checkpoint_epoch_5.pth


100%|██████████| 16/16 [00:05<00:00,  2.96it/s]
100%|██████████| 16/16 [00:05<00:00,  2.74it/s]
100%|██████████| 16/16 [00:05<00:00,  2.84it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 16/16 [00:05<00:00,  2.77it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.83it/s]
100%|██████████| 16/16 [00:03<00:00,  4.47it/s]
100%|██████████| 3/3 [00:00<00:00,  3.29it/s]


Epoch 6/10, Train Loss: 0.7685, Val Loss: 0.9513, F1 Macro: 0.4207, F1 Micro: 0.6249
Model saved as checkpoint_epoch_6.pth


100%|██████████| 16/16 [00:05<00:00,  2.94it/s]
100%|██████████| 16/16 [00:05<00:00,  2.73it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]
100%|██████████| 16/16 [00:05<00:00,  2.80it/s]
100%|██████████| 16/16 [00:05<00:00,  2.83it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.85it/s]
100%|██████████| 16/16 [00:03<00:00,  4.10it/s]
100%|██████████| 3/3 [00:00<00:00,  3.21it/s]


Epoch 7/10, Train Loss: 0.7459, Val Loss: 1.0888, F1 Macro: 0.3900, F1 Micro: 0.5956


100%|██████████| 16/16 [00:05<00:00,  2.98it/s]
100%|██████████| 16/16 [00:06<00:00,  2.66it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 16/16 [00:05<00:00,  2.75it/s]
100%|██████████| 16/16 [00:05<00:00,  2.76it/s]
100%|██████████| 16/16 [00:05<00:00,  2.79it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.82it/s]
100%|██████████| 16/16 [00:03<00:00,  4.42it/s]
100%|██████████| 3/3 [00:00<00:00,  3.33it/s]


Epoch 8/10, Train Loss: 0.7365, Val Loss: 0.9052, F1 Macro: 0.4246, F1 Micro: 0.6299
Model saved as checkpoint_epoch_8.pth


100%|██████████| 16/16 [00:05<00:00,  2.95it/s]
100%|██████████| 16/16 [00:05<00:00,  2.75it/s]
100%|██████████| 16/16 [00:05<00:00,  2.73it/s]
100%|██████████| 16/16 [00:05<00:00,  2.72it/s]
100%|██████████| 16/16 [00:05<00:00,  2.77it/s]
100%|██████████| 16/16 [00:05<00:00,  2.84it/s]
100%|██████████| 4/4 [00:01<00:00,  2.22it/s]
100%|██████████| 16/16 [00:03<00:00,  4.77it/s]
100%|██████████| 16/16 [00:03<00:00,  4.40it/s]
100%|██████████| 3/3 [00:00<00:00,  3.31it/s]


Epoch 9/10, Train Loss: 0.7045, Val Loss: 1.1353, F1 Macro: 0.3708, F1 Micro: 0.5732


100%|██████████| 16/16 [00:05<00:00,  2.96it/s]
100%|██████████| 16/16 [00:05<00:00,  2.76it/s]
100%|██████████| 16/16 [00:05<00:00,  2.81it/s]
100%|██████████| 16/16 [00:05<00:00,  2.78it/s]
100%|██████████| 16/16 [00:05<00:00,  2.82it/s]
100%|██████████| 16/16 [00:05<00:00,  2.83it/s]
100%|██████████| 4/4 [00:01<00:00,  2.25it/s]
100%|██████████| 16/16 [00:03<00:00,  4.80it/s]
100%|██████████| 16/16 [00:03<00:00,  4.40it/s]
100%|██████████| 3/3 [00:00<00:00,  3.31it/s]


Epoch 10/10, Train Loss: 0.6635, Val Loss: 1.2270, F1 Macro: 0.3441, F1 Micro: 0.5499
Model saved as final_model.pth
Training completed and final model saved.
