In [1]:
import os
import random
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tempfile import TemporaryDirectory
import pandas as pd
import csv

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, random_split

import torchvision
from torchvision import datasets, models, transforms

In [2]:
# Load dataset from folders
class FolderDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        Initialize the FolderDataset.

        Args:
            data_dir (str): Path to the dataset directory containing subdirectories for each class.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.classes = self._find_classes()
        self.img_names, self.image_paths, self.labels = self._load_data()

    def _find_classes(self):
        """
        Find class labels by identifying subdirectories in the dataset directory.

        Returns:
            list: List of class labels.
        """
        classes = sorted([d for d in os.listdir(self.data_dir) if os.path.isdir(os.path.join(self.data_dir, d))])
        return classes

    def _load_data(self):
        """
        Load data by iterating through each class subdirectory and collecting image paths and labels.

        Returns:
            tuple: A tuple containing lists of image names, image paths, and corresponding labels.
        """
        image_paths = []
        labels = []
        img_names = []
        
        for label in self.classes:
            class_dir = os.path.join(self.data_dir, label)
            for img_name in os.listdir(class_dir):
                img_path = os.path.join(class_dir, img_name)
                img_names.append(img_name)
                image_paths.append(img_path)
                labels.append(int(label))
        return img_names, image_paths, labels

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.image_paths)

    def __getitem__(self, idx):
        """
        Retrieve a sample from the dataset by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the image name, image data, and corresponding label.
        """
        img_name = self.img_names[idx]
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        label = int(label)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return img_name, image, label


In [3]:
def find_image_path(image_name, folder_path):
    """
    Find the path of an image file with the given name in the specified folder.

    Args:
        image_name (str): The name of the image file without extension.
        folder_path (str): The path to the folder containing the image file.

    Returns:
        str or None: The path of the image file if found, otherwise None.
    """
    # List of common image file extensions
    extensions = ['.jpg', '.jpeg', '.png']

    # Iterate through each extension to check if the image file exists
    for ext in extensions:
        img_path = os.path.join(folder_path, image_name + ext)
        if os.path.exists(img_path):
            return img_path
    
    # If the image file is not found with any of the extensions, return None
    return None


In [4]:
class FromCSVDataset(Dataset):
    def __init__(self, csv_file, data_dir, transform=None):
        """
        Initialize the FromCSVDataset.

        Args:
            csv_file (str): Path to the CSV file containing image information.
            data_dir (str): Path to the directory containing the image files.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_dir = data_dir
        self.transform = transform
        self.classes = self._find_classes(csv_file)
        self.image_names, self.labels = self._load_data(csv_file)

    def _find_classes(self, csv_file):
        """
        Find unique class labels from the CSV file.

        Args:
            csv_file (str): Path to the CSV file containing image information.

        Returns:
            list: List of unique class labels.
        """
        df = pd.read_csv(csv_file)
        classes = sorted(df['label'].unique())
        return classes

    def _load_data(self, csv_file):
        """
        Load image names and labels from the CSV file.

        Args:
            csv_file (str): Path to the CSV file containing image information.

        Returns:
            tuple: A tuple containing lists of image names and corresponding labels.
        """
        df = pd.read_csv(csv_file)
        image_names = df['image_name'].tolist()
        labels = df['label'].tolist()
        return image_names, labels

    def __len__(self):
        """
        Get the length of the dataset.

        Returns:
            int: Number of samples in the dataset.
        """
        return len(self.image_names)

    def __getitem__(self, idx):
        """
        Retrieve a sample from the dataset by index.

        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: A tuple containing the image name, image data, and corresponding label.
        """
        image_name = str(self.image_names[idx])
        
        img_path = find_image_path(image_name, self.data_dir)
        
        if not img_path:
            raise FileNotFoundError(f"No valid image file found for {image_name} with supported extensions {['.jpg', '.jpeg', '.png']}")
        
        label = int(self.labels[idx])
        
        try:
            image = Image.open(img_path).convert("RGB")
        except FileNotFoundError:
            print(f"File not found: {img_path}")
            raise
        except UnidentifiedImageError:
            print(f"Cannot identify image file: {img_path}")
            raise
        except Exception as e:
            print(f"Unexpected error loading image {img_path}: {e}")
            raise

        if self.transform:
            try:
                image = self.transform(image)
            except Exception as e:
                print(f"Error applying transform to image {img_path}: {e}")
                raise

        return image_name, image, label


In [5]:
# Data transformation
transform = transforms.Compose([
        transforms.Resize((224, 224)),  
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

In [6]:
# Location of data and csv file
csv_file = '/kaggle/input/wb-data/wb_recognition_dataset/val/labels.csv'
data_dir = '/kaggle/input/wb-data/wb_recognition_dataset/val/images'

In [7]:
# Test dataset (use suitable Dataset class)
test_dataset = FromCSVDataset(csv_file, data_dir, transform)

In [8]:
# Dataloader iterators
test_dataloader = DataLoader(test_dataset, batch_size = 32, shuffle = False, num_workers=4, pin_memory = True)

In [9]:
# Check number of images and labels
test_dataset_size = len(test_dataset)
print('Number of images in test dataset: ', len(test_dataset))
print('Number of labels in test dataset: ',len(test_dataset.classes))

Number of images in test dataset:  1392
Number of labels in test dataset:  595


In [10]:
# Number of classes
num_classes = 2139  

In [11]:
pip install efficientnet-pytorch

  pid, fd = os.forkpty()


Collecting efficientnet-pytorch
  Downloading efficientnet_pytorch-0.7.1.tar.gz (21 kB)
  Preparing metadata (setup.py) ... [?25ldone
Building wheels for collected packages: efficientnet-pytorch
  Building wheel for efficientnet-pytorch (setup.py) ... [?25ldone
[?25h  Created wheel for efficientnet-pytorch: filename=efficientnet_pytorch-0.7.1-py3-none-any.whl size=16428 sha256=f848b3085307cf7a632aaa4bb5f1c0c52637483be231c83df71ee936b0b6ec36
  Stored in directory: /root/.cache/pip/wheels/03/3f/e9/911b1bc46869644912bda90a56bcf7b960f20b5187feea3baf
Successfully built efficientnet-pytorch
Installing collected packages: efficientnet-pytorch
Successfully installed efficientnet-pytorch-0.7.1
Note: you may need to restart the kernel to use updated packages.


In [12]:
# Load pretrained model
from efficientnet_pytorch import EfficientNet
model = EfficientNet.from_pretrained('efficientnet-b0')
num_ftrs = model._fc.in_features
model._fc = nn.Linear(num_ftrs, num_classes)

# Move to gpu 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth
100%|██████████| 20.4M/20.4M [00:00<00:00, 51.9MB/s]

Loaded pretrained weights for efficientnet-b0





EfficientNet(
  (_conv_stem): Conv2dStaticSamePadding(
    3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
    (static_padding): ZeroPad2d((0, 1, 0, 1))
  )
  (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
  (_blocks): ModuleList(
    (0): MBConvBlock(
      (_depthwise_conv): Conv2dStaticSamePadding(
        32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
        (static_padding): ZeroPad2d((1, 1, 1, 1))
      )
      (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_se_reduce): Conv2dStaticSamePadding(
        32, 8, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_se_expand): Conv2dStaticSamePadding(
        8, 32, kernel_size=(1, 1), stride=(1, 1)
        (static_padding): Identity()
      )
      (_project_conv): Conv2dStaticSamePadding(
        32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False
    

In [13]:
def test_model(model, dataloader, csv_result):
    # Set the model to evaluation mode
    model.eval()

    # Keep track of the number of correct predictions
    running_corrects = 0
    
    # A list to store the results
    results = []

    for image_names, inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Get the predicted labels
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        # Accumulate the number of correct predictions
        running_corrects += torch.sum(preds == labels.data)
        
        # Store the image names and predicted labels in the results list
        for image_name, pred in zip(image_names, preds):
            # use image_name.item() for csv
            results.append([image_name, pred.item()])
            
    # Save the results to a CSV file 
    try:
        with open(csv_result, 'w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(['image_name', 'label'])
            writer.writerows(results)
        print(f'Results saved successfully to {csv_result}')
    except Exception as e:
        print(f'Error saving results to CSV: {e}')

     # Calculate the accuracy
    epoch_acc = running_corrects.double() / test_dataset_size

    print(f'Accuracy: {epoch_acc:.4f}')
            

In [14]:
# Location of saved model 
saved_model_path = '/kaggle/input/efficientnetb0-1/pytorch/efficientnet_b0_1/4/efficientNetb0-imagenet-11-best.pt'

In [15]:
try:
    if torch.cuda.is_available():
        model.load_state_dict(torch.load(saved_model_path))
    else:
        model.load_state_dict(torch.load(saved_model_path, map_location=torch.device('cpu')))
    print('Loaded saved model successfully')
except FileNotFoundError:
    print('File not found')
except Exception as e:
    print(f'An error occurred: {e}')

Loaded saved model successfully


In [16]:
csv_result_path = '/kaggle/working/result_folder.csv'

In [17]:
try:
    test_model(model, test_dataloader, csv_result_path)
except Exception as e:
    print(f'An error occurred during model testing: {e}')
    

  self.pid = os.fork()
  self.pid = os.fork()


Results saved successfully to /kaggle/working/result_folder.csv
Accuracy: 0.9511
