## Imports

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Optimizer
from ultralytics import YOLO
import torch.nn as nn
import torch
import os

# Setup and Configuration

In [None]:
data_config = "../dataset_final/data.yaml"   # Points to the dataset we just created
model_config = "yolov8_resnet50.yaml"        # Points to your custom backbone file

# Initialization

In [None]:
print(f"Building model from: {model_config}")
model = YOLO(model_config)

# Training Cell

In [None]:
# Custom function to run during training
def on_train_epoch_end(trainer):
    """
    This function runs automatically after every epoch.
    We can use it to log specific things
    """
    current_epoch = trainer.epoch + 1
    # Get the latest validation metrics
    metrics = trainer.metrics
    # Example: Accessing mAP50 (Mean Average Precision)
    map50 = metrics.get("metrics/mAP50(B)", 0)
    
    print(f"Epoch {current_epoch}: mAP is {map50:.4f}")
    
    # You could add logic here: 
    # "If mAP > 0.9, print 'Excellent accuracy achieved!'"

# Attach your custom function to the model
model.add_callback("on_train_epoch_end", on_train_epoch_end)

print(f"Starting training on GPU: {torch.cuda.get_device_name(0)}")

# Train as normal (now with your custom "hook" running inside it)
results = model.train(
    data=data_config, #better than data="data.yaml" bc. easier to change if the folder is different or moved to another place
    epochs=50,
    imgsz=640,
    batch=8,
    device=0,

    # Project Name (Creates a nice folder structure in 'runs/')
    project="Bone_Fracture_Project",
    name="resnet50_run",

    # Augmentation Settings
    degrees=10,      # Rotate +/- 10 degrees (Bones aren't always straight, but rarely upside down)
    translate=0.1,   # Shift image 10% (fracture might be off-center)
    scale=0.5,       # Zoom in/out (+/- 50%)
    fliplr=0.5,      # Flip Left-Right (Left hand looks like Right hand)
    flipud=0.0,      # NO flip Up-Down
    mosaic=1.0,      # Mix 4 images (Standard YOLO booster, very good for context)
    mixup=0.0,       # OFF: Do not mix two bones together (confusing for medical diagnosis)
    hsv_h=0.010,     # Color: Keep VERY low (X-rays are grayscale)
    hsv_s=0.0,       # Saturation: 0 (No color in X-rays)
    hsv_v=0.4,       # Brightness: +/- 30% (Simulates over/under-exposed X-rays) more Robust/Tougher than only 30%
)

## Training Loop

In [None]:
batch_size = 
shuffle = True
data = 

In [None]:
dataLoader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
#optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9)
#optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001)
num_epochs = 
criterion = nn.MSELoss()

In [None]:
for epochs in range(num_epochs):
    for data in dataLoader:
        # Set gradients to zero
        optimizer.zero_grad()
        #Get feature and target from the data loader
        feature, target = data
        #Run a foward pass
        pred = model(feature)
        #Compute loss and gradients
        loss = criterion(pred, target)
        loss.backward()
        #Update the parameters
        optimizer.step()
        

In [None]:
from ultralytics import YOLO

# Load a COCO-pretrained YOLOv8 model
model = YOLO("yolov8_resnet50.yaml")

# Train the model with custom configuration ???
model.train(cfg="train_custom.yaml")

In [None]:
from ultralytics import YOLO
import timm

# 1. Check if TIMM is importable
print(f"TIMM Version: {timm.__version__}")

# 2. Try to build the model
try:
    # Point to the file you just created
    model = YOLO("yolov8_resnet50.yaml") 
    
    # Print the model info - this will verify if the layers were built
    model.info()
    print("\n✅ SUCCESS: ResNet50 Backbone loaded successfully!")
    
except Exception as e:
    print(f"\n❌ ERROR: Something is wrong with the YAML file.\n{e}")

## Post Training - (Optional)

In [None]:
# Manual Inference Display
import cv2
import matplotlib.pyplot as plt

# Run one prediction
results = model("path/to/test_image.jpg")
box = results[0].boxes[0]  # Get first box

# Extract raw coordinates 
x, y, w, h = box.xywh[0].tolist()
conf = box.conf[0].item()

print(f"Math Check: Box Center at ({x:.1f}, {y:.1f}) with confidence {conf:.2%}")