In [15]:
import torchvision
import torch
import torchvision.transforms as transforms
import PIL.Image as Image

In [69]:
classes = [
    'Beer Bottle',
    'Plastic Bottle',
    'Soda Bottle',
    'Water Bottle',
    'Wine Bottle'
]

In [78]:
model = torch.load('best_model.pth')

In [71]:
def get_mean_and_std(loader):
    mean = 0.
    std = 0.
    total_images_count = 0
    for images, _ in loader:
        image_count_in_a_batch = images.size(0)
        images = images.view(image_count_in_a_batch, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images_count += image_count_in_a_batch
        
    mean /= total_images_count
    std /= total_images_count
    
    return mean, std

In [72]:
training_dataset_path = './splitted_bootles/train'
training_transforms = transforms.Compose([transforms.Resize((224,224)), transforms.ToTensor()])
training_dataset = torchvision.datasets.ImageFolder(root = training_dataset_path, transform = training_transforms)
training_loader = torch.utils.data.DataLoader(dataset = training_dataset, batch_size = 32, shuffle = False)
mean_and_std = get_mean_and_std(training_loader)
print(mean_and_std)

(tensor([0.4729, 0.4099, 0.3521]), tensor([0.1786, 0.1670, 0.1610]))


In [79]:
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean_and_std[0], mean_and_std[1])
])

In [80]:
def set_device():
    if torch.cuda.is_available():
        dev = "cuda"
    else:
        dev = "cpu"
    return torch.device(dev)

In [81]:
def classify(model, image_transforms, image_path, classes):
    model = model.eval()
    image = Image.open(image_path)
    image = image_transforms(image)
    image = image.unsqueeze(0)
    
    device = set_device()
    
    image = image.to(device)
    model = model.to(device)
    
    output = model(image)
    _, predicted = torch.max(output.data, 1)
    
    print(classes[predicted.item()])
    

In [82]:
path = 'plastic.jpg'

In [83]:
classify(model, image_transforms, path, classes)

Plastic Bottle
