In [12]:
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
import os


In [13]:
def create_image_dataframe(directory):
    image_paths = []
    labels = []
    for label in ['PNEUMONIA', 'NORMAL']:
        subfolder_path = os.path.join(directory, label)
        for filename in os.listdir(subfolder_path):
            file_path = os.path.join(subfolder_path, filename)
            image_paths.append(file_path)
            labels.append(label)
    
    df = pd.DataFrame({
        'image_path': image_paths,
        'label': labels
    })
    return df


In [14]:
directory = "C:\\Users\\Vatsal\\Documents\\project_pneumonia\\pneumonia_dataset\\chest_xray\\train"
df = create_image_dataframe(directory)

label_map = {'PNEUMONIA': 0, 'NORMAL': 1}
df['label'] = df['label'].map(label_map)

model_name_or_path = 'google/vit-base-patch16-224-in21k'
vit_feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained(model_name_or_path, num_labels=2)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
class CustomImageDataset(Dataset):
    def __init__(self, dataframe, feature_extractor):
        self.dataframe = dataframe
        self.feature_extractor = feature_extractor
    
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['image_path']
        label = self.dataframe.iloc[idx]['label']
        
        image = Image.open(img_path).convert("RGB")
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        
        return {
            'pixel_values': inputs['pixel_values'].squeeze(),  
            'label': torch.tensor(label, dtype=torch.long)
        }

In [16]:
train_df = df.sample(frac=0.75, random_state=42)
test_df = df.drop(train_df.index)

train_dataset = CustomImageDataset(train_df, vit_feature_extractor)
test_dataset = CustomImageDataset(test_df, vit_feature_extractor)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

In [17]:

for epoch in range(10): 
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        inputs = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values=inputs)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

Epoch 1, Loss: 0.14056869612907674
Epoch 2, Loss: 0.05988417710853714
Epoch 3, Loss: 0.03331004919013835
Epoch 4, Loss: 0.020098323670238758
Epoch 5, Loss: 0.018952239647122948
Epoch 6, Loss: 0.019254430187180455
Epoch 7, Loss: 0.0018411772641075245
Epoch 8, Loss: 0.0008188466303877091
Epoch 9, Loss: 0.00048089469015238337
Epoch 10, Loss: 0.00035028869275704545


In [18]:
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        print(f"Inputs shape: {inputs.shape}")
        print(f"Logits shape: {outputs.logits.shape}")
        print(f"Labels shape: {labels.shape}")

        inputs = batch['pixel_values'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(pixel_values=inputs)
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total}%")

Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Logits shape: torch.Size([8, 2])
Labels shape: torch.Size([8])
Inputs shape: torch.Size([8, 3, 224, 224])
Log

In [19]:
model_save_path = "vit_pneumonia_predictor.pth"
torch.save(model.state_dict(), model_save_path)
print (f"Model saved to {model_save_path}")

Model saved to vit_pneumonia_predictor.pth
