In [None]:
import sys
import os

# Get the absolute path of the project root, which is one level above the directory this script is in
project_root = os.path.abspath('..')
sys.path.insert(0, project_root)

In [None]:
%matplotlib inline

In [None]:
from pokemon import *
from tqdm import tqdm
import pickle
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import MSELoss
import random

In [None]:
class DataConfig:
    image_dir = "train_images"
    annotations_path = "train_images/annotations.json"
    training_data_path = "data/background_training_data.pkl"
    validation_data_path = "data/background_val_data.pkl"
    test_image_dir = "test_images"

In [None]:
# from concurrent.futures import ThreadPoolExecutor

# TRANSFORMATIONS = 200
# BACKGROUND_IMAGE_INDEXES = [23, 30, 32, 35, 38, 43, 44, 48, 51, 54, 64, 70, 74, 80, 85, 93, 109, 115, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]
# background_image_names = [f"Image {i}.jpg" for i in BACKGROUND_IMAGE_INDEXES] + [f"Image_{i}.jpg" for i in BACKGROUND_IMAGE_INDEXES]

# data = load_image_data(DataConfig.image_dir, DataConfig.annotations_path)
# data = [i for i in data if i.image_name in background_image_names]
# val_data = random.sample(data, 5)
# train_data = [item for item in data if item not in val_data]

# def affine_transform_data(data):
#     return [augment_data(i) for i in data]

# train_dataset = []
# # Create a thread pool
# with ThreadPoolExecutor() as executor:
#     for transformed_data in tqdm(executor.map(affine_transform_data, [train_data]*TRANSFORMATIONS), total=TRANSFORMATIONS):
#         training_data = [
#             (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float())
#             for i in transformed_data
#         ]
#         train_dataset.extend(training_data)

# train_dataset += [
#     (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float()) for i in train_data
# ]

# with open(DataConfig.training_data_path, "wb") as f:
#     pickle.dump(train_dataset, f)

# val_dataset = []
# # Create a thread pool
# with ThreadPoolExecutor() as executor:
#     for transformed_data in tqdm(executor.map(affine_transform_data, [val_data]*TRANSFORMATIONS), total=TRANSFORMATIONS):
#         training_data = [
#             (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float())
#             for i in transformed_data
#         ]
#         val_dataset.extend(training_data)

# val_dataset += [
#     (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float()) for i in val_data
# ]

# with open(DataConfig.validation_data_path, "wb") as f:
#     pickle.dump(val_dataset, f)

In [None]:
# Load the data from the pickle file
with open(DataConfig.training_data_path, "rb") as f:
    train_dataset = pickle.load(f)

with open(DataConfig.validation_data_path, "rb") as f:
    val_dataset = pickle.load(f)

In [None]:
class ModelConfig:
    model_name = "background_hrnet"
    save_epochs = 5
    final_layer_epochs = 10
    full_model_epochs = 5
    checkpoint_dir = "model_checkpoints"
    final_layer_learning_rate = 0.01
    full_model_learning_rate = 0.001
    batch_size = 32
    weight_decay = 0.001

In [None]:
train_dataset = PokemonData(train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=ModelConfig.batch_size, shuffle=True)

In [None]:
val_dataset = PokemonData(val_dataset)
validation_dataloader = DataLoader(val_dataset, batch_size=ModelConfig.batch_size, shuffle=True)

In [None]:
model = create_model(ModelConfig.model_name)
loss_fn = MSELoss()

In [None]:
# # Phase 1: Train only the final layer
# for name, param in model.named_parameters():
#     if "classifier" not in name:
#         param.requires_grad = False

# optimizer = optim.Adam(model.parameters(), lr=ModelConfig.final_layer_learning_rate, weight_decay=ModelConfig.weight_decay)
# model, optimizer = load_latest_checkpoint(ModelConfig.checkpoint_dir, ModelConfig.model_name, model)

# train_model(
#     model=model,
#     train_dataloader=train_dataloader,
#     val_dataloader=validation_dataloader,
#     optimizer=optimizer,
#     loss_fn=loss_fn,
#     num_epochs=ModelConfig.final_layer_epochs,
#     is_final_layer_only=True,
#     save_epochs=ModelConfig.save_epochs,
#     checkpoint_dir=ModelConfig.checkpoint_dir,
#     model_name=ModelConfig.model_name
# )

In [None]:
# Phase 2: Train the entire model
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=ModelConfig.full_model_learning_rate, weight_decay=ModelConfig.weight_decay)
model, optimizer = load_latest_checkpoint(ModelConfig.checkpoint_dir, ModelConfig.model_name, model, optimizer)
for group in optimizer.param_groups:
    group['weight_decay'] = ModelConfig.weight_decay

train_model(
    model=model,
    train_dataloader=train_dataloader,
    val_dataloader=validation_dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=ModelConfig.full_model_epochs,
    is_final_layer_only=False,
    save_epochs=ModelConfig.save_epochs,
    checkpoint_dir=ModelConfig.checkpoint_dir,
    model_name=ModelConfig.model_name
)

# Evaluation

In [None]:
model = create_model(ModelConfig.model_name)
model = load_latest_checkpoint(ModelConfig.checkpoint_dir, ModelConfig.model_name, model)
model.eval();
model.to("cpu");

In [None]:
data = load_test_image_data(DataConfig.test_image_dir)

In [None]:
for i in data:
    i.predict_annotations(model)