In [1]:
import torch
import torch.nn as nn
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset, sample_img_paths
from src.models import OriginalPretextNetwork, OurPretextNetwork
from src.train_pretext import train_model

# Original Pretext Task 

### Setup
Note: To make the data loading process more efficient, we generate __samples_per_image__ random samples at once whenever we load an image. Creating __x__ samples at once from an image is much more efficient than iterating through the entire dataset __x__ times while creating only one sample per image. The "true" batch size is thus __samples_per_image__ * __batch_size__.

In [2]:
img_paths = sample_img_paths(frac=0.1)
print(len(img_paths))

ds_train = OriginalPatchLocalizationDataset(img_paths=img_paths[:4500], samples_per_image=10)
ds_val = OriginalPatchLocalizationDataset(img_paths=img_paths[4500:], samples_per_image=10)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=8, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=8, shuffle=False, num_workers=4)

model = OriginalPretextNetwork(backbone="resnet18")
criterion = nn.CrossEntropyLoss()

4967
Number of training images: 	 4500
Number of validation images: 	 467


In [3]:
train_model(
    experiment_id="original_pretext_1",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=None,
    start_epoch=0,
    num_epochs=20,
    log_frequency=10,
)

Epoch: [0][0/563]	Time 1.153s (1.153s)	Speed 6.9 samples/s	Data 0.001s (0.001s)	Loss 2.08339 (2.08339)
Epoch: [0][10/563]	Time 1.098s (1.114s)	Speed 7.3 samples/s	Data 0.002s (0.001s)	Loss 2.18512 (2.67159)
Epoch: [0][20/563]	Time 1.101s (1.106s)	Speed 7.3 samples/s	Data 0.002s (0.001s)	Loss 2.15102 (2.42316)
Epoch: [0][30/563]	Time 1.104s (1.106s)	Speed 7.2 samples/s	Data 0.001s (0.001s)	Loss 2.05586 (2.32025)


KeyboardInterrupt: 