In [None]:
import torch
import timm
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Class names in the same order used during training
class_names = ['Alternaria','Anthracnose', 'Bacterial_Blight', 'Cercospora', 'Healthy']

In [None]:
# Set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Define transformation (same as val_transforms from training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [None]:
# Load image 
image_path = "sample.jpg"  # Replace with your test image
image = Image.open(image_path).convert("RGB")
input_tensor = transform(image).unsqueeze(0).to(device)

In [None]:
# Load the DaViT model and weights
model_name = 'davit_base'
model = timm.create_model(model_name, pretrained=False, num_classes=len(class_names))
model.head.fc = torch.nn.Linear(model.head.in_features, len(class_names))
model.load_state_dict(torch.load("models/final/DaViT_Base_Epoch_28.pth", map_location=device))
model = model.to(device)
model.eval()

In [None]:
# Inference 
with torch.no_grad():
    output = model(input_tensor)
    predicted_index = torch.argmax(output, dim=1).item()
    predicted_class = class_names[predicted_index]

In [None]:
# Display results
plt.imshow(image)
plt.title(f"Predicted Class: {predicted_class}")
plt.axis("off")
plt.show()

print(f"Predicted Class: {predicted_class}")