#### Importing need library

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
import torch.optim as optim
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os

#### Set hyperparameters

In [None]:
NUM_EPOCHS = 5
BATCH_SIZE = 32

#### Check hardware accessbility(CUDA)

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

torch.device(device)
torch.cuda.set_device(0)


#### Define transforms for data augmentation and normalization

In [None]:
transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # Randomly crop and resize with random zoom (80% - 100% of original size)
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(25), # Random rotation by up to 10 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2), # Randomly adjust brightness and contrast
    transforms.RandomAffine(0, translate=(0.1,0.1), scale=(0.9, 1.1)), # Random affine transformation (rotate, translate, scale)
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

transform_val_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#### Load your dataset from local PC and define data loaders

In [None]:
# Load the whole folder images 
dataset = ImageFolder(root='/home/mehdirexon/Desktop/vsc/image classifcation/dataset/plastic', transform=transform)

# Split dataset into train, validation, and test sets
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

#### Visualize pictures

In [None]:
def matplotlib_imshow_grid(images, labels, classes, num_rows=2, num_cols=4, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], one_channel=False):
    num_images = len(images)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(7, 7))
    for i in range(num_rows):
        for j in range(num_cols):
            index = i * num_cols + j
            if index < num_images:
                img = images[index]
                label = labels[index]
                if one_channel:
                    img = img.mean(dim=0)
                else:
                    # Unnormalize the image
                    for t, m, s in zip(img, mean, std):
                        t.mul_(s).add_(m)
                img = img / 2 + 0.5     # Unnormalize
                npimg = img.numpy()
                if one_channel:
                    axes[i, j].imshow(npimg, cmap="Greys")
                else:
                    npimg = np.transpose(npimg, (1, 2, 0))
                    npimg = np.clip(npimg, 0, 1)    # Clip to [0, 1] in case of numerical errors
                    axes[i, j].imshow(npimg)
                axes[i, j].set_title(classes[label])
                axes[i, j].axis('off')
            else:
                axes[i, j].axis('off')
    plt.tight_layout()
    plt.show()

dataiter = iter(train_loader)
images, labels = next(dataiter)

# Show images in a grid format with specified number of rows and columns
matplotlib_imshow_grid(images, labels, dataset.classes, num_rows=5, num_cols=5, one_channel=True)

#### Define the ResNet18 model, loss function and optimizer


In [None]:
model = torchvision.models.resnet18(weights='ResNet18_Weights.DEFAULT')
dataiter = iter(train_loader)

inputs, labels = next(dataiter)
inputs, labels = inputs.cuda(), labels.cuda()

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(dataset.classes))

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

#### Train the model

In [None]:
train_loss_values = []  # List to store the training loss values
train_accuracy_values = []  # List to store the training accuracy values
val_loss_values = []  # List to store the validation loss values
val_accuracy_values = []  # List to store the validation accuracy values



# Define a function to save the checkpoint
def save_checkpoint(model, optimizer, epoch, train_loss, train_accuracy, val_loss, val_accuracy, save_dir):
    """
    Save model checkpoint.

    Parameters:
        model (torch.nn.Module): Model to be saved.
        optimizer (torch.optim.Optimizer): Optimizer state to be saved.
        epoch (int): Current epoch.
        train_loss (float): Training loss.
        train_accuracy (float): Training accuracy.
        val_loss (float): Validation loss.
        val_accuracy (float): Validation accuracy.
        save_dir (str): Directory path to save the checkpoint.
    """
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'train_accuracy': train_accuracy,
        'val_loss': val_loss,
        'val_accuracy': val_accuracy,
    }
    checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at {checkpoint_path}")

# Training loop with checkpointing
for epoch in range(NUM_EPOCHS):  # Number of epochs
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # Compute training accuracy
        _, predicted_train = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted_train == labels).sum().item()
        print(f" mini batch: {i} / {len(train_loader)}")
    
    epoch_train_loss = running_loss / len(train_loader)
    epoch_train_accuracy = 100 * correct_train / total_train
    
    # Compute validation accuracy and loss
    model.eval()
    val_running_loss = 0.0
    correct_val = 0
    total_val = 0
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            outputs = model(images)
            val_loss = criterion(outputs, labels)
            val_running_loss += val_loss.item()
            _, predicted_val = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted_val == labels).sum().item()
            
    
    epoch_val_loss = val_running_loss / len(val_loader)
    epoch_val_accuracy = 100 * correct_val / total_val
    
    # Append values for visualization
    train_loss_values.append(epoch_train_loss)
    train_accuracy_values.append(epoch_train_accuracy)
    val_loss_values.append(epoch_val_loss)
    val_accuracy_values.append(epoch_val_accuracy)
    
    print(f"Epoch {epoch+1} completed. Train Loss: {epoch_train_loss:.3f}, Train Accuracy: {epoch_train_accuracy:.2f}%, "
          f"Val Loss: {epoch_val_loss:.3f}, Val Accuracy: {epoch_val_accuracy:.2f}%")
    
    # Save checkpoint
    save_checkpoint(model, optimizer, epoch, epoch_train_loss, epoch_train_accuracy, epoch_val_loss, epoch_val_accuracy, "./epochs")

# Plot the loss and accuracy graphs
plt.figure(figsize=(12, 5))

# Loss subplot
plt.subplot(1, 2, 1)
plt.plot(range(1, NUM_EPOCHS + 1), train_loss_values, label='Train Loss')
plt.plot(range(1, NUM_EPOCHS + 1), val_loss_values, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Accuracy subplot
plt.subplot(1, 2, 2)
plt.plot(range(1, NUM_EPOCHS + 1), train_accuracy_values, label='Train Accuracy')
plt.plot(range(1, NUM_EPOCHS + 1), val_accuracy_values, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

#### Choosing the best one

In [None]:
# Example usage:
checkpoint_path = '/home/mehdirexon/Desktop/vsc/image classifcation/epochs/checkpoint_epoch_12.pth'  # Change this to the path of your checkpoint file

# Load the checkpoint
checkpoint = torch.load(checkpoint_path)

# Load model state
model.load_state_dict(checkpoint['model_state_dict'])

# If you also want to load the optimizer state
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# Get the epoch at which the checkpoint was saved
epoch_loaded = checkpoint['epoch']

print(f"Checkpoint loaded from {checkpoint_path}. Epoch: {epoch_loaded}")

#### Evaluation on validation set


In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in val_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on validation set: %d %%' % (100 * correct / total))

#### Evaluation on test set

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy on test set: %d %%' % (100 * correct / total))

#### Calculating metrics and confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, classification_report

# Evaluate the model on the test set
model.eval()
true_labels = []
predicted_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        true_labels.extend(labels.numpy())
        predicted_labels.extend(predicted.numpy())

# Compute confusion matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Compute classification report
report = classification_report(true_labels, predicted_labels)

# Calculate accuracy
accuracy = np.sum(np.array(true_labels) == np.array(predicted_labels)) / len(true_labels)

print("\nClassification Report:")
print(report)
print("\nAccuracy:", accuracy)

import matplotlib.pyplot as plt
import seaborn as sns

class_names = dataset.classes

# Normalize confusion matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot normalized confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm_normalized, annot=True, cmap='Blues', fmt=".2f", xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Normalized Confusion Matrix')
plt.show()


#### Predicting single image (Laboratory)

In [None]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and preprocess the image
image = Image.open('/home/mehdirexon/Desktop/vsc/image classifcation/pete.jpeg')  # Load your image
image = transform(image)
image = image.unsqueeze(0)  # Add a batch dimension

# Make predictions
with torch.no_grad():
    outputs = model(image)
    _, predicted = torch.max(outputs, 1)
    probabilities = torch.softmax(outputs, dim=1)

# Interpret the predictions
predicted_class = predicted.item()
class_probabilities = probabilities[0].tolist()

In [None]:
predicted_class,class_probabilities

#### Saving model

In [None]:
torch.save(model, 'plastic_model.pth')