In [None]:
# Vegas Training

# import libraries
import os, time, math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
import rasterio
import numpy as np
from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_absolute_error

In [None]:
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
PRINT_EVERY = 50

# LoRA hyperparams
LORA_R = 32
LORA_ALPHA = 64
LORA_DROPOUT = 0.05

MODEL_NAME = "vit_base_patch14_dinov2"
IMAGE_SIZE = 518

# ------------------- LoRA modules -------------------
class LoRALinear(nn.Module):
    def __init__(self, base_linear, r=LORA_R, alpha=LORA_ALPHA, dropout=LORA_DROPOUT):
        super().__init__()
        self.base = base_linear
        in_f, out_f = base_linear.in_features, base_linear.out_features
        self.r = r

        if r > 0:
            self.lora_A = nn.Parameter(torch.zeros(in_f, r))
            self.lora_B = nn.Parameter(torch.zeros(r, out_f))
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
            self.scaling = alpha / r
            self.dropout = nn.Dropout(dropout)
        else:
            self.lora_A = None
            self.lora_B = None
            self.scaling = 1
            self.dropout = nn.Identity()

    def forward(self, x):
        base_out = self.base(x)
        if self.r > 0:
            dx = self.dropout(x)
            lora_out = (dx @ self.lora_A) @ self.lora_B
            return base_out + self.scaling * lora_out
        return base_out


def replace_modules_by_path(model, substrs=("qkv", "proj")):
    replaced = []
    for full_name, module in list(model.named_modules()):
        if any(s in full_name for s in substrs):
            parent_name = ".".join(full_name.split(".")[:-1])
            leaf = full_name.split(".")[-1]

            parent = model
            if parent_name:
                for p in parent_name.split("."):
                    parent = getattr(parent, p)

            old = getattr(parent, leaf)
            if isinstance(old, nn.Linear):
                new = LoRALinear(old)
                setattr(parent, leaf, new)
                replaced.append(full_name)
    return replaced

In [None]:
class ImgDataset(Dataset):
    def __init__(self, csv_path, image_col="image_path", target_col="target",
                 image_size=IMAGE_SIZE):
        self.df = pd.read_csv(csv_path)
        self.paths = self.df[image_col].tolist()
        self.targets = self.df[target_col].astype(float).values

        self.tf = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406], [0.229,0.224,0.225])
        ])

    def load_image(self, path):
        try:
            with rasterio.open(path) as src:
                arr = src.read([1,2,3])
                arr = np.transpose(arr, (1,2,0))
                arr = np.clip(arr, 0, 255).astype(np.uint8)
                return Image.fromarray(arr)
        except:
            return Image.open(path).convert("RGB")

    def __len__(self): return len(self.paths)

    def __getitem__(self, idx):
        img = self.load_image(self.paths[idx])
        img = self.tf(img)
        y = float(self.targets[idx])
        return img, torch.tensor(y, dtype=torch.float32)

In [None]:
print("Device:", DEVICE)
print("Loading:", MODEL_NAME)

# load to CPU first 
model = timm.create_model(MODEL_NAME, pretrained=True)

# find embedding dimension
embed_dim = model.num_features if hasattr(model, "num_features") else model.embed_dim
print("Embedding dim:", embed_dim)

# replace classification head with regression head
model.reset_classifier(num_classes=0)
model.head = nn.Linear(embed_dim, 1)

# freeze base parameters
for p in model.parameters():
    p.requires_grad = False

print("Patching LoRA...")
replaced = replace_modules_by_path(model, ("qkv","proj"))
print("LoRA replaced modules:", len(replaced))

# unfreeze LoRA params + head
for n, p in model.named_parameters():
    if "lora_" in n or n.startswith("head"):
        p.requires_grad = True

# move model to device
model = model.to(DEVICE)
print("Model moved to device.")

In [None]:
CSV = "../data/output/vegas_merged_full_clean.csv"

ds = ImgDataset(CSV)
N = len(ds)

# === Compute training mean/std for normalization ===
raw_targets = ds.targets
t_mean = raw_targets.mean()
t_std  = raw_targets.std()

print("Target mean:", t_mean, " Target std:", t_std)

# Store normalized targets into dataset
ds.targets = (raw_targets - t_mean) / t_std

idx = np.arange(N)
np.random.shuffle(idx)

train_idx = idx[: int(0.8*N)]
val_idx   = idx[int(0.8*N):]

train_ds = torch.utils.data.Subset(ds, train_idx)
val_ds   = torch.utils.data.Subset(ds, val_idx)

# === Updated batch size ===
train_loader = DataLoader(train_ds, batch_size=8, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=8, shuffle=False)

print("Train:", len(train_ds), " Val:", len(val_ds))

In [None]:
def forward_reg(m, x):
    out = m(x)
    if out.ndim == 1:
        out = out.unsqueeze(1)
    return out

criterion = nn.L1Loss()       # === MAE loss ===

trainable = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable, lr=3e-5, weight_decay=1e-6)

print("Trainable params:", sum(p.numel() for p in trainable))

In [None]:
# LoRA training

EPOCHS = 50
best_val = float("inf")

# ADD THESE HISTORY LISTS
history_train_mae = []
history_val_mae = []

for ep in range(1, EPOCHS+1):
    model.train()
    total, n = 0, 0
    pbar = tqdm(train_loader, desc=f"Train ep{ep}/{EPOCHS}")

    for imgs, ys in pbar:
        imgs = imgs.to(DEVICE)
        ys = ys.to(DEVICE).unsqueeze(1)

        optimizer.zero_grad()
        preds = forward_reg(model, imgs)
        loss = criterion(preds, ys)
        loss.backward()
        optimizer.step()

        total += loss.item() * imgs.size(0)
        n += imgs.size(0)
        pbar.set_postfix(loss=loss.item())

    train_mae = total / n
    history_train_mae.append(train_mae)

    # --- Validation ---
    model.eval()
    total, n = 0, 0
    with torch.no_grad():
        for imgs, ys in val_loader:
            imgs = imgs.to(DEVICE)
            ys = ys.to(DEVICE).unsqueeze(1)
            preds = forward_reg(model, imgs)
            total += torch.abs(preds - ys).sum().item()
            n += imgs.size(0)

    val_mae = total / n
    history_val_mae.append(val_mae)

    print(f"Epoch {ep}/{EPOCHS} Train MAE={train_mae:.4f}  Val MAE={val_mae:.4f}")

    # Save checkpoint if best val
    if val_mae < best_val:
        best_val = val_mae
        os.makedirs("../models", exist_ok=True)
        torch.save(
            {k: v.cpu() for k,v in model.state_dict().items()
             if ("lora_" in k or k.startswith("head"))},
            "../models/dinov2_vegas_lora_best.pth"
        )
        print("Saved new best checkpoint")

In [None]:
# final evaluation and history logging

# Load BEST LoRA checkpoint before evaluation
best_ckpt_path = "../models/dinov2_vegas_lora_best.pth"

state = torch.load(best_ckpt_path, map_location=DEVICE)
model.load_state_dict(state, strict=False)
model.eval()

print("Loaded best validation checkpoint:", best_ckpt_path)

# Final evaluation
all_preds, all_targets = [], []

with torch.no_grad():
    for imgs, ys in val_loader:
        imgs = imgs.to(DEVICE)
        ys = ys.to(DEVICE).unsqueeze(1)
        preds = forward_reg(model, imgs)

        all_preds.append(preds.cpu().numpy())
        all_targets.append(ys.cpu().numpy())

preds = np.vstack(all_preds).ravel()
targets = np.vstack(all_targets).ravel()

# INVERSE NORMALIZE predictions
preds = preds * t_std + t_mean
targets = targets * t_std + t_mean

final_r2 = r2_score(targets, preds)
final_mae = mean_absolute_error(targets, preds)

print("Final R2 (best model):", final_r2)
print("Final MAE (best model):", final_mae)

# Save training history
history_path = "../models/dinov2_vegas_training_history.csv"
df_hist = pd.DataFrame({
    "epoch": list(range(1, len(history_train_mae) + 1)),
    "train_mae": history_train_mae,
    "val_mae": history_val_mae
})
df_hist.to_csv(history_path, index=False)
print("Saved training history to:", history_path)

# Save last epoch LoRA separately (for reference)
last_path = "../models/dinov2_vegas_lora_last.pth"
torch.save(
    {k: v.cpu() for k, v in model.state_dict().items()
     if ("lora_" in k or k.startswith("head"))},
    last_path
)
print("Saved final epoch LoRA params to:", last_path)