In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets, models  # Import models module
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score


In [2]:
class ResNetModel(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetModel, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        in_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

In [3]:
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

In [4]:
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])


In [7]:
import os
import cv2
import numpy as np
def load_images(directory, label):
    images = []
    labels = []
    for filename in os.listdir(directory):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            img = cv2.imread(os.path.join(directory, filename))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
            images.append(img)
            labels.append(label)
    return images, labels

In [9]:
clean_images, clean_labels = load_images("cleaned", label=0)
dirty_images, dirty_labels = load_images("polluted", label=1)

all_images = clean_images + dirty_images
all_labels = clean_labels + dirty_labels

In [10]:
X_train, X_test, y_train, y_test = train_test_split(all_images, all_labels, test_size=0.2, random_state=42)

train_dataset = CustomDataset(X_train, y_train, transform=transform)
test_dataset = CustomDataset(X_test, y_test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [12]:
model = ResNetModel()  # Change here to use the ResNet model
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/pranaymishra/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:41<00:00, 1.12MB/s]


In [13]:
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.numpy())
            all_labels.extend(labels.numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")


Epoch 1/10, Loss: 0.9377, Accuracy: 0.2857
Epoch 2/10, Loss: 0.0901, Accuracy: 0.7500
Epoch 3/10, Loss: 0.4257, Accuracy: 0.8929
Epoch 4/10, Loss: 0.0319, Accuracy: 0.6964
Epoch 5/10, Loss: 0.0407, Accuracy: 0.9107
Epoch 6/10, Loss: 0.0269, Accuracy: 0.9643
Epoch 7/10, Loss: 0.0094, Accuracy: 0.9821
Epoch 8/10, Loss: 0.0016, Accuracy: 0.9821
Epoch 9/10, Loss: 0.0034, Accuracy: 0.9821
Epoch 10/10, Loss: 0.0502, Accuracy: 0.9821


In [14]:
test_image = cv2.imread("/Users/pranaymishra/Desktop/ml_practice/ocean_dataset/cleaned/3.jpg")
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
test_image = transform(test_image).unsqueeze(0)


In [15]:
model.eval()
with torch.no_grad():
    output = model(test_image)
    _, predicted_class = torch.max(output, 1)

print(f"The image is predicted to be {'clean' if predicted_class.item() == 0 else 'dirty'}.")

The image is predicted to be clean.
