In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #use GPU if possible


In [None]:
#mount drive
from google.colab import drive
drive.mount('/content/drive')
data_dir = '/content/drive/MyDrive/BME450Final1'

#Preprocessing Training Data
transform_train = transforms.Compose([
    transforms.Grayscale(), #convert images to grayscale
    transforms.RandomHorizontalFlip(), #horizontally flip images at random
    transforms.RandomRotation(15), #rotate images up to 15 degrees at random
    transforms.Resize((160, 160)), #resize all images to 160 by 160
    transforms.ToTensor(), #convert to tensor
    transforms.Normalize([0.5], [0.5]) #normalize
])

#Preprocessing Validation Data
transform_val = transforms.Compose([
    transforms.Grayscale(), #convert images to grayscale
    transforms.Resize((160, 160)), #resize all images to 160 by 160
    transforms.ToTensor(), # convert to tensor
    transforms.Normalize([0.5], [0.5]) #normalize
])

#Loading the data
train_dataset = ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform_train)
val_dataset = ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform_val)

#Creating DataLoaders for the data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

class_names = train_dataset.classes
num_classes = len(class_names)
print("Classes:", class_names)


In [None]:
#MobileNetV2 CNN
model = models.mobilenet_v2(pretrained=True) #upload mobilenetv2
model.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False) #changing 1st conv. layer to accept 1 channel grayscale input
model.classifier = nn.Sequential(
    nn.Dropout(0.3), #dropout rate with 30% probability, change dropout rate HERE
    nn.Linear(model.last_channel, num_classes) #final layer matching classes from data for output
)
model = model.to(device)


In [None]:
#Training function
def train_epoch(model, loader, optimizer, loss_fun):
    model.train()
    running_loss, correct, total = 0.0, 0, 0 #initializing variables
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad() #clearing past gradients
        outputs = model(images) #forward pass
        loss = loss_func(outputs, labels) #calculating loss (predictions and truth)
        loss.backward() #backward pass
        optimizer.step() #updating weights
        running_loss += loss.item() #summing loss
        _, preds = torch.max(outputs, 1) #predicting the label
        correct += (preds == labels).sum().item() #number of correct predictions
        total += labels.size(0)
    return running_loss / len(loader), correct / total #returning avg loss and accuracy

#Validation Function
def eval_epoch(model, loader, loss_func):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0 #intialize varibles
    all_preds, all_labels = [], []
    with torch.no_grad(): #disable gradient
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images) #forward pass
            loss = loss_func(outputs, labels) #calculate loss
            running_loss += loss.item() #summing loss
            _, preds = torch.max(outputs, 1) #predicted class labels
            correct += (preds == labels).sum().item() #number of correct predictions
            total += labels.size(0)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return running_loss / len(loader), correct / total, all_preds, all_labels #return avg loss, accuracy, lists of predictions and labels


In [None]:
EPOCHS = 20 #epoch constant
loss_func = nn.CrossEntropyLoss() #using cross entropy loss function
optimizer = optim.Adam(model.parameters(), lr=.01, weight_decay=1e-4) #CHANGE LR HERE
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) #learning rate scheduler

#Initializing variables for early stopping mechanism
best_val_acc = 0
patience = 5
patience_counter = 0

train_accs, val_accs = [], []
train_losses, val_losses = [], []

#Training Loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, loss_func)
    val_loss, val_acc, _, _ = eval_epoch(model, val_loader, loss_func)
    scheduler.step()

    train_accs.append(train_acc)
    val_accs.append(val_acc)
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    print(f"Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}")

#Save output if val accuracy has improved
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered.") #training stopped early once count reaches 5
            break


In [None]:
#Plotting Loss and Accuracy Curves for Train and Val
plt.figure
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title('Loss Curve')

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.legend()
plt.title('Accuracy Curve')
plt.show()


In [None]:
model.load_state_dict(torch.load("best_model.pth")) #load best model weights
_, _, preds, labels = eval_epoch(model, val_loader, loss_func) #evaluate to get predicted and true labels

print("Classification Report:")
print(classification_report(labels, preds, target_names=class_names))

#Computing and plotting confusion matrix
cm = confusion_matrix(labels, preds)
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.colorbar()
plt.xticks(np.arange(num_classes), class_names, rotation=45)
plt.yticks(np.arange(num_classes), class_names)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.tight_layout()
plt.show()
