In [None]:
"""
End-to-end pipeline to train a model that predicts car prices from images + tabular metadata.
Save predictions (id, true_price, pred_price) to a CSV.

How to use:
    python car_price_pipeline.py \
        --parquet data/cars.parquet \
        --images_dir data/images/ \
        --output_dir outputs/ \
        --image_ext .jpg \
        --epochs 10

The script expects the parquet file to contain at least an `id` column and a `price` column
(or you can provide a test set where `price` is missing / NaN).
If images are named by id (e.g. <id>.jpg) the script will find them in images_dir.
If your parquet already contains an `image_path` column, it will use that.

The model: pretrained ResNet backbone (from torchvision) -> embedding -> concat tabular features -> MLP head -> single regression output.
Tabular preprocessing uses sklearn ColumnTransformer (StandardScaler for numeric, OneHotEncoder for categorical).

Outputs:
 - best model saved to output_dir/best_model.pth
 - last model to output_dir/last_model.pth
 - predictions to output_dir/predictions.csv
 - training log to output_dir/train_log.csv

"""

import os
import argparse
from pathlib import Path
import math
import random
import time

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision.models as models

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline

# ----------------------------- Utilities ---------------------------------

def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


# ----------------------------- Dataset ----------------------------------
class CarPriceDataset(Dataset):
    def __init__(self, df, images_dir, image_ext='.jpg', tabular_transform=None, image_transform=None, id_col='id', price_col='price'):
        self.df = df.reset_index(drop=True)
        self.images_dir = Path(images_dir)
        self.image_ext = image_ext
        self.tabular_transform = tabular_transform
        self.image_transform = image_transform
        self.id_col = id_col
        self.price_col = price_col

    def __len__(self):
        return len(self.df)

    def _get_image_path(self, row):
        # prefer explicit image path if present
        if 'image_path' in row and pd.notna(row['image_path']):
            return Path(row['image_path'])
        # otherwise look up by id
        img_name = f"{row[self.id_col]}{self.image_ext}"
        return self.images_dir / img_name

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self._get_image_path(row)
        # load image
        from PIL import Image
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            # If image missing, create a black image to avoid crash (but log may be needed)
            img = Image.new('RGB', (224, 224), (0, 0, 0))
        if self.image_transform:
            img = self.image_transform(img)
        # tabular
        tab = None
        if self.tabular_transform is not None:
            # tabular_transform expects dataframe-like; we will pass a single-row DataFrame
            # but sklearn transformers expect 2D arrays => pass a row as DataFrame and then ravel
            Xtab = self.tabular_transform.transform(self.df.drop(columns=[self.id_col, self.price_col], errors='ignore'))
            # Note: computing transform for whole df per item is inefficient. We'll handle tabular data precomputed outside.
            pass
        sample = {
            'id': row[self.id_col],
            'image': img
        }
        # price if exists
        if self.price_col in self.df.columns and pd.notna(row[self.price_col]):
            sample['price'] = float(row[self.price_col])
        return sample

# We'll implement a more efficient dataset that takes precomputed tabular numpy array
class CarPriceDatasetFast(Dataset):
    def __init__(self, df, tabular_array, images_dir, image_ext='.jpg', image_transform=None, id_col='id', price_col='price'):
        self.df = df.reset_index(drop=True)
        self.tabular = tabular_array
        self.images_dir = Path(images_dir)
        self.image_ext = image_ext
        self.image_transform = image_transform
        self.id_col = id_col
        self.price_col = price_col

    def __len__(self):
        return len(self.df)

    def _get_image_path(self, row):
        if 'image_path' in row and pd.notna(row['image_path']):
            return Path(row['image_path'])
        img_name = f"{row[self.id_col]}{self.image_ext}"
        return self.images_dir / img_name

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = self._get_image_path(row)
        from PIL import Image
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            img = Image.new('RGB', (224, 224), (0, 0, 0))
        if self.image_transform:
            img = self.image_transform(img)
        tab = self.tabular[idx].astype(np.float32) if self.tabular is not None else np.array([], dtype=np.float32)
        sample = {
            'id': row[self.id_col],
            'image': img,
            'tabular': tab
        }
        if self.price_col in self.df.columns and pd.notna(row[self.price_col]):
            sample['price'] = float(row[self.price_col])
        return sample

# ----------------------------- Model ------------------------------------
class ImageTabularRegressor(nn.Module):
    def __init__(self, backbone_name='resnet18', pretrained=True, tab_dim=0, embed_dim=512, head_hidden=[256, 64], dropout=0.2):
        super().__init__()
        # load backbone
        if backbone_name.startswith('resnet'):
            model = getattr(models, backbone_name)(pretrained=pretrained)
            # remove last fc
            in_features = model.fc.in_features
            modules = list(model.children())[:-1]  # remove fc
            self.backbone = nn.Sequential(*modules)
            self.backbone_embed_dim = in_features
        else:
            raise NotImplementedError('Only resnet* backbones implemented in this example')

        # projection from backbone output to embed_dim
        self.img_proj = nn.Linear(self.backbone_embed_dim, embed_dim)
        self.tab_dim = tab_dim
        # head
        head_in = embed_dim + tab_dim
        layers = []
        prev = head_in
        for h in head_hidden:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            prev = h
        layers.append(nn.Linear(prev, 1))
        self.head = nn.Sequential(*layers)

    def forward(self, images, tabular=None):
        # images: tensor Bx3xHxW
        x = self.backbone(images)  # B x C x 1 x 1
        x = x.reshape(x.size(0), -1)
        x = self.img_proj(x)
        if tabular is not None:
            x = torch.cat([x, tabular], dim=1)
        out = self.head(x).squeeze(1)
        return out

# ----------------------------- Training ---------------------------------

def get_transforms(train=True, size=224):
    if train:
        return T.Compose([
            T.Resize((size, size)),
            T.RandomHorizontalFlip(),
            T.RandomApply([T.ColorJitter(0.2,0.2,0.2,0.02)], p=0.5),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])
    else:
        return T.Compose([
            T.Resize((size, size)),
            T.ToTensor(),
            T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
        ])


def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    count = 0
    for batch in loader:
        images = batch['image'].to(device)
        tabs = batch['tabular'].to(device) if 'tabular' in batch else None
        prices = batch.get('price', None)
        if prices is None:
            continue
        prices = prices.to(device)
        optimizer.zero_grad()
        if scaler is not None:
            with torch.cuda.amp.autocast():
                preds = model(images, tabs)
                loss = criterion(preds, prices)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            preds = model(images, tabs)
            loss = criterion(preds, prices)
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * images.size(0)
        count += images.size(0)
    return running_loss / max(1, count)


def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    count = 0
    preds_list = []
    ids = []
    trues = []
    with torch.no_grad():
        for batch in loader:
            images = batch['image'].to(device)
            tabs = batch['tabular'].to(device) if 'tabular' in batch else None
            prices = batch.get('price', None)
            ids.extend(batch['id'])
            if prices is not None:
                prices = prices.to(device)
                trues.extend(prices.cpu().numpy().tolist())
            pred = model(images, tabs)
            preds_list.extend(pred.detach().cpu().numpy().tolist())
            if prices is not None:
                running_loss += ((pred - prices)**2).sum().item()
                count += images.size(0)
    mse = running_loss / max(1, count)
    rmse = math.sqrt(mse) if count>0 else None
    return rmse, ids, trues, preds_list

# ----------------------------- Main pipeline -----------------------------

def main(args):
    seed_everything(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    # 1) read parquet
    df = pd.read_parquet(args.parquet)
    if 'id' not in df.columns:
        raise RuntimeError('Parquet must contain an `id` column')

    # detect price column
    price_col = args.price_col
    if price_col not in df.columns:
        print(f"Price column `{price_col}` not found in parquet. Proceeding as test-only (no targets).")

    # split train/val if prices exist
    has_price = price_col in df.columns and df[price_col].notna().any()
    if has_price:
        train_df, val_df = train_test_split(df[df[price_col].notna()], test_size=args.val_size, random_state=args.seed)
        # If there are rows without price, treat them as test
        test_df = df[df[price_col].isna()]
    else:
        train_df = pd.DataFrame(columns=df.columns)
        val_df = pd.DataFrame(columns=df.columns)
        test_df = df.copy()

    # 2) prepare tabular features
    # choose tabular columns = all except id, price, image_path
    ignore_cols = set([args.id_col, price_col, 'image_path'])
    tab_cols = [c for c in df.columns if c not in ignore_cols]
    print('Tabular columns used:', tab_cols)

    # Build transformer
    numeric_cols = [c for c in tab_cols if pd.api.types.is_numeric_dtype(df[c])]
    cat_cols = [c for c in tab_cols if c not in numeric_cols]

    transformers = []
    if numeric_cols:
        transformers.append(('num', StandardScaler(), numeric_cols))
    if cat_cols:
        transformers.append(('cat', OneHotEncoder(handle_unknown='ignore', sparse=False), cat_cols))

    if transformers:
        col_transformer = ColumnTransformer(transformers)
        # fit on train + val + test to avoid unseen categories issues? fit only on train recommended
        if len(train_df)>0:
            col_transformer.fit(train_df[tab_cols].fillna(''))
        else:
            col_transformer.fit(df[tab_cols].fillna(''))
        # transform datasets
        def make_tab_array(df_part):
            if len(tab_cols)==0:
                return np.zeros((len(df_part),0), dtype=np.float32)
            X = df_part[tab_cols].fillna('')
            arr = col_transformer.transform(X)
            return arr.astype(np.float32)
        train_tab = make_tab_array(train_df)
        val_tab = make_tab_array(val_df)
        test_tab = make_tab_array(test_df)
        tab_dim = train_tab.shape[1]
    else:
        # no tabular features
        train_tab = np.zeros((len(train_df),0), dtype=np.float32)
        val_tab = np.zeros((len(val_df),0), dtype=np.float32)
        test_tab = np.zeros((len(test_df),0), dtype=np.float32)
        tab_dim = 0

    # 3) Transforms and Datasets
    train_transform = get_transforms(train=True, size=args.img_size)
    val_transform = get_transforms(train=False, size=args.img_size)

    train_ds = CarPriceDatasetFast(train_df, train_tab, args.images_dir, image_ext=args.image_ext, image_transform=train_transform, id_col=args.id_col, price_col=price_col)
    val_ds = CarPriceDatasetFast(val_df, val_tab, args.images_dir, image_ext=args.image_ext, image_transform=val_transform, id_col=args.id_col, price_col=price_col)
    test_ds = CarPriceDatasetFast(test_df, test_tab, args.images_dir, image_ext=args.image_ext, image_transform=val_transform, id_col=args.id_col, price_col=price_col)

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # 4) Model
    device = torch.device('cuda' if torch.cuda.is_available() and not args.use_cpu else 'cpu')
    model = ImageTabularRegressor(backbone_name=args.backbone, pretrained=args.pretrained, tab_dim=tab_dim, embed_dim=args.embed_dim, head_hidden=args.head_hidden, dropout=args.dropout)
    model.to(device)

    # 5) Loss, optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5, verbose=True)
    scaler = torch.cuda.amp.GradScaler() if (device.type=='cuda' and args.use_amp) else None

    # 6) Training loop
    best_rmse = float('inf')
    history = []
    for epoch in range(1, args.epochs+1):
        start = time.time()
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler=scaler)
        val_rmse, _, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step(val_rmse if val_rmse is not None else train_loss)
        elapsed = time.time()-start
        print(f"Epoch {epoch}/{args.epochs} — train_loss: {train_loss:.4f} — val_rmse: {val_rmse} — time: {elapsed:.1f}s")
        history.append({'epoch': epoch, 'train_loss': train_loss, 'val_rmse': val_rmse})
        # checkpoint
        torch.save(model.state_dict(), os.path.join(args.output_dir, 'last_model.pth'))
        if val_rmse is not None and val_rmse < best_rmse:
            best_rmse = val_rmse
            torch.save(model.state_dict(), os.path.join(args.output_dir, 'best_model.pth'))

    # save training log
    pd.DataFrame(history).to_csv(os.path.join(args.output_dir, 'train_log.csv'), index=False)

    # 7) Inference on test + val (we'll run on all available rows to produce predictions)
    model.eval()
    def run_inference(loader):
        ids = []
        preds = []
        trues = []
        with torch.no_grad():
            for batch in loader:
                images = batch['image'].to(device)
                tabs = batch['tabular'].to(device) if 'tabular' in batch else None
                out = model(images, tabs)
                preds.extend(out.detach().cpu().numpy().tolist())
                ids.extend(batch['id'])
                if 'price' in batch:
                    trues.extend(batch['price'])
                else:
                    trues.extend([None]*len(batch['id']))
        return ids, trues, preds

    all_ids = []
    all_trues = []
    all_preds = []
    # val
    if len(val_ds)>0:
        ids, trues, preds = run_inference(val_loader)
        all_ids.extend(ids); all_trues.extend(trues); all_preds.extend(preds)
    # test
    if len(test_ds)>0:
        ids, trues, preds = run_inference(test_loader)
        all_ids.extend(ids); all_trues.extend(trues); all_preds.extend(preds)

    out_df = pd.DataFrame({'id': all_ids, 'true_price': all_trues, 'pred_price': all_preds})
    out_csv = os.path.join(args.output_dir, 'predictions.csv')
    out_df.to_csv(out_csv, index=False)
    print('Saved predictions to', out_csv)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--parquet', type=str, required=True)
    parser.add_argument('--images_dir', type=str, required=True)
    parser.add_argument('--output_dir', type=str, default='outputs')
    parser.add_argument('--image_ext', type=str, default='.jpg')
    parser.add_argument('--id_col', type=str, default='id')
    parser.add_argument('--price_col', type=str, default='price')
    parser.add_argument('--val_size', type=float, default=0.15)
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--backbone', type=str, default='resnet18')
    parser.add_argument('--pretrained', action='store_true')
    parser.add_argument('--embed_dim', type=int, default=512)
    parser.add_argument('--head_hidden', nargs='+', type=int, default=[256,64])
    parser.add_argument('--dropout', type=float, default=0.2)
    parser.add_argument('--use_amp', action='store_true')
    parser.add_argument('--use_cpu', action='store_true')
    args = parser.parse_args()
    main(args)
