<a href="https://colab.research.google.com/github/ketanp23/sit-neuralnetworks-class/blob/main/KFoldCrossValidation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import numpy as np

# Data
iris = load_iris()
X = iris.data
y = iris.target
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Neural Network
class IrisNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(4, 10)
        self.fc2 = nn.Linear(10, 3)  # 3 classes

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

# K-Fold CV
k = 5
kf = KFold(n_splits=k, shuffle=True, random_state=42)
accuracies = []

for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
    # Split data
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    # Convert to tensors
    X_train = torch.tensor(X_train, dtype=torch.float)
    X_test = torch.tensor(X_test, dtype=torch.float)
    y_train = torch.tensor(y_train, dtype=torch.long)
    y_test = torch.tensor(y_test, dtype=torch.long)

    # Model, loss, optimizer
    model = IrisNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # Train
    for epoch in range(100):
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

    # Test
    with torch.no_grad():
        outputs = model(X_test)
        _, preds = torch.max(outputs, 1)
        accuracy = (preds == y_test).float().mean()
        accuracies.append(accuracy)
        print(f"Fold {fold+1} Accuracy: {accuracy:.4f}")

# Average performance
print(f"Average Accuracy: {np.mean(accuracies):.4f}")

Fold 1 Accuracy: 1.0000
Fold 2 Accuracy: 0.9667
Fold 3 Accuracy: 0.9333
Fold 4 Accuracy: 0.9667
Fold 5 Accuracy: 0.9667
Average Accuracy: 0.9667
