<a href="https://colab.research.google.com/github/ericodle/J_PlanktoNET/blob/main/VGG_finetuner_aug13.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [52]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# Set up paths and parameters
data_dir = '/content/drive/MyDrive/puafolder'
image_size = 224
batch_size = 32

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


from PIL import Image

def pil_loader(path):
    try:
        img = Image.open(path)
        return img.convert("RGB")
    except Exception as e:
        print(f"Error loading image: {path} - {e}")
        return None


import os
from torch.utils.data import Dataset
from torchvision import transforms

import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.loader = self.pil_loader
        self.samples = self.make_samples()
        self.class_to_idx = self.find_classes()

    def pil_loader(self, path):
        try:
            img = Image.open(path)
            return img.convert("RGB")
        except Exception as e:
            print(f"Error loading image: {path} - {e}")
            return None

    def make_samples(self):
        samples = []
        for class_name in os.listdir(self.data_dir):
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.isdir(class_dir):
                continue  # Skip non-directory entries

            for filename in os.listdir(class_dir):
                img_path = os.path.join(class_dir, filename)
                if not os.path.isfile(img_path):
                    continue  # Skip non-file entries

                samples.append((img_path, class_name))
        return samples

    def find_classes(self):
        classes = [item[1] for item in self.samples]
        classes = list(set(classes))
        classes.sort()
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}
        return class_to_idx

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img_path, target_name = self.samples[index]
        img = self.loader(img_path)

        if img is None:
            return None, None

        if self.transform is not None:
            img = self.transform(img)

        target = self.class_to_idx[target_name]
        return img, target


from torch.utils.data import DataLoader

# Define transformations
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Create CustomDataset instance
train_dataset = CustomDataset(data_dir=data_dir, transform=train_transform)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [48]:
import pickle

# Save train_dataset labels to a file
with open('/content/drive/MyDrive/pua_path/train_labels.pkl', 'wb') as f:
    pickle.dump(train_dataset.samples, f)

with open('/content/drive/MyDrive/pua_path/train_labels.pkl', 'rb') as f:
    saved_samples = pickle.load(f)

# Extract class names from saved_samples
class_labels = list(set(item[1] for item in saved_samples))

print(class_labels)

num_classes = len(class_labels)

['Zooplankton', 'Dinophysiales', 'ceratium', 'Hemiaulaceae', 'Dinoflagellate', 'Pennales', 'Ciliophora', 'radiolaria', 'Junk', 'pyrodinium', 'Noctilucales', 'copepods_exoskeleton_parts', 'nematoda', 'gonyaulax', 'tintinnida_alive', 'tintinnida_house', 'Ostreopsidaceae', 'Gymnodiniales', 'Peridiniales', 'Acantharea', 'Centrales', 'Oxytoxum', 'copepods_alive']


In [53]:
num_classes

23

In [62]:

import torchvision.models as models

#try resnet50 next

# Load pre-trained VGG16 model
base_model = models.vgg19(pretrained=True)
num_features = base_model.classifier[6].in_features
base_model.classifier[6] = nn.Linear(num_features, num_classes)
base_model = base_model.to(device)

epochs = 20

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(base_model.parameters(), lr=0.0001)

losses = []
# Training loop
for epoch in range(epochs):
    base_model.train()
    for inputs, labels in train_loader:
        non_none_indices = [i for i, item in enumerate(inputs) if item is not None]

        if len(non_none_indices) == 0:
            continue  # Skip batch with all problematic images

        inputs = [inputs[i] for i in non_none_indices]
        labels = [labels[i] for i in non_none_indices]

        inputs = torch.stack(inputs).to(device)
        labels = torch.tensor(labels).to(device)

        optimizer.zero_grad()
        outputs = base_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f"Epoch [{epoch+1}/{epochs}] - Loss: {loss.item():.4f}")

import matplotlib.pyplot as plt

plt.plot(losses, label='Training Loss')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Training Losses')
plt.legend()
plt.show()


# Save the trained model
torch.save(base_model.state_dict(), "/content/drive/MyDrive/pua_path/vgg16_fine_tuned.pth")


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 67.3MB/s]


AttributeError: ignored

In [61]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import os
import shutil

output_dir = '/content/drive/MyDrive/pua_test'

test_dir = '/content/drive/MyDrive/unknown'

# Set the image size for resizing and normalization
image_size = 224

# Define the transformation for the unsorted images
unsorted_data_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load the saved model
test_model = models.vgg19(pretrained=False)
test_model.classifier[6] = nn.Linear(num_features, num_classes)
test_model.load_state_dict(torch.load('/content/drive/MyDrive/pua_path/vgg16_fine_tuned.pth'))
test_model.eval()

# Move the model to the appropriate device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
test_model = test_model.to(device)

# Get the class-to-index mapping from the dataset
class_to_idx = train_dataset.class_to_idx

# Retrieve the list of class labels
class_labels = list(class_to_idx.keys())

# Print the list of class labels
print(class_labels)

os.makedirs(output_dir, exist_ok=True)

# Iterate over the unsorted images
for filename in os.listdir(test_dir):
    image_path = os.path.join(test_dir, filename)

    try:
        img = Image.open(image_path)
    except (OSError, IOError) as e:
        print(f"Error opening {image_path}: {e}")
        continue  # Skip to the next image

    try:
        img = unsorted_data_transforms(img).unsqueeze(0).to(device)

        with torch.no_grad():
            outputs = test_model(img)
            _, preds = torch.max(outputs, 1)
            predicted_class = preds.item()

    except Exception as e:
        print(f"Error processing {image_path}: {e}")

    # Get the class name from the fine-tuning dataset
    predicted_class_name = class_labels[predicted_class]

    # Create the target class directory if it doesn't exist
    target_class_dir = os.path.join(output_dir, predicted_class_name)
    os.makedirs(target_class_dir, exist_ok=True)

    # Move the image to the corresponding class directory
    target_image_path = os.path.join(target_class_dir, filename)
    shutil.copy(image_path, target_image_path)

    print(f"Image {filename} moved to class {predicted_class_name}")




['Acantharea', 'Centrales', 'Ciliophora', 'Dinoflagellate', 'Dinophysiales', 'Gymnodiniales', 'Hemiaulaceae', 'Junk', 'Noctilucales', 'Ostreopsidaceae', 'Oxytoxum', 'Pennales', 'Peridiniales', 'Zooplankton', 'ceratium', 'copepods_alive', 'copepods_exoskeleton_parts', 'gonyaulax', 'nematoda', 'pyrodinium', 'radiolaria', 'tintinnida_alive', 'tintinnida_house']
Image D20230712T020618_IFCB108_01586.png moved to class Dinoflagellate
Image D20230712T031622_IFCB108_00510.png moved to class Hemiaulaceae
Image D20230712T022940_IFCB108_01092.png moved to class Dinoflagellate
Image D20230712T020618_IFCB108_00634.png moved to class Dinoflagellate
Image D20230711T231856_IFCB108_00151.png moved to class Pennales
Image D20230711T063656_IFCB108_00213.png moved to class ceratium
Image D20230711T072340_IFCB108_01509.png moved to class Pennales
Image D20230711T063656_IFCB108_01116.png moved to class Pennales
Image D20230711T063656_IFCB108_00563.png moved to class Hemiaulaceae
Image D20230711T061225_IFCB1