In [None]:
import os
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from efficientnet_pytorch import EfficientNet
from tqdm import tqdm

# Device
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Constants
IMG_SIZE = 224
BATCH_SIZE = 32

# Label mappings
idx2label = {0: 'Alluvial soil', 1: 'Black Soil', 2: 'Clay soil', 3: 'Red soil'}

# Test transforms
test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset class
class SoilDataset(torch.utils.data.Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(os.path.join(self.img_dir, row['image_id'])).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, row['image_id']

# Load model
def get_model():
    model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=4)
    model.load_state_dict(torch.load('best_model.pth', map_location=DEVICE))
    model.eval()
    return model.to(DEVICE)

# Inference
def run_inference(test_csv, test_dir):
    df = pd.read_csv(test_csv)
    dataset = SoilDataset(df, test_dir, test_transforms)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

    model = get_model()

    all_preds, all_ids = [], []
    with torch.no_grad():
        for images, image_ids in tqdm(loader, desc="Inferencing"):
            images = images.to(DEVICE)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_ids.extend(image_ids)

    submission = pd.DataFrame({'image_id': all_ids, 'soil_type': [idx2label[p] for p in all_preds]})
    submission.to_csv('submission.csv', index=False)
    print("Saved submission.csv")

# Example usage
# run_inference('soil_classification-2025/test_ids.csv', 'soil_classification-2025/test')
