In [16]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch import save, load

In [23]:
class WhaleDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform
        self.image_names = self.data['image'].tolist()
        self.labels = self.data['species'].tolist()
        self.classes = self.data['species'].unique()
        self.encode = {k: i for i,k in enumerate(self.classes)}

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        image_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(image_name).convert('RGB')
        label = self.encode[self.labels[idx]]

        if self.transform:
            image = self.transform(image)

        return image, label

In [24]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [25]:
dataset = WhaleDataset(csv_file='filtered_train.csv', image_dir='train_images', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [26]:
class WhaleClassifier(nn.Module):
    def __init__(self, num_classes):
        super(WhaleClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3,padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 56 * 56, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.embedding = nn.Embedding(num_classes, num_classes)

    def forward(self, x, labels=None):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv3(out)
        out = self.relu(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)

        if labels is not None:
            out = self.embedding(labels).squeeze(1)

        return out

In [27]:
num_classes = dataset.data['species'].nunique()
model = WhaleClassifier(num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [28]:
num_epochs = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs, labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    print(f'Epoch: {epoch+1}, Loss: {epoch_loss:.4f}')

Epoch: 1, Loss: 3.1922


In [29]:
with open('model_state.pt', 'wb') as f:
    save(model.state_dict(), f)

In [33]:
model = WhaleClassifier(num_classes)
model.load_state_dict(torch.load('model_state.pt'))

model.eval()

WhaleClassifier(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu): ReLU()
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=200704, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=30, bias=True)
  (embedding): Embedding(30, 30)
)

In [34]:
class InferenceDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_names = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        image_name = os.path.join(self.image_dir, self.image_names[idx])
        image = Image.open(image_name).convert('RGB')


        if self.transform:
            image = self.transform(image)

        return image

In [39]:
inference_dataset = InferenceDataset(image_dir='inference_images', transform=transform)
inference_loader = DataLoader(inference_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

with torch.no_grad():
    for images in inference_loader:
        images = images.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

reverse_encode = {v:k for k, v in dataset.encode.items()}
predicted_labels = [reverse_encode[label.item()] for label in predicted]

tensor([[[[-0.8849, -0.8849, -0.8849,  ..., -1.2445, -1.2103, -1.1760],
          [-0.9363, -1.0048, -1.0733,  ..., -1.2788, -1.2103, -1.2274],
          [-0.9877, -1.0904, -1.1760,  ..., -1.2274, -1.2103, -1.1760],
          ...,
          [-1.4500, -1.4672, -1.5185,  ..., -1.6213, -1.6384, -1.6384],
          [-1.2959, -1.4843, -1.4843,  ..., -1.6042, -1.6213, -1.6384],
          [-1.5528, -1.6042, -1.3987,  ..., -1.5357, -1.5185, -1.5357]],

         [[-0.1800, -0.2150, -0.2500,  ..., -0.6352, -0.6176, -0.5826],
          [-0.2675, -0.3375, -0.3901,  ..., -0.6352, -0.6001, -0.6001],
          [-0.2850, -0.4076, -0.4601,  ..., -0.6352, -0.6176, -0.5826],
          ...,
          [-1.0378, -1.0028, -1.1253,  ..., -1.1078, -1.1253, -1.1429],
          [-0.8452, -1.0553, -1.1078,  ..., -1.0903, -1.1253, -1.1429],
          [-1.1429, -1.2129, -0.9503,  ..., -1.0378, -1.0378, -1.0728]],

         [[ 0.6531,  0.6356,  0.6182,  ...,  0.3045,  0.3393,  0.3742],
          [ 0.5834,  0.5485,  

In [38]:
predicted_labels

['bottlenose_dolpin',
 'bottlenose_dolpin',
 'frasiers_dolphin',
 'bottlenose_dolpin',
 'false_killer_whale',
 'false_killer_whale',
 'false_killer_whale',
 'false_killer_whale',
 'false_killer_whale',
 'frasiers_dolphin',
 'frasiers_dolphin',
 'false_killer_whale',
 'frasiers_dolphin',
 'frasiers_dolphin',
 'frasiers_dolphin',
 'bottlenose_dolpin',
 'false_killer_whale',
 'bottlenose_dolpin',
 'false_killer_whale',
 'false_killer_whale',
 'bottlenose_dolpin',
 'bottlenose_dolpin',
 'bottlenose_dolpin',
 'frasiers_dolphin']

In [40]:
inference_dataset

<__main__.InferenceDataset at 0x168a0ce80>