In [None]:
import os
import sys
import torch
from torchinfo import summary
from torchvision.transforms import v2
from torch.utils.data import DataLoader

module_dir = os.path.dirname(os.path.abspath('__file__'))
sys.path.append(os.path.dirname(module_dir))


from modules.data import PhoneDataset
from modules.models import ConvolutionalLocator
from modules.training import EarlyStopping, ModelTrainer
from modules.transforms import (
    RandomHorizontalFlip,
    RandomVerticalFlip,
    RandomTranslation
)
from modules.utilities import (
    get_data,
    visualize_augmentations,
    train_test_split
)

In [None]:
folder = os.path.join('..', 'find_phone_data')
seed = 0
test_size = 0.1
batch_size = 128
patience = 5
min_delta = 0.0
loss_fn = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam
early_stopping = EarlyStopping(15)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau
epochs = 50
lr = 5e-4
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [None]:
image_paths, locations = get_data(folder)
print(image_paths[0])
print(locations[os.path.basename(image_paths[0])])

In [None]:
train_image_paths, test_image_paths, train_locations, test_locations = train_test_split(
    image_paths, locations, test_size, seed
)
print(train_image_paths[0])
print(train_locations[os.path.basename(train_image_paths[0])])

In [None]:
aware_transforms = v2.Compose([
    RandomHorizontalFlip(),
    RandomVerticalFlip(),
    RandomTranslation(),
])

In [None]:
train_dataset = PhoneDataset(
    train_image_paths,
    aware_transforms,
)

test_dataset = PhoneDataset(test_image_paths)
visualize_augmentations(train_dataset, train_locations, random_img=True)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
for X, y in train_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape} {X.dtype}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

In [None]:
model_trainer = ModelTrainer(
    train_dataloader,
    test_dataloader,
    loss_fn,
    optimizer,
    early_stopping,
    scheduler,
    device
)

In [None]:
c_in = 3
h_in, w_in = 326, 490

model = torch.hub.load(
    'pytorch/vision:v0.10.0',
    'vgg11',
    pretrained=True
)

summary(
    model=model,
    input_size=(batch_size, c_in, h_in, w_in),
    col_names=["input_size", "output_size", "num_params"],
    col_width=20,
    row_settings=["var_names"]
)

In [None]:
for param in model.parameters():
    param.requires_grad = False

in_features = model.classifier[-1].in_features
model.classifier[6] = torch.nn.Linear(in_features, 2)

In [None]:
model_trainer = ModelTrainer(
    train_dataloader,
    test_dataloader,
    loss_fn,
    optimizer,
    early_stopping,
    scheduler,
    device
)

model_trainer(epochs, model, lr)

In [None]:
for param in model.parameters():
    param.grad_requires = True

model_trainer(epochs, model, lr)

In [None]:
model.eval()
train_size = len(train_dataloader.dataset)
train_correct = 0
for X, y in train_dataloader:
    X, y = X.to(device), y.to(device)
    pred = model(X)
    correct = (torch.sum((y - pred)**2, dim=-1) < 0.05)
    train_correct += correct.type(torch.float).sum().item()

test_size = len(test_dataloader.dataset)
test_correct = 0
for X, y in test_dataloader:
    X, y = X.to(device), y.to(device)
    pred = model(X)
    correct = (torch.sum((y - pred)**2, dim=-1) < 0.05)
    test_correct += correct.type(torch.float).sum().item()

print(f'Train set accuracy: {train_correct / train_size}')
print(f'Test set accuracy: {test_correct / test_size}')