In [23]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch import save, load

class WhaleDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform
        self.image_names = self.data['image'].tolist()
        self.labels = self.data['species'].tolist()
        self.classes = self.data['species'].unique()
        self.encode = {k: i for i,k in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        image_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(image_name).convert('RGB')
        label = self.encode[self.labels[idx]]

        if self.transform:
            image = self.transform(image)

        return image, label

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = WhaleDataset(csv_file='filtered_train.csv', image_dir='train_images', transform=transform)
num_classes = dataset.data['species'].nunique()

class InferenceDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_names = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(image_name).convert('RGB')


        if self.transform:
            image = self.transform(image)

        return image, image_name

class WhaleClassifier(nn.Module):
    def __init__(self, num_classes):
        super(WhaleClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3,padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 56 * 56, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.embedding = nn.Embedding(num_classes, num_classes)

    def forward(self, x, labels=None):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv3(out)
        out = self.relu(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)

        if labels is not None:
            out = self.embedding(labels).squeeze(1)

        return out

model = WhaleClassifier(num_classes)
model.load_state_dict(torch.load('model_state.pt'))

model.eval()

WhaleClassifier(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=200704, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=30, bias=True)
  (embedding): Embedding(30, 30)
)

In [25]:
correct, total = 0, 0

df = pd.read_csv('filtered_train.csv')

In [28]:
inference_dataset = InferenceDataset(image_dir='inference_images', transform=transform)
inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

with torch.no_grad():
    for images, image_names in inference_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        reverse_encode = {v: k for k, v in dataset.encode.items()}
        predicted_labels = [reverse_encode[label.item()] for label in predicted]

        # Process the predictions and image names here
        for image_name, predicted_label in zip(image_names, predicted_labels):
            # Do something with the image name and predicted label
            # print(f"Image: {image_name}, Predicted Label: {predicted_label}")
            image_name = image_name.replace('inference_images/', '')
            actual_label = df.loc[df['image'] == image_name, 'species'].values[0]

            if actual_label == predicted_label:
                correct += 1

            total += 1

print(f'')

melon_headed_whale
beluga
cuviers_beaked_whale
beluga
beluga
beluga
blue_whale
spinner_dolphin
pantropic_spotted_dolphin
bottlenose_dolphin
spinner_dolphin
bottlenose_dolphin
humpback_whale
humpback_whale
false_killer_whale
beluga
beluga
bottlenose_dolpin
kiler_whale
humpback_whale
long_finned_pilot_whale
humpback_whale
killer_whale
blue_whale
