In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd '/content/drive/MyDrive/CV-crowd-flow-estimation-/notebooks'
!pwd

In [1]:
import sys
import os
# Add project root to sys.path (one directory up from the notebook)
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.data_loader import ShanghaiTechDataModule
from src.train import train_model
from src.utils import plot_density_predictions, plot_all_decoder_predictions
from src.models import get_model

In [2]:
# 1) Prepare the data
data_module = ShanghaiTechDataModule(
    data_folder="../data/ShanghaiTech",
    part="part_A",
    validation_split=0.1,
    sigma=5,
    return_count=False,
    batch_size=8,
    num_workers=4,
    input_size=(384, 384),
    density_map_size=(192, 192),
)

Using Apple MPS


### Resnet Density Map

In [None]:
model = train_model(
    data_module,
    model_name="resnet50",
    epochs=15,
    lr=1e-4,
    pretrained=True,
    save_path="../models/pth/part_A_resnet50_15.pth"
)


In [None]:
plot_density_predictions(
    model,
    data_module.test_dataloader(),
    device = "mps",
)

### VGG19bn Density map

In [None]:
# 2) Train the model (15 epochs)
model = train_model(
    data_module,
    model_name="vgg19_bn",
    epochs=10,
    lr=1e-4,
    pretrained=True,
    save_path="../models/pth/part_A_vgg19.pth"
)


In [None]:
plot_density_predictions(
    model,
    data_module.test_dataloader(),
    device = "mps",
)

# Basic u-net

In [None]:
model = train_model(
    data_module,
    model_name="unet",
    epochs=30,
    lr=1e-5,
    pretrained=True,
    save_path="../models/pth/part_A_unet.pth"
)

In [None]:
plot_density_predictions(
    model,
    data_module.test_dataloader(),
    device = "mps",
)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params} ({total_params / 1e6:.2f}M)")

In [None]:
model = get_model("unet",
        cpt="../models/pth/part_A_unet.pth", device="mps")[0]

In [None]:
datloader = data_module.test_dataloader()

In [None]:
plot_all_decoder_predictions(
    model, 
    datloader,
    device="mps",
    i=2,
)