In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import convnext_tiny, ConvNeXt_Tiny_Weights
from PIL import Image
import csv
import random
import numpy as np


In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
def remove_outliers_quantile(df, column_names, low_q=0.01, high_q=0.99):
    for col in column_names:
        low, high = df[col].quantile([low_q, high_q])
        df = df[(df[col] >= low) & (df[col] <= high)]
    return df.reset_index(drop=True)

# === Dataset Class ===
class LocationDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None, filter_outliers=False):
        df = pd.read_csv(csv_file)

        if filter_outliers:
            df = remove_outliers_quantile(df, ['latitude', 'longitude'])

        self.data = df.reset_index(drop=True)
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = os.path.join(self.image_dir, row['filename'])

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)

        target = torch.tensor([row['latitude'], row['longitude']], dtype=torch.float32)
        return image, target


In [None]:
def compute_mean_std(dataset):
    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    mean, std, n_samples = 0.0, 0.0, 0
    for imgs, _ in loader:
        batch = imgs.size(0)
        imgs = imgs.view(batch, imgs.size(1), -1)
        mean += imgs.mean(2).sum(0)
        std += imgs.std(2).sum(0)
        n_samples += batch
    return (mean / n_samples).tolist(), (std / n_samples).tolist()

# === Training Function ===
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for imgs, targets in loader:
        imgs, targets = imgs.to(device), targets.to(device)
        preds = model(imgs)
        loss = criterion(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for imgs, targets in loader:
        imgs, targets = imgs.to(device), targets.to(device)
        preds = model(imgs)
        loss = criterion(preds, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# === Validation Function (with exclusion) ===
def validate(model, loader, criterion, device, exclude_indices):
    model.eval()
    total_loss, current_index = 0.0, 0
    with torch.no_grad():
        for imgs, targets in loader:
            batch_size = imgs.size(0)
            imgs, targets = imgs.to(device), targets.to(device)
            preds = model(imgs)

            for i in range(batch_size):
                if current_index + i in exclude_indices:
                    preds[i] = targets[i] = torch.tensor([0.0, 0.0], device=device)

            total_loss += criterion(preds, targets).item()
            current_index += batch_size
    return total_loss / len(loader)


In [None]:
def save_predictions(model, loader, device, exclude_indices, output_file='Latlong_predictions.csv', total_samples=738):
    # === Save Predictions ===
# === Save Predictions ===
    predictions = []
    exclude_indices = set([95, 145, 146, 158, 159, 160, 161])
    global_idx = 0  # Track global index across validation dataset

    model.eval()
    with torch.no_grad():
        for imgs, _ in val_loader:
            imgs = imgs.to(device)
            preds = model(imgs)

            for i in range(len(preds)):
                if global_idx in exclude_indices:
                    lat, lon = 0, 0
                else:
                    lat = round(preds[i][0].item())
                    lon = round(preds[i][1].item())
                predictions.append([global_idx, lat, lon])
                global_idx += 1

    # Fill any remaining rows up to 738 (if required)
    for i in range(global_idx, 738):
        predictions.append([i, 0, 0])

    # Write to CSV
    with open('Latlong_predictions.csv', mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['id', 'Latitude', 'Longitude'])
        writer.writerows(predictions)

    print("Predictions (with excluded IDs marked as 0,0) saved to 'Latlong_predictions.csv'")



In [None]:
if __name__ == "__main__":
    # Pre-transform to calculate mean/std
    pre_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    tmp_dataset = LocationDataset('/kaggle/input/csv-files/labels_train.csv', '/kaggle/input/images-train/images_train', transform=pre_transform, filter_outliers=True)
    mean, std = compute_mean_std(tmp_dataset)
    print(f"Computed Mean: {mean}, Std: {std}")
    
    # Final transforms
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize all images to 256x256
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # Random crop of the image and resize to 224x224
        transforms.RandomHorizontalFlip(),  # Flip images horizontally with 50% probability
        transforms.RandomRotation(30),  # Random rotation between -30 and +30 degrees
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),  # Random color adjustments
        transforms.ToTensor(),  # Convert the image to a Tensor
        transforms.Normalize(mean=mean, std=std)  # Normalize the image with the computed mean and std
    ])
    
    # Transforms for Validation (No Augmentation)
    val_transform = transforms.Compose([
        transforms.Resize((256, 256)),  # Resize all images to 256x256
        transforms.ToTensor(),  # Convert the image to a Tensor
        transforms.Normalize(mean=mean, std=std)  # Normalize the image with the computed mean and std
    ])

    # Datasets & Loaders
    train_dataset = LocationDataset('/kaggle/input/csv-files/labels_train.csv', '/kaggle/input/images-val/images_val', transform=train_transform, filter_outliers=True)
    val_dataset = LocationDataset('/kaggle/input/csv-files/labels_train.csv', '/kaggle/input/images-val/images_val', transform=val_transform)
    train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=20)

    # Model Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    weights = ConvNeXt_Tiny_Weights.DEFAULT
    model = convnext_tiny(weights=weights)
    model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
    model = model.to(device)

    # Training Setup
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-4)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.1, verbose=True)

    exclude_indices = set([95, 145, 146, 158, 159, 160, 161])

    best_val_loss = float('inf')  # Initialize with a large value
    best_model_wts = None

    # Train Loop
    for epoch in range(1):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss = validate(model, val_loader, criterion, device, exclude_indices)
        print(f"Epoch {epoch+1} | Train MSE Loss: {train_loss:.4f} | Val MSE Loss: {val_loss:.4f}")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = model.state_dict()  # Save the model's state (weights)
            print(f"Validation loss improved, saving model...")
            torch.save(best_model_wts, 'latlong.pth')
    # Save predictions
    save_predictions(model, val_loader, device, exclude_indices)
