In [None]:
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader

In [2]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

inference_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [3]:
train_dataset = datasets.ImageFolder(root='/Users/luka/PyCharmProjects/simple-defect-detection/data/casting_data/train', transform=train_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [4]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
device

device(type='mps')

In [5]:
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)

# Modify the final fully connected layer to match the number of classes in your new dataset
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # num_classes should be set to the number of your new categories

model = model.to(device)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [7]:
model.train()
num_epochs = 3

for epoch in range(num_epochs):
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

print("Training complete")

Epoch 1/3, Loss: 0.25737475441732943
Epoch 2/3, Loss: 0.14873862069977734
Epoch 3/3, Loss: 0.1380754316968915
Training complete


In [8]:
val_dataset = datasets.ImageFolder(root='/Users/luka/PyCharmProjects/simple-defect-detection/data/casting_data/test', transform=inference_transforms)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [9]:
model.eval()
val_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        val_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Loss: {val_loss / len(val_loader)}, Accuracy: {100 * correct / total}%")

Test Loss: 0.03548820461804533, Accuracy: 99.02097902097903%


In [10]:
idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()}

def predict(img_path):
  image = Image.open(img_path)

  transformed_image = inference_transforms(image)
  transformed_image = transformed_image.unsqueeze(0)  # Add batch dimension
  transformed_image = transformed_image.to(device)

  model.eval()

  with torch.no_grad():
      outputs = model(transformed_image)
      _, predicted = torch.max(outputs, 1)
      predicted_class = idx_to_class[predicted.item()]

  return predicted_class

In [14]:
predicted_class = predict('/Users/luka/PyCharmProjects/simple-defect-detection/data/casting_data/test/def_front/cast_def_0_1447.jpeg')

print(f'Predicted class: {predicted_class}')

Predicted class: def_front
