In [18]:
import pandas as pd
import numpy as np
import ast

import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn, optim
from torchvision import datasets, transforms, utils
from torchsummary import summary

import matplotlib.pyplot as plt
from PIL import Image
import os

Data Cleaning

In [7]:
labels = pd.read_csv('data/small_data/labels.csv', index_col=0)
print(labels.shape)
labels.head()

(2218, 7)


Unnamed: 0,index,genes,sex,origin,price,birth,url
0,0-0,"['Yellow Belly', 'Pastel', 'Het Puzzle', '50% ...",female,Self Produced,650.0,7th April 2023,https://www.morphmarket.com/us/c/reptiles/pyth...
1,1-0,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...
2,1-1,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...
3,2-0,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...
4,2-1,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...


In [8]:
# 'Normal' if there is no genes
labels.loc[labels["genes"] == "[]", "genes"] = '["Normal"]'

In [9]:
clean_genes = []
list_genes = [ast.literal_eval(gene) for gene in labels['genes']]
for lst in list_genes:
    for element in lst:
        if "Het" in element:
            clean_genes.append('Het' + ' '.join(element.split('Het')[1:]))
        else:
            clean_genes.append(element)

clean_possible_genes = list(set(clean_genes))
print(f'Number of possible genes: {len(clean_possible_genes)}')
clean_possible_genes[:5]

Number of possible genes: 147


['Spotnose', 'Sugar', 'KRG', 'Leopard', 'Blade']

In [10]:
gene_extension_df = pd.DataFrame(np.zeros([labels.shape[0], len(clean_possible_genes)]), dtype=int, columns=clean_possible_genes)
labels_extended = pd.concat([labels, gene_extension_df], axis=1)
labels_extended.head()

Unnamed: 0,index,genes,sex,origin,price,birth,url,Spotnose,Sugar,KRG,...,Bamboo,Bongo,Desert Ghost,Het Citrus Hypo,Axanthic (TSK),Hypo,Axanthic (VPI),Orange Ghost,Stranger,Het Puzzle
0,0-0,"['Yellow Belly', 'Pastel', 'Het Puzzle', '50% ...",female,Self Produced,650.0,7th April 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1-0,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1-1,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,2-0,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2-1,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [11]:
len([len(lst) for lst in list_genes])

2218

In [12]:
start_row = 0
count = 0
for gene_col in clean_genes:
    labels_extended.loc[start_row, gene_col] = 1
    count += 1
    if count == [len(lst) for lst in list_genes][start_row]:
        start_row += 1
        count = 0
labels_extended.head()

Unnamed: 0,index,genes,sex,origin,price,birth,url,Spotnose,Sugar,KRG,...,Bamboo,Bongo,Desert Ghost,Het Citrus Hypo,Axanthic (TSK),Hypo,Axanthic (VPI),Orange Ghost,Stranger,Het Puzzle
0,0-0,"['Yellow Belly', 'Pastel', 'Het Puzzle', '50% ...",female,Self Produced,650.0,7th April 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,1
1,1-0,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,1-1,"['Piebald', 'Albino']",female,Self Produced,450.0,15th October 2023,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,2-0,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,2-1,"['Butter', 'Yellow Belly', 'Hurricane', 'Leopa...",male,Self Produced,750.0,1st May 2022,https://www.morphmarket.com/us/c/reptiles/pyth...,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [17]:
assert list(labels_extended.select_dtypes('int').sum(axis=1)) == [len(lst) for lst in list_genes]

## Load Data

In [61]:
image = Image.open("data/small_data/img/0-0.png")

# Convert the image to a tensor
transform = transforms.ToTensor()
tensor_image = transform(image)
tensor_image.shape

torch.Size([3, 1420, 1420])

In [62]:
class PythonGeneDataset(Dataset):
    def __init__(self, labels_df, img_dir, indices=None, transform=None):
        self.labels_df = labels_df
        if indices is not None:
            self.labels_df = self.labels_df.iloc[indices]
        self.img_dir = img_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, f"{self.labels_df.iloc[idx, 0]}.png")
        image = Image.open(img_name)
        # Parse labels here based on your CSV structure and required format
        labels = torch.tensor(self.labels_df.iloc[idx, 7:].astype('float32').values)
        
        if self.transform:
            image = self.transform(image)

        return image, labels


In [63]:
IMAGE_SIZE = 32
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()])

full_dataset = PythonGeneDataset(labels_df=labels_extended, img_dir='data/small_data/img/', transform=transform)

# Split dataset
total_size = len(full_dataset)
train_size = int(0.8 * total_size)
valid_size = total_size - train_size
train_indices, valid_indices = torch.utils.data.random_split(np.arange(total_size), [train_size, valid_size])

# Create train and validation datasets
train_dataset = Subset(full_dataset, train_indices)
valid_dataset = Subset(full_dataset, valid_indices)

# Initialize DataLoaders
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)


In [64]:
class PythonGeneClassifier(nn.Module):
    def __init__(self, num_classes):
        super(PythonGeneClassifier, self).__init__()
        # Increasing the complexity of the network
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.25)  # Adding dropout for regularization
        )

        # Flatten layer is moved to the forward function
        self.classifier = nn.Sequential(
            nn.Linear(64 * 4 * 4, 512),  # Adjusted for 32x32 input images; calculate accordingly
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.classifier(x)
        return x

In [65]:
model = PythonGeneClassifier(num_classes=len(clean_possible_genes))
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Training on {device}')
model.to(device)

def train_model(model, criterion, optimizer, num_epochs):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)

        # Calculate average loss for the epoch
        train_loss = train_loss / len(train_loader.dataset)

        # Validation of the model
        model.eval()  # Set model to evaluate mode
        valid_loss = 0.0
        with torch.no_grad():
            for inputs, labels in valid_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item() * inputs.size(0)

        # Calculate average loss over validation data
        valid_loss = valid_loss / len(valid_loader.dataset)

        # Print training/validation statistics
        print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}')

# Call to train the model
num_epochs = 5  # Set the number of epochs
train_model(model, criterion, optimizer, num_epochs)


Epoch 1/5, Train Loss: 0.1603, Valid Loss: 0.0830
Epoch 2/5, Train Loss: 0.0865, Valid Loss: 0.0819
Epoch 3/5, Train Loss: 0.0849, Valid Loss: 0.0811
Epoch 4/5, Train Loss: 0.0846, Valid Loss: 0.0815
Epoch 5/5, Train Loss: 0.0843, Valid Loss: 0.0814
