In [1]:
%matplotlib inline 

In [2]:
from pokemon_image import *
from affine_transform import *
from model import *
from tqdm import tqdm
import pickle
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn import MSELoss

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class DataConfig:
    image_dir = "images"
    annotations_path = "images/annotations.json"
    data_path = "full_dataset.pkl"

In [None]:
# TRANSFORMATIONS = 500

# def affine_transform_data(data):
#     return [affine_transform_pokemon_image(i) for i in data]

# data = load_image_data("images", "images/annotations.json")
# full_dataset = []
# for _ in tqdm(range(TRANSFORMATIONS)):
#     transformed_data = affine_transform_data(data)
#     training_data = [
#         (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float())
#         for i in transformed_data
#     ]
#     full_dataset.extend(training_data)
# full_dataset += [
#     (i.resized_image.float(), torch.tensor(i.resized_annotation).flatten().float()) for i in data
# ]

# with open("full_dataset.pkl", "wb") as f:
#     pickle.dump(full_dataset, f)

In [4]:
# Load the data from the pickle file
with open(DataConfig.data_path, "rb") as f:
    full_dataset = pickle.load(f)

In [5]:
class ModelConfig:
    model_name = "basic_hrnet"
    save_epochs = 5
    final_layer_epochs = 10
    full_model_epochs = 5
    checkpoint_dir = "model_checkpoints"
    final_layer_learning_rate = 0.01
    full_model_learning_rate = 0.001
    batch_size = 32

In [6]:
dataset = PokemonData(full_dataset)
dataloader = DataLoader(dataset, batch_size=ModelConfig.batch_size, shuffle=True)

In [7]:
model = create_model(ModelConfig.model_name)
loss_fn = MSELoss()

In [None]:
# # Phase 1: Train only the final layer
# for name, param in model.named_parameters():
#     if "classifier" not in name:
#         param.requires_grad = False

# optimizer = optim.Adam(model.parameters(), lr=ModelConfig.final_layer_learning_rate)

# train_model(
#     model=model,
#     dataloader=dataloader,
#     optimizer=optimizer,
#     loss_fn=loss_fn,
#     num_epochs=ModelConfig.final_layer_epochs,
#     is_final_layer_only=True,
#     save_epochs=ModelConfig.save_epochs,
#     checkpoint_dir=ModelConfig.checkpoint_dir,
#     model_name=ModelConfig.model_name
# )

In [8]:
# Phase 2: Train the entire model
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=ModelConfig.full_model_learning_rate)
model, optimizer = load_latest_checkpoint(ModelConfig.checkpoint_dir, ModelConfig.model_name, model, optimizer)

train_model(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=ModelConfig.full_model_epochs,
    is_final_layer_only=False,
    save_epochs=ModelConfig.save_epochs,
    checkpoint_dir=ModelConfig.checkpoint_dir,
    model_name=ModelConfig.model_name
)

  0%|          | 0/5 [00:00<?, ?it/s]

# Evaluation

In [None]:
data = load_image_data(DataConfig.image_dir, DataConfig.annotations_path)

In [None]:
for i in range(len(data)):
    # Get prediction
    pred_annotations = model(data[i].resized_image.float().unsqueeze(0)).detach().cpu()

    # Call method to plot prediction
    data[i].plot_prediction(pred_annotations.reshape(-1, 2).tolist())