In [None]:
from torchvision import transforms

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


### same transforms used by them


In [2]:
import os
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# Define paths
source_dir = '/Users/kushal/2nd SEM classes /Deep learning /Final Project /blood_dataset_external/iran_2_classes/Test-B'
output_dir = 'iran_2_classes_preprocessed'

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
    # Note: Normalization will be applied later, once the mean and std are determined
])

# Create output directories for each class
classes = ['neutrophil', 'lymphocyte']
for cls in classes:
    os.makedirs(os.path.join(output_dir, cls), exist_ok=True)

# Process and save images
for cls in classes:
    class_dir = os.path.join(source_dir, cls)
    output_class_dir = os.path.join(output_dir, cls)
    if not os.path.isdir(class_dir):
        print(f"Directory for class '{cls}' not found in source directory.")
        continue
    for img_name in tqdm(os.listdir(class_dir), desc=f"Processing {cls} images"):
        img_path = os.path.join(class_dir, img_name)
        try:
            # Open and preprocess the image
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            
            # Save the preprocessed image
            save_path = os.path.join(output_class_dir, img_name)
            # Convert tensor back to PIL Image for saving
            image_pil = transforms.ToPILImage()(image)
            image_pil.save(save_path)
        except Exception as e:
            print(f"Error processing {img_name}: {e}")


Processing neutrophil images: 100%|██████████| 1971/1971 [00:08<00:00, 225.36it/s]
Processing lymphocyte images: 100%|██████████| 148/148 [00:00<00:00, 245.66it/s]


## load data - iran 


In [3]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

data_dir = '/Users/kushal/2nd SEM classes /Deep learning /Final Project /MedViTV2/iran_2_classes_preprocessed'  

iran_dataset = ImageFolder(root=data_dir, transform=test_transform)

iran_loader = DataLoader(iran_dataset, batch_size=32, shuffle=False, num_workers=4)


In [6]:
import torch
device = torch.device('mps')
print(device)

mps


# loading pre trained model 

In [None]:

from MedViT import MedViT_tiny  # Ensure this import matches your project structure

# Instantiate the model with the appropriate number of classes
model = MedViT_tiny(num_classes=2)  # Adjust num_classes as needed

# Load the checkpoint with weights_only=False
checkpoint_path = '/Users/kushal/2nd SEM classes /Deep learning /Final Project /MedViTV2/checkpoints/bloodmnist_MedViT_small.json'  # Replace with your actual path
state_dict = torch.load(checkpoint_path, map_location='cpu', weights_only=False)

# Load the state_dict into the model
model.load_state_dict(state_dict)

# Set the model to evaluation mode
model.eval()
model.to(device)

