In [None]:
from pokemon_image import *
from affine_transform import *
from tqdm import tqdm
import pickle

In [None]:
# TRANSFORMATIONS = 50

# def affine_transform_data(data):
#     return [affine_transform_pokemon_image(i) for i in data]

# data = load_image_data("images", "images/annotations.json")
# full_dataset = []
# for _ in tqdm(range(TRANSFORMATIONS)):
#     transformed_data = affine_transform_data(data)
#     training_data = [
#         (i.resized_image, torch.tensor(i.resized_annotation).flatten())
#         for i in transformed_data
#     ]
#     full_dataset.extend(training_data)
# full_dataset += [
#     (i.resized_image, torch.tensor(i.resized_annotation).flatten()) for i in data
# ]

# with open("full_dataset.pkl", "wb") as f:
#     pickle.dump(full_dataset, f)

In [None]:
# Specify the path to your pickle file
pickle_file = "full_dataset.pkl"

# Load the data from the pickle file
with open(pickle_file, "rb") as f:
    full_dataset = pickle.load(f)

In [None]:
from torch.utils.data import DataLoader

dataset = PokemonData(full_dataset)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
import torch.nn as nn
import timm
import torch.optim as optim
from torch.nn import MSELoss
from datetime import datetime


def create_model():
    model = timm.create_model("hrnet_w18", pretrained=True)
    model.classifier = nn.Linear(2048, 32)
    loss_fn = MSELoss()
    return model.to("cuda")


def save_checkpoint(model, optimizer, epoch, checkpoint_dir, is_final_layer):
    # Current time
    now = datetime.now()
    timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")

    # Checkpoint filename
    layer_status = "final_layer" if is_final_layer else "full_model"
    checkpoint_filename = f"checkpoint_{layer_status}_epoch_{epoch}_{timestamp}.pt"

    # Full path for saving
    checkpoint_path = os.path.join(checkpoint_dir, checkpoint_filename)

    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        },
        checkpoint_path,
    )
    print(f"Saved checkpoint to {checkpoint_path}")


def train_model(
    model,
    dataloader,
    optimizer,
    loss_fn,
    num_epochs,
    is_final_layer_only,
    save_epochs,
    checkpoint_dir,
):
    for epoch in range(num_epochs):
        running_loss = 0.0  # reset running loss for each epoch
        for i, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to("cuda"), targets.to("cuda")

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()  # add up batch loss

        avg_loss = running_loss / len(dataloader)  # calculate average loss
        print(f"Epoch {epoch+1}, Loss: {avg_loss}")

        # Save a checkpoint every n epochs
        if (epoch + 1) % save_epochs == 0:
            save_checkpoint(
                model, optimizer, epoch, checkpoint_dir, is_final_layer_only
            )

In [None]:
class ModelConfig:
    save_epochs = 5
    final_layer_epochs = 10
    full_model_epochs = 5
    checkpoint_dir = "model_checkpoints"
    final_layer_learning_rate = 0.01
    full_model_learning_rate = 0.001

In [None]:
model = create_model()
loss_fn = MSELoss()

In [None]:
# Phase 1: Train only the final layer
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False

optimizer = optim.Adam(model.parameters(), lr=ModelConfig.final_layer_learning_rate)
train_model(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=ModelConfig.final_layer_epochs,
    is_final_layer_only=True,
    save_epochs=ModelConfig.save_epochs,
    checkpoint_dir=ModelConfig.checkpoint_dir,
)

In [None]:
# Phase 2: Train the entire model
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.Adam(model.parameters(), lr=ModelConfig.full_model_learning_rate)
train_model(
    model=model,
    dataloader=dataloader,
    optimizer=optimizer,
    loss_fn=loss_fn,
    num_epochs=ModelConfig.full_model_epochs,
    is_final_layer_only=False,
    save_epochs=ModelConfig.save_epochs,
    checkpoint_dir=ModelConfig.checkpoint_dir,
)