In [1]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision

from torch.utils.data import Dataset, DataLoader, BatchSampler, random_split
from torchvision import transforms
from PIL import Image

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
path = '/content/drive/MyDrive/CS4995/'

In [6]:
train_ann_df = pd.read_csv(path + 'train_data.csv')
# test_ann_df = pd.read_csv(path + 'test_data.csv')
super_map_df = pd.read_csv(path + 'superclass_mapping.csv')
sub_map_df = pd.read_csv(path + 'subclass_mapping.csv')

In [7]:
import zipfile
import os

# Define paths
base_path = '/content/drive/MyDrive/CS4995'
train_zip_path = os.path.join(base_path, 'train_images.zip')
test_zip_path = os.path.join(base_path, 'test_images.zip')

# Define extraction folders
train_extract_dir = os.path.join(base_path, 'train_images/train_images')
test_extract_dir = os.path.join(base_path, 'test_images/test_images')
novel_extract_dir = os.path.join(base_path, 'train_images/reptile_novel')

# # Unzip train images

In [8]:
train_img_dir = train_extract_dir
test_img_dir = test_extract_dir
novel_img_dir = novel_extract_dir

# image_preprocessing = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0), std=(1)),
# ])

# Filter to only keep rows where superclass is 2
train_ann_df_filtered = train_ann_df[train_ann_df['superclass_index'] == 2].reset_index(drop=True)
unique_subs = sorted(train_ann_df_filtered['subclass_index'].unique())
sub2new = {old: new for new, old in enumerate(unique_subs)}
# train_ann_df_filtered['sub_label'] = train_ann_df_filtered['subclass_index'].map(sub2new)
train_ann_df_filtered['label'] = train_ann_df_filtered['subclass_index'].map(sub2new)
num_subtypes = len(unique_subs)

In [9]:
num_subtypes

29

In [10]:
novel_filenames = os.listdir(novel_img_dir)
novel_filenames = [f for f in novel_filenames if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

# Create the novel DataFrame
novel_df = pd.DataFrame({
    'image': novel_filenames,
    'label': num_subtypes
})

# Optional: Add superclass_index and subclass_index columns as placeholder (-1 or NaN)
novel_df['superclass_index'] = 2
novel_df['subclass_index'] = -1
novel_df['image'] = "reptile_novel/" + novel_df['image']

train_ann_df_filtered['image'] = "train_images/" + train_ann_df_filtered['image']

# Concatenate with the filtered train DataFrame
train_ann_df_final = pd.concat([train_ann_df_filtered, novel_df.tail(500)], ignore_index=True)

In [11]:
train_ann_df_final

Unnamed: 0,image,superclass_index,subclass_index,description,label
0,train_images/5.jpg,2,61,"nature photograph of a reptile, specifically a...",18
1,train_images/6.jpg,2,57,"nature photograph of a reptile, specifically a...",16
2,train_images/18.jpg,2,3,"nature photograph of a reptile, specifically a...",1
3,train_images/19.jpg,2,50,"nature photograph of a reptile, specifically a...",13
4,train_images/26.jpg,2,47,"nature photograph of a reptile, specifically a...",11
...,...,...,...,...,...
2849,reptile_novel/c8b118de27.jpg,2,-1,,29
2850,reptile_novel/da7e4e32dc.jpg,2,-1,,29
2851,reptile_novel/efbf7fa21a.jpg,2,-1,,29
2852,reptile_novel/ca2645fc31.jpg,2,-1,,29


In [12]:
print((train_ann_df_final['label'].unique()))

[18 16  1 13 11 28 20 14 15 27 24 12  4  0 21 19  6 10  7 26 17 25 23  9
  2  3 22  8  5 29]


In [13]:
train_ann_df_filtered = train_ann_df_final

In [14]:
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import os
import torch

# Transformation: resize and convert to tensor (range [0, 1])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Accumulators
mean = torch.zeros(3)
std = torch.zeros(3)
total = 0

root_dir = base_path + '/train_images/'
for img_name in tqdm(train_ann_df_filtered['image']):
    img_path = os.path.join(root_dir, img_name)
    img = Image.open(img_path).convert('RGB')
    tensor = transform(img)  # Shape: [C, H, W]
    mean += tensor.mean(dim=(1, 2))
    std += tensor.std(dim=(1, 2))
    total += 1

# Average over all images
mean /= total
std /= total

print("Mean:", mean)
print("Std:", std)


  0%|          | 2/2854 [00:02<1:05:30,  1.38s/it]


KeyboardInterrupt: 

In [15]:
from torch.utils.data import random_split

In [16]:
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import torch

In [17]:
transform = transforms.Compose([
    transforms.Resize((456, 456)),
    # add
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomRotation(15),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    # add
    transforms.ToTensor(),
    transforms.Normalize([0.4411, 0.4063, 0.3336],
                         [0.1785, 0.1688, 0.1633])
])
root_dir = base_path + '/train_images/'

class Reptile(Dataset):
    def __init__(self, df, transform=None, root_dir=""):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.root_dir = root_dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_name = self.df.loc[idx, 'image']
        img_path = os.path.join(self.root_dir, img_name)
        label = self.df.loc[idx, 'label']

        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        return img, label



train_dataset = Reptile(df=train_ann_df_filtered, transform=transform, root_dir=root_dir)

total_len = len(train_dataset)
train_len = int(0.8 * total_len)
val_len = total_len - train_len
train_dataset, val_dataset = random_split(train_dataset, [train_len, val_len], generator=torch.Generator().manual_seed(42))

In [18]:
print(total_len)

2854


In [19]:
val_dataset

<torch.utils.data.dataset.Subset at 0x7d02363a8d10>

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

class TransferModel(nn.Module):
    def __init__(self, num_classes, include_none_class: bool = True):
        super().__init__()
        # Load pretrained MobileNetV2 and freeze its feature extractor
        base = mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        for p in base.features.parameters():
            p.requires_grad = False

        self.features    = base.features
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Decide how many outputs we need
        out_features = num_classes + 1 if include_none_class else num_classes

        # Build the classifier head
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1280, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(256, out_features)  # No Softmax
        )

        # (Optional) store a label index for the "none" class
        if include_none_class:
            self.none_class_index = out_features - 1

    def forward(self, x):
        x = self.features(x)
        x = self.global_pool(x)
        probs = self.classifier(x)
        return probs

In [21]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=8, pin_memory=True)

In [22]:
num_subtypes

29

In [23]:
import torch
import gc
from tqdm import tqdm

class Trainer():
    def __init__(self, model, criterion, optimizer, train_loader, val_loader, test_loader=None, device='cuda'):
        self.model = model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.device = device
        self.best_val_acc = 0

    def train_epoch(self):
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(self.train_loader, desc="Training", leave=False):
            inputs, labels = inputs.to(self.device), labels.to(self.device)

            self.optimizer.zero_grad()
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            loss.backward()
            self.optimizer.step()

            running_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()
            total += labels.size(0)

        avg_loss = running_loss / len(self.train_loader)
        acc = correct / total * 100
        print(f'Training loss: {avg_loss:.4f} | Train Acc: {acc:.2f}% \n')

    def validate_epoch(self):
        self.model.eval()
        correct_seen = 0
        total_seen = 0
        correct_unseen = 0
        total_unseen = 0
        running_loss = 0.0

        with torch.no_grad():
            for inputs, labels in tqdm(self.val_loader, desc="Validating", leave=False):
                inputs, labels = inputs.to(self.device), labels.to(self.device)

                # Get model predictions
                outputs = self.model(inputs)
                loss = self.criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)

                # Separate accuracy calculations for seen (labels 0-28) and unseen (label 29)
                seen_mask = labels != 29
                unseen_mask = labels == 29

                # Accuracy for seen labels
                correct_seen += ((predicted == labels) & seen_mask).sum().item()
                total_seen += seen_mask.sum().item()

                # Accuracy for unseen labels
                correct_unseen += ((predicted == labels) & unseen_mask).sum().item()
                total_unseen += unseen_mask.sum().item()
                # if unseen_mask.sum().item() > 0:
                #     print(unseen_mask, predicted, labels)

                running_loss += loss.item()

                # Optionally clear memory
                # del inputs, labels, outputs, loss
                # torch.cuda.empty_cache()
                # gc.collect()

        # Calculate average loss and accuracies
        avg_loss = running_loss / len(self.val_loader)
        acc_seen = correct_seen / total_seen * 100 if total_seen > 0 else 0
        acc_unseen = correct_unseen / total_unseen * 100 if total_unseen > 0 else 0
        # print(correct_unseen)

        # Save the best model weights if accuracy improves
        if acc_seen > self.best_val_acc:
            self.best_val_acc = acc_seen
            torch.save(self.model.state_dict(), "best_model_weights.pth")

        # Print the results
        print(f'Validation loss: {avg_loss:.4f}')
        print(f'Seen Class Accuracy (0-28): {acc_seen:.2f}%')
        print(f'Unseen Class Accuracy (29): {acc_unseen:.2f}%')
        print(f'Overall Val Accuracy: {(correct_seen + correct_unseen) / (total_seen + total_unseen) * 100:.2f}%\n')



In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TransferModel(num_classes=num_subtypes, include_none_class = True).to(device)

# # --- 5) loss & optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=1e-3)
# criterion = nn.BCELoss()
# optimizer = optim.Adam(model.classifier.parameters(), lr=1e-4)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 223MB/s]


In [25]:
trainer = Trainer(model, criterion, optimizer, train_loader, val_loader, device=device)

In [26]:
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import models
import os
from PIL import Image
import time
import gc
EPOCHS = 20

for epoch in range(EPOCHS):
    print(f'\nEpoch {epoch+1}/{EPOCHS}')
    start_time = time.time()

    trainer.train_epoch()
    trainer.validate_epoch()

    # torch.cuda.empty_cache()
    # gc.collect()

    end_time = time.time()
    duration = end_time - start_time
    print(f'Epoch time: {duration:.2f} seconds')

print('Finished Training')


Epoch 1/20




Training loss: 1.6663 | Train Acc: 55.45% 





Validation loss: 0.9287
Seen Class Accuracy (0-28): 71.34%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 76.53%

Epoch time: 285.55 seconds

Epoch 2/20




Training loss: 0.7795 | Train Acc: 79.11% 





Validation loss: 0.6019
Seen Class Accuracy (0-28): 79.96%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.54%

Epoch time: 6.78 seconds

Epoch 3/20




Training loss: 0.5962 | Train Acc: 82.48% 





Validation loss: 0.5277
Seen Class Accuracy (0-28): 82.97%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 85.99%

Epoch time: 6.54 seconds

Epoch 4/20




Training loss: 0.4914 | Train Acc: 84.45% 





Validation loss: 0.5565
Seen Class Accuracy (0-28): 79.74%
Unseen Class Accuracy (29): 100.00%
Overall Val Accuracy: 83.54%

Epoch time: 6.74 seconds

Epoch 5/20




Training loss: 0.4109 | Train Acc: 86.68% 





Validation loss: 0.5785
Seen Class Accuracy (0-28): 77.16%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 81.26%

Epoch time: 6.61 seconds

Epoch 6/20




Training loss: 0.3775 | Train Acc: 88.30% 





Validation loss: 0.6143
Seen Class Accuracy (0-28): 76.72%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 80.91%

Epoch time: 6.70 seconds

Epoch 7/20




Training loss: 0.3281 | Train Acc: 89.53% 





Validation loss: 0.5221
Seen Class Accuracy (0-28): 79.96%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.54%

Epoch time: 6.47 seconds

Epoch 8/20




Training loss: 0.2935 | Train Acc: 90.89% 





Validation loss: 0.4642
Seen Class Accuracy (0-28): 81.25%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 84.59%

Epoch time: 6.65 seconds

Epoch 9/20




Training loss: 0.2590 | Train Acc: 91.94% 





Validation loss: 0.6267
Seen Class Accuracy (0-28): 76.51%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 80.74%

Epoch time: 6.49 seconds

Epoch 10/20




Training loss: 0.2461 | Train Acc: 92.25% 





Validation loss: 0.5547
Seen Class Accuracy (0-28): 80.17%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.71%

Epoch time: 6.68 seconds

Epoch 11/20




Training loss: 0.2289 | Train Acc: 92.33% 





Validation loss: 0.6010
Seen Class Accuracy (0-28): 78.45%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 82.31%

Epoch time: 6.48 seconds

Epoch 12/20




Training loss: 0.2026 | Train Acc: 93.96% 





Validation loss: 0.5654
Seen Class Accuracy (0-28): 80.17%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.71%

Epoch time: 6.66 seconds

Epoch 13/20




Training loss: 0.2077 | Train Acc: 93.60% 





Validation loss: 0.5255
Seen Class Accuracy (0-28): 82.76%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 85.81%

Epoch time: 6.50 seconds

Epoch 14/20




Training loss: 0.2096 | Train Acc: 93.08% 





Validation loss: 0.6038
Seen Class Accuracy (0-28): 78.66%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 82.49%

Epoch time: 6.57 seconds

Epoch 15/20




Training loss: 0.2030 | Train Acc: 92.99% 





Validation loss: 0.5062
Seen Class Accuracy (0-28): 80.17%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.71%

Epoch time: 6.63 seconds

Epoch 16/20




Training loss: 0.1670 | Train Acc: 94.92% 





Validation loss: 0.5034
Seen Class Accuracy (0-28): 81.25%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 84.59%

Epoch time: 6.60 seconds

Epoch 17/20




Training loss: 0.1730 | Train Acc: 94.35% 





Validation loss: 0.4503
Seen Class Accuracy (0-28): 82.97%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 85.99%

Epoch time: 6.55 seconds

Epoch 18/20




Training loss: 0.1754 | Train Acc: 94.35% 





Validation loss: 0.4769
Seen Class Accuracy (0-28): 83.41%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 86.34%

Epoch time: 6.75 seconds

Epoch 19/20




Training loss: 0.1745 | Train Acc: 94.48% 





Validation loss: 0.6018
Seen Class Accuracy (0-28): 79.31%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 83.01%

Epoch time: 6.52 seconds

Epoch 20/20




Training loss: 0.1768 | Train Acc: 94.22% 



                                                           

Validation loss: 0.6168
Seen Class Accuracy (0-28): 78.45%
Unseen Class Accuracy (29): 99.07%
Overall Val Accuracy: 82.31%

Epoch time: 6.72 seconds
Finished Training




In [None]:
model.load_state_dict(torch.load("best_model_weights.pth"))
# model.eval()  # Set to evaluation mode

  model.load_state_dict(torch.load("best_model_weights.pth"))


<All keys matched successfully>

In [27]:
torch.save(model, "best_model.pt")


In [28]:
ls

best_model.pt  best_model_weights.pth  [0m[01;34mdrive[0m/  [01;34msample_data[0m/


In [3]:
import os
print(os.getcwd())


/content


In [4]:
import os
print(os.listdir(os.getcwd()))


['.config', 'drive', 'sample_data']
