In [255]:
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score
from PIL import Image

In [256]:
IMAGE_DIR = "train"
NUM_IMAGES = 5488

In [257]:
def get_image_path(image_id):
    return os.path.join(IMAGE_DIR, f"img_{image_id:06d}.jpg")

def open_image(image_path):
    try:
        return Image.open(image_path)
    except IOError:
        print(f"Error opening image: {image_path}")
        return None

In [258]:
class CreateModel:
    def __init__(
        self, 
        model_class,
        model_params,
        get_training_data,
        test_size=0.2,
        random_state=42,
    ):
        self.model = model_class(**model_params)
        self.features, self.labels = get_training_data(random_state = random_state)
        self.test_size = test_size
        self.random_state = random_state

    def training_test_split(self):
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
            self.features, 
            self.labels, 
            test_size = self.test_size, 
            random_state = self.random_state
        )
        return self.X_train, self.X_test, self.y_train, self.y_test

    def train(self):
        self.training_test_split()
        self.model.fit(self.X_train, self.y_train)

    def evaluate(self, score_function):
        y_pred = self.model.predict(self.X_test)
        score = score_function(self.y_test, y_pred)
        print(f"Score: {score:.4f}")
        return score


In [259]:
def load_image_data():
    data = []

    for i in range(1, NUM_IMAGES + 1):
        image = open_image(get_image_path(i))
        if image:
            flat = np.array(image).flatten()
            row = {"image_path": image.filename.split("/")[-1]}
            # Add one column per pixel
            row.update({f"pixel_{j}": flat[j] for j in range(len(flat))})
            data.append(row)

    images = pd.DataFrame(data)
    return images
            

    # images = np.stack(images)

def get_training_data(random_state = 0):
    # Features
    additional_features = pd.read_csv("train/Features/additional_features.csv")
    color_histogram = pd.read_csv("train/Features/color_histogram.csv")
    hog_pca = pd.read_csv("train/Features/hog_pca.csv")

    # Image data
    images = load_image_data()

    # Combine features on image_path
    # features = pd.merge(additional_features, color_histogram, on="image_path")
    # features = pd.merge(features, hog_pca, on="image_path")
    # features = pd.merge(features, images, on="image_path")
    features = images.copy()

    features = features.drop(columns=["image_path"])
    print(features)

    train_metadata = pd.read_csv("train/train_metadata.csv")
    labels = train_metadata["ClassId"]

    return features, labels

def weighted_f1(y_true, y_pred):
    return f1_score(y_true, y_pred, average='weighted')

def main():
    model = CreateModel(
        model_class=RandomForestClassifier,
        model_params={
            "n_estimators": 100,
        },
        get_training_data = get_training_data,
        test_size=0.2,

    )

    # model.train()
    # model.evaluate(weighted_f1)

if __name__ == "__main__":
    main()

      pixel_0  pixel_1  pixel_2  pixel_3  pixel_4  pixel_5  pixel_6  pixel_7  \
0          81       70       71       75       66       69       65       58   
1          44       34       27       46       36       29       46       36   
2          72       95      114       60       73       91       77       81   
3          24       19       20       25       24       22       16       18   
4          61       75       75       54       73       72       55       79   
...       ...      ...      ...      ...      ...      ...      ...      ...   
5483       56       50       49       56       49       48       58       50   
5484      150      142      156      144      136      149      142      134   
5485       30       31       24       40       41       29       31       32   
5486      217      220      229      213      219      226      211      221   
5487      103      117      135      102      117      138      106      122   

      pixel_8  pixel_9  ...  pixel_1191