In [None]:
# notebooks/inference.ipynb

# import
import sys
sys.path.append("/kaggle/working/src")
from postprocessing import generate_submission
from preprocessing import load_and_prepare_data

import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
from tqdm import tqdm

# Label Encoder
_, _, le = load_and_prepare_data("/kaggle/input/soil-classification/soil_classification-2025/train_labels.csv")

# Transforms
img_size = 224
val_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset
class TestSoilDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.loc[idx, 'image_id']
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, img_name

# Load Test Data
import pandas as pd
test_df = pd.read_csv("/kaggle/input/soil-classification/soil_classification-2025/test_ids.csv")
test_dataset = TestSoilDataset(test_df, "/kaggle/input/soil-classification/soil_classification-2025/test/", transform=val_transforms)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18()
model.fc = torch.nn.Linear(model.fc.in_features, 4)
model.load_state_dict(torch.load("/kaggle/working/best_model.pth", map_location=device))
model = model.to(device)
model.eval()

# Predict
predictions = []
with torch.no_grad():
    for images, img_names in tqdm(test_loader):
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        predictions.extend(zip(img_names, preds.cpu().numpy()))

# Submit
generate_submission(predictions, le)