In [1]:
import torch
from torch.utils.data import DataLoader as torch_dataloader
from torch.utils.data import Dataset as torch_dataset
import numpy as np
import matplotlib.pyplot as plt
import skimage
import skimage.io as io
import glob
import pandas as pd
#%%
class MyDataset(torch_dataset):
    def __init__(self, path, filenamelist, labellist):
        self.path=path
        self.filenamelist=filenamelist
        self.labellist=labellist
    def __len__(self):
        #return the number of data points
        return len(self.filenamelist)
    def __getitem__(self, idx):
        I=io.imread(self.path+self.filenamelist[idx])
        I=skimage.util.img_as_float32(I)
        I = I.reshape(1,I.shape[0],I.shape[1])
        I = torch.tensor(I, dtype=torch.float32)
        I = I.expand(3, I.shape[1],I.shape[2])
        label=torch.tensor(self.labellist[idx], dtype=torch.int64)
        return I, label
#%%
def get_dataloader():
    df_train=pd.read_csv('C:/S224/train.csv')
    path='C:/S224/'
    dataset_train = MyDataset(path, df_train['filename'].values, df_train['label'].values)
    loader_train = torch_dataloader(dataset_train, batch_size=32, num_workers=2,
                                    shuffle=True, pin_memory=True)
    return loader_train

In [2]:
#Construct a CNN by modifying ResNet-18 or ResNet-50 for binary classification

import torch
import torch.nn as nn
import torchvision.models as models

class ModifiedResNet(nn.Module):
    def __init__(self, num_classes):
        super(ModifiedResNet, self).__init__()
        self.resnet = models.resnet18(pretrained=False)
        num_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_features, num_classes)
    
    def forward(self, x):
        x = self.resnet(x)
        return x

ImportError: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (C:\Users\matth\Miniconda3\envs\mam900\lib\site-packages\charset_normalizer\constant.py)

In [None]:
#train the CNN from scratch

import torch.optim as optim

def train_from_scratch(model, dataloader, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = correct / total
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.4f}")

In [None]:
#train the CNN using transfer learning

def train_with_transfer_learning(model, dataloader, num_epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = correct / total
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Accuracy: {epoch_acc:.4f}")

In [None]:
#visualize the two models (2)&(3) using two CAM methods (e.g., GradCAM and EigenCAM)

import torch
from torchvision.transforms import ToPILImage
from gradcam.utils import visualize_cam
from gradcam import GradCAM, EigenCAM

# Load the trained models
model_from_scratch = ModifiedResNet(num_classes=2)
model_from_scratch.load_state_dict(torch.load('model_from_scratch.pth'))

model_transfer_learning = ModifiedResNet(num_classes=2)
model_transfer_learning.load_state_dict(torch.load('model_transfer_learning.pth'))

# Choose an image for visualization
image_path = 'path_to_image.jpg'  # Replace with the path to your image

# Preprocess the image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0)

# Create GradCAM and EigenCAM instances for the models
gradcam_from_scratch = GradCAM(model_from_scratch)
gradcam_transfer_learning = GradCAM(model_transfer_learning)

# Generate the CAMs for the image
cam_from_scratch = gradcam_from_scratch(image_tensor)
cam_transfer_learning = gradcam_transfer_learning(image_tensor)

# Convert the CAMs to PIL images
cam_image_from_scratch = ToPILImage()(cam_from_scratch.squeeze().cpu())
cam_image_transfer_learning = ToPILImage()(cam_transfer_learning.squeeze().cpu())

# Visualize the original image and the CAMs
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(image)
axs[0].set_title('Original Image')

axs[1].imshow(image)
axs[1].imshow(cam_image_from_scratch, alpha=0.5, cmap='jet')
axs[1].set_title('GradCAM (From Scratch)')

axs[2].imshow(image)
axs[2].imshow(cam_image_transfer_learning, alpha=0.5, cmap='jet')
axs[2].set_title('GradCAM (Transfer Learning)')

plt.tight_layout()
plt.show()