In [1]:
import torch
import torch.nn as nn
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset, sample_img_paths
from src.models import OriginalPretextNetwork, OurPretextNetwork
from src.loss import CustomLoss
from src.train_pretext import train_model

# Original Pretext Task 

### Setup
Note: To make the data loading process more efficient, we generate __samples_per_image__ random samples at once whenever we load an image. Creating __x__ samples at once from an image is much more efficient than iterating through the entire dataset __x__ times while creating only one sample per image. The "true" batch size is thus __samples_per_image__ * __batch_size__.

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: \t {device}")

img_paths = sample_img_paths(frac=0.1)

ds_train = OriginalPatchLocalizationDataset(img_paths=img_paths[:10], samples_per_image=1)
ds_val = OriginalPatchLocalizationDataset(img_paths=img_paths[10:12], samples_per_image=1)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=8, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=8, shuffle=False, num_workers=4)

model = OriginalPretextNetwork(backbone="resnet18").to(device)
criterion = nn.CrossEntropyLoss().to(device)

Number of training images: 	 10
Number of validation images: 	 2


### Training

In [None]:
train_model(
    experiment_id="original_pretext_1",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=None,
    start_epoch=0,
    num_epochs=20,
    log_frequency=10,
)

# Our Pretext Task
### Setup

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: \t {device}")

img_paths = sample_img_paths(frac=0.1)

ds_train = OurPatchLocalizationDataset(img_paths=img_paths[:8], samples_per_image=1)
ds_val = OurPatchLocalizationDataset(img_paths=img_paths[8:10], samples_per_image=1)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=8, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=8, shuffle=False, num_workers=4)

model = OurPretextNetwork(backbone="resnet18").to(device)
criterion = CustomLoss(alpha=1., symmetric=True).to(device)

Number of training images: 	 8
Number of validation images: 	 2


### Training


In [5]:
train_model(
    experiment_id="our_pretext_1",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=None,
    start_epoch=0,
    num_epochs=20,
    log_frequency=10,
)

Epoch: [0][0/1]	Time 1.629s (1.629s)	Speed 4.9 samples/s	Data 0.000s (0.000s)	Loss 4.13859 (4.13859)
Epoch: [0][0/1]	Time 1.629s (1.629s)	Speed 4.9 samples/s	Data 0.000s (0.000s)	Loss 4.13859 (4.13859)
Test: [0/1]	Time 0.146 (0.146)	Loss 33.9489 (33.9489)
Test: [0/1]	Time 0.146 (0.146)	Loss 33.9489 (33.9489)
Accuracy: 0.000
Accuracy: 0.000
Saving checkpoint to ./out/our_pretext_1/
Saving checkpoint to ./out/our_pretext_1/
Epoch: [1][0/1]	Time 1.710s (1.710s)	Speed 4.7 samples/s	Data 0.000s (0.000s)	Loss 27.60455 (27.60455)
Epoch: [1][0/1]	Time 1.710s (1.710s)	Speed 4.7 samples/s	Data 0.000s (0.000s)	Loss 27.60455 (27.60455)
Test: [0/1]	Time 0.144 (0.144)	Loss 4.8663 (4.8663)
Test: [0/1]	Time 0.144 (0.144)	Loss 4.8663 (4.8663)
Accuracy: 0.500
Accuracy: 0.500
Saving best model to ./out/our_pretext_1/
Saving best model to ./out/our_pretext_1/
Saving checkpoint to ./out/our_pretext_1/
Saving checkpoint to ./out/our_pretext_1/


KeyboardInterrupt: 