# Inference Script with TTA

This notebook contains the inference code with Test-Time Augmentation.

In [None]:
info ="""

Author: Annam.ai IIT Ropar
Team Name: SoilClassifiers
Team Members: Caleb Chandrasekar, Sarvesh Chandran, Swaraj Bhattacharjee, Karan Singh, Saatvik Tyagi
Leaderboard Rank: 103

"""ss

In [None]:
import torch
from torchvision import transforms, models
from PIL import Image
import pandas as pd
import os

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

IMG_SIZE = 384

# Soil classes
CLASSES = ['Alluvial soil', 'Black Soil', 'Clay soil', 'Red soil']

# Base transform
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# TTA transforms
tta_transforms = [
    transform,
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ]),
]

# Load model
model = models.efficientnet_v2_s(weights=None)
model.classifier[1] = torch.nn.Linear(1280, len(CLASSES))
model.load_state_dict(torch.load('/path/to/saved_model.pth', map_location=device))
model = model.to(device)
model.eval()

# Load test IDs and perform TTA
test_ids_df = pd.read_csv('/kaggle/input/soil-classification/soil_classification-2025/test_ids.csv')
test_ids = test_ids_df['image_id'].tolist()

preds = []
for img_id in test_ids:
    img_path = os.path.join('/kaggle/input/soil-classification/soil_classification-2025/test', img_id if '.' in img_id else img_id + '.jpg')
    image = Image.open(img_path).convert('RGB')
    tta_logits = []
    for aug in tta_transforms:
        augmented_img = aug(image).unsqueeze(0).to(device)
        tta_logits.append(model(augmented_img))
    mean_logits = torch.mean(torch.stack(tta_logits), dim=0)
    _, final_pred = torch.max(mean_logits, 1)
    preds.append(final_pred.item())

predicted_labels = [CLASSES[i] for i in preds]
submission = pd.DataFrame({
    'image_id': test_ids,
    'soil_type': predicted_labels
})
submission.to_csv('/mnt/data/submission.csv', index=False)