In [15]:
import torch
from captum.attr import IntegratedGradients
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader, Dataset

from PIL import Image
import matplotlib.pyplot as plt


In [16]:
#config

print(torch.__version__)
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")


2.5.0.dev20240619
Using device: mps


In [17]:
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [18]:
test_dir = "../chest_xray/test"


In [19]:
test_dataset = datasets.ImageFolder(root=test_dir, transform=test_transforms)


In [20]:
batch_size = 1


In [21]:
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


In [22]:
class PneumoniaScannerModel(torch.nn.Module):
    def __init__(self, num_classes=1):
        super(PneumoniaScannerModel, self).__init__()
        self.base_model = models.resnet50(weights="ResNet50_Weights.DEFAULT")

        for param in self.base_model.parameters():
            param.requires_grad = False

        num_out_ftrs = self.base_model.fc.out_features
        self.fc = torch.nn.Linear(num_out_ftrs, num_classes)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.base_model(x)
        x = self.fc(x)
        x = self.sigmoid(x)
        return x

model = PneumoniaScannerModel().to(device)


In [23]:
model_pickle_path = "../Model/PneumoniaScannerModel_wSampler.pth"


In [24]:
model.load_state_dict(torch.load(model_pickle_path))


<All keys matched successfully>

In [25]:

# Assuming your model is already defined and trained
model.eval()

dataiter = iter(test_loader)
batch = next(dataiter)
images, labels = batch

input_tensor = images.to(device)

# Initialize Integrated Gradients
ig = IntegratedGradients(model)

# Calculate attributions
attributions = ig.attribute(input_tensor, target=0, n_steps=200)

# Visualize attributions
attributions = attributions.squeeze().cpu().detach().numpy()
attributions = (attributions - attributions.min()) / (attributions.max() - attributions.min())  # Normalize to 0-1
plt.imshow(attributions.transpose(1, 2, 0))
plt.show()


RuntimeError: MPS backend out of memory (MPS allocated: 8.93 GB, other allocations: 2.38 MB, max allowed: 9.07 GB). Tried to allocate 306.25 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).