### Training the Model

In [None]:
# Importing librairies needed for the code to run
import torch #Main PyTorch library for tensor operations and neural networks
import torch.nn as nn #needed module (layers, loss functions)
import torch.optim as optim #needed module (optimization)
import torchvision.transforms as transforms # Image preprocessing (resizing, normalization)
from torchvision.datasets import ImageFolder # Automatic dataset organization from folders
from torch.utils.data import DataLoader # Data loading and batching
from efficientnet_pytorch import EfficientNet # Pre-trained EfficientNet model

### Step 1: Data Collection (specify training images folder path)
dataset_path = "/ExamplePath/training-dataset" 

### Step 2: Parameters definition
#Data Preprocessing definition 
transform = transforms.Compose([
    transforms.Resize((224, 224)), #reduce images to 224x224 [adjustable]
    transforms.ToTensor(),
])

# Create the ImageFolder dataset (with the subfolders as classes)
Training_dataset = ImageFolder(root=dataset_path, transform=transform)

# Create a DataLoader for the dataset (define iterations parameters)
batch_size = 32 # Number of samples per batch [adjustable]
shuffle = True # Shuffle the data at every epoch
data_loader = DataLoader(Training_dataset, batch_size=batch_size, shuffle=shuffle)

### Step 3: Model definition
# Instantiate EfficientNet-B0 model with 2 ouput classes (ClassA and ClassB)
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)  

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001) #standard learning rate [adjustable]

### Step 4: Model training 
# Train the model
num_epochs = 40 # Number of epochs [adjustable]
for epoch in range(num_epochs):
    for inputs, labels in data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}') #keep track of the training process

# Save the trained model into a specified folder, as .pth format
torch.save(model.state_dict(), '/ExamplePath/trained_model.pth')

### Testing the Model

In [None]:
import os
import torch
from efficientnet_pytorch import EfficientNet
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# Define the path to the pre-trained model file
model_path = '/ExamplePath/trained_model.pth'

# Define the path to the folder containing test images
test_folder_path = '/ExamplePath/testing-dataset'

# Define the transformation to apply to the test images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load the pre-trained model
model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2)
model.load_state_dict(torch.load(model_path))
model.eval()

# Create a custom dataset for the test images without class folders
class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform)
    
    def __getitem__(self, index):
        path, _ = self.samples[index]
        image = self.loader(path)
        if self.transform is not None:
            image = self.transform(image)
        return image

# Load the test dataset
test_dataset = CustomImageFolder(root=test_folder_path, transform=transform)

# Create a DataLoader for the test dataset
batch_size = 1  # Process one image at a time
shuffle = False  # Do not shuffle the test dataset
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle)

# Perform inference on the test dataset
predictions = []
for inputs in test_loader:
    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        predictions.append(predicted.item())

# Display the predictions along with the file names
print("Predictions for test images:")
for i, (prediction, (inputs, _)) in enumerate(zip(predictions, test_dataset.samples)):
    image_name = os.path.basename(inputs)
    print(f"Image {i+1} ({image_name}): {'ClassA' if prediction == 0 else 'ClassB'}")