In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset, sample_image_paths
from src.models import OriginalPretextNetwork, OurPretextNetwork
from src.loss import CustomLoss
from src.train import train_model

# Original Pretext Task 

### Setup
Note: If you run out of memory, set `cache_images=False` in the constructor of the datasets. 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

img_paths = sample_image_paths(frac=1.0)
ds_train = OriginalPatchLocalizationDataset(image_paths=img_paths[:46000], cache_images=True)
ds_val = OriginalPatchLocalizationDataset(image_paths=img_paths[46000:], cache_images=True)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

model = OriginalPretextNetwork(backbone="resnet18")
criterion = nn.CrossEntropyLoss()

### Training

In [None]:
train_model(
    experiment_id="original_pretext_1",
    model=model,    
    ds_train=ds_train,
    ds_val=ds_val,
    device=device,
    criterion=criterion,
    optimizer=None,
    num_epochs=50,
    batch_size=64,
    num_workers=4,
    log_frequency=50,
    resume_from_checkpoint=False,
)

# Our Pretext Task
### Setup

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

img_paths = sample_image_paths(frac=1.0)
ds_train = OurPatchLocalizationDataset(image_paths=img_paths[:46000], cache_images=True)
ds_val = OurPatchLocalizationDataset(image_paths=img_paths[46000:], cache_images=True)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

model = OurPretextNetwork(backbone="resnet18")
criterion = CustomLoss(alpha=1.0, symmetric=True)

### Training


In [None]:
train_model(
    experiment_id="our_pretext_1",
    model=model,    
    ds_train=ds_train,
    ds_val=ds_val,
    device=device,
    criterion=criterion,
    optimizer=None,
    num_epochs=50,
    batch_size=64,
    num_workers=4,
    log_frequency=50,
    resume_from_checkpoint=False,
)