In [7]:
import torchvision.transforms as transforms

from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [8]:
import torchvision.transforms as transforms

In [9]:
transform = transforms.Compose(
    [
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5],[0.5])
    ]
)

In [10]:
trained_data = ImageFolder("data/train",transform=transform)

In [11]:
val_data = ImageFolder("data/val",transform=transform)
test_data = ImageFolder("data/test",transform=transform)

In [12]:
train_loader = DataLoader(trained_data,batch_size=32,shuffle=True)
val_loader = DataLoader(val_data,batch_size=32)
test_loader = DataLoader(test_data,batch_size=32)

In [13]:
import torch.nn as nn
import torchvision.models as models

In [5]:
#models.resnet18(pretrained=True)


In [25]:
import torch
model = models.resnet18(pretrained=False)
model.load_state_dict(torch.load('resnet18-f37072fd.pth'))

<All keys matched successfully>

In [26]:
# === Freeze all layers ===
for param in model.parameters():
    param.requires_grad = False


In [27]:

# === Replace final classifier ===
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, 1),
    nn.Sigmoid()
)


In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [29]:
import torch.optim as optim
criterion = nn.BCELoss()
optimizer = optim.Adam(model.fc.parameters(),lr=0.001)

In [30]:
for epoch in range(5):  # Adjust number of epochs
    model.train()
    total_loss = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Training Loss: {total_loss/len(train_loader):.4f}")

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.float().unsqueeze(1).to(device)
            outputs = model(images)
            preds = (outputs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f"Validation Accuracy: {100 * correct / total:.2f}%")

    

Epoch 1, Training Loss: 0.5858
Validation Accuracy: 68.75%
Epoch 2, Training Loss: 0.4033
Validation Accuracy: 56.25%
Epoch 3, Training Loss: 0.3262
Validation Accuracy: 50.00%
Epoch 4, Training Loss: 0.2670
Validation Accuracy: 50.00%
Epoch 5, Training Loss: 0.2377
Validation Accuracy: 62.50%


In [31]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)
        outputs = model(images)
        preds = (outputs > 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")


Test Accuracy: 84.03%


In [32]:
torch.save(model.state_dict(), "resnet18_pneumonia_trained12.pth")
