In [2]:
import torch
from torchvision.models import resnet18, ResNet18_Weights 
from torchvision.transforms import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import os
from PIL import Image

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model_transforms = ResNet18_Weights.DEFAULT.transforms()

transform = transforms.Compose([
    model_transforms
])

class BallDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform = None):
        super().__init__()
        self.labels_frame = pd.read_csv(csv_path)
        self.image_dir = img_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.labels_frame)
    
    def __getitem__(self, index):
        img_name = self.labels_frame.iloc[index, 0]
        image = Image.open(os.path.join(self.image_dir, img_name)).convert('RGB')

        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor([self.labels_frame.iloc[index , 1]], dtype=torch.long).squeeze(-1)
        
        return image, label
    



In [None]:
test_dataset = BallDataset(
    csv_path="../data/practice-data/practice.csv",
    img_dir="../data/practice-data",
    transform = transform
)

test_dataloader = DataLoader(test_dataset, batch_size=4)

img, label = next(iter(test_dataloader))
print(img.shape)
print(label)


torch.Size([2, 3, 224, 224])
tensor([805, 852])


In [16]:
model.eval()
with torch.no_grad():
    num_correct = 0
    for X, y in test_dataloader:
        preds = model(X)
        preds = torch.argmax(preds, dim=1).squeeze()
        print(preds.shape)
        print(y.shape)
        num_correct += (preds == y).sum().item()
    
    print(num_correct)
        
        

torch.Size([2])
torch.Size([2])
2
