In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
from src.dataset import OurPatchLocalizationDataset, OriginalPatchLocalizationDataset, sample_img_paths
from src.models import OriginalPretextNetwork, OurPretextNetwork
from src.loss import CustomLoss
from src.train_pretext import train_model

# Original Pretext Task 

### Setup
Note: To make the data loading process more efficient, we generate __samples_per_image__ random samples at once whenever we load an image. Creating __x__ samples at once from an image is much more efficient than iterating through the entire dataset __x__ times while creating only one sample per image. The "true" batch size is thus __samples_per_image__ * __batch_size__.

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

img_paths = sample_img_paths(frac=1.0)

ds_train = OriginalPatchLocalizationDataset(img_paths=img_paths[:46000], samples_per_image=1)
ds_val = OriginalPatchLocalizationDataset(img_paths=img_paths[46000:], samples_per_image=1)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=64, shuffle=False, num_workers=4)

model = OriginalPretextNetwork(backbone="resnet18").to(device)
criterion = nn.CrossEntropyLoss().to(device)

Device: cuda
Number of training images: 	 46000
Number of validation images: 	 3100


### Training

In [None]:
train_model(
    experiment_id="original_pretext_1",
    model=model,    
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    criterion=criterion,
    optimizer=None,
    num_epochs=50,
    log_frequency=50,
)

# Our Pretext Task
### Setup

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

img_paths = sample_img_paths(frac=1.0)

ds_train = OurPatchLocalizationDataset(img_paths=img_paths[:46000], samples_per_image=1)
ds_val = OurPatchLocalizationDataset(img_paths=img_paths[46000:], samples_per_image=1)

print(f"Number of training images: \t {len(ds_train)}")
print(f"Number of validation images: \t {len(ds_val)}")

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(ds_val, batch_size=64, shuffle=False, num_workers=4)

model = OurPretextNetwork(backbone="resnet18").to(device)
criterion = CustomLoss(alpha=1.0, symmetric=True).to(device)

Device: cuda
Number of training images: 	 46000
Number of validation images: 	 3100


### Training


In [None]:
train_model(
    experiment_id="our_pretext_1",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    criterion=criterion,
    optimizer=None,
    num_epochs=50,
    log_frequency=50,
)