# Training Script

This notebook contains the training code.

In [None]:
info ="""

Author: Annam.ai IIT Ropar
Team Name: SoilClassifiers
Team Members: Caleb Chandrasekar, Sarvesh Chandran, Swaraj Bhattacharjee, Karan Singh, Saatvik Tyagi
Leaderboard Rank: 103

"""

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from sklearn.metrics import f1_score, classification_report
import pandas as pd
from PIL import Image
import os

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Constants
IMG_SIZE = 384
BATCH_SIZE = 32
NUM_EPOCHS = 10
LR = 1e-3

# Soil classes (must match CSV labels exactly)
CLASSES = ['Alluvial soil', 'Black Soil', 'Clay soil', 'Red soil']
class_to_idx = {{c: i for i, c in enumerate(CLASSES)}}

# Transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])

# Dataset
class SoilDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        image_id = str(self.annotations.iloc[idx, 0])
        if '.' not in image_id:
            image_id += '.jpg'
        else:
            base, ext = os.path.splitext(image_id)
            image_id = base + ext.lower()

        img_path = os.path.join(self.root_dir, image_id)
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = self.annotations.iloc[idx, 1].strip()
        return image, class_to_idx[label]

# Prepare DataLoader
train_dataset = SoilDataset(
    csv_file='/kaggle/input/soil-classification/soil_classification-2025/train_labels.csv',
    root_dir='/kaggle/input/soil-classification/soil_classification-2025/train',
    transform=transform
)

train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_data, val_data = torch.utils.data.random_split(train_dataset, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)

# Model setup
model = models.efficientnet_v2_s(weights='IMAGENET1K_V1')
model.classifier[1] = nn.Linear(1280, len(CLASSES))
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2)

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)

    model.eval()
    y_true, y_pred = [], []
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    val_f1 = f1_score(y_true, y_pred, average='weighted')
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], "
          f"Train Loss: {running_loss/len(train_loader.dataset):.4f}, "
          f"Val Loss: {val_loss/len(val_loader.dataset):.4f}, "
          f"Val F1: {val_f1:.4f}")

    scheduler.step(val_loss)

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=CLASSES))